This commit is contained in:
2025-09-29 15:28:09 +08:00
parent 03941068d8
commit 715033f650
5 changed files with 141 additions and 124 deletions

196
copier.go
View File

@@ -15,7 +15,7 @@ type copier struct {
mu sync.Mutex
}
func newCopier(opts ...option) *copier {
func New(opts ...option) *copier {
opt := getOpt(opts...)
var visited map[uintptr]struct{}
if opt.detectCircularRefs {
@@ -52,29 +52,30 @@ func (c *copier) Copy(dst, src any) error {
func (c *copier) reset() {
if c.opt.detectCircularRefs {
c.visited = make(map[uintptr]struct{})
clear(c.visited)
}
}
func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
if c.ExceedMaxDepth(depth) {
return nil
return ErrMaxDepthExceeded
}
// 复制源为空值,直接返回
if !src.IsValid() || src.IsZero() {
return nil
}
if c.isCircularRefs(src) {
return nil
return ErrCircularReference
}
// 可以直接赋值时直接赋值并返回
srcType, dstType := src.Type(), dst.Type()
if srcType.AssignableTo(dstType) {
dst.Set(src)
return nil
}
// if srcType.AssignableTo(dstType) {
// dst.Set(src)
// return nil
// }
// 处理时间类型
if isTimeType(srcType) || isTimeType(dstType) {
@@ -83,7 +84,7 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
switch src.Kind() {
case reflect.Slice, reflect.Array:
return c.copySlice(dst, src, depth)
return c.copySliceOrArray(dst, src, depth)
case reflect.Map:
return c.copyMap(dst, src, depth)
case reflect.Struct:
@@ -93,26 +94,26 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
case reflect.Interface:
return c.copyInterface(dst, src, depth)
case reflect.Chan, reflect.Func, reflect.UnsafePointer:
return ErrNotSupported
return ErrNotSupported(srcType, dstType)
default:
return c.set(dst, src)
return c.copyBasic(dst, src)
}
}
func (c *copier) copySlice(dst, src reflect.Value, depth int) error {
func (c *copier) copySliceOrArray(dst, src reflect.Value, depth int) error {
switch dst.Kind() {
case reflect.Struct:
return c.copyStruct2Slice(dst, src, depth)
return c.copyStructToSlice(dst, src, depth)
case reflect.Slice, reflect.Array:
return c.copySliceToSlice(dst, src, depth)
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
}
func (c *copier) copyStruct2Slice(dst, src reflect.Value, depth int) error {
func (c *copier) copyStructToSlice(dst, src reflect.Value, depth int) error {
if dst.Kind() != reflect.Slice {
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
// 创建新的slice长度+1
@@ -152,7 +153,6 @@ func (c *copier) copyToSlice(dst, src reflect.Value, depth int) error {
}
copyLen := min(dst.Len(), srcLen)
for i := range copyLen {
if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil {
return err
@@ -168,7 +168,6 @@ func (c *copier) copyToArray(dst, src reflect.Value, depth int) error {
// 取较小长度
copyLen := min(dstLen, srcLen)
for i := range copyLen {
if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil {
return err
@@ -178,7 +177,7 @@ func (c *copier) copyToArray(dst, src reflect.Value, depth int) error {
return nil
}
func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error {
func (c *copier) copyStructToMap(dst, src reflect.Value, depth int) error {
if dst.IsNil() {
dst.Set(reflect.MakeMapWithSize(dst.Type(), src.NumField()))
}
@@ -210,26 +209,23 @@ func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error {
continue
}
if sField.Kind() == reflect.Pointer || sField.Kind() == reflect.Struct {
var newDst reflect.Value
var copiedValue reflect.Value
if isTimeType(sField.Type()) {
newDst = reflect.New(sField.Type()).Elem()
if err := c.copyTime(newDst, sField); err != nil {
copiedValue = reflect.New(sField.Type()).Elem()
if err := c.copyTime(copiedValue, sField); err != nil {
return err
}
} else {
var newDst = reflect.ValueOf(make(map[string]any))
} else if sField.Kind() == reflect.Struct {
copiedValue = reflect.ValueOf(make(map[string]any))
sField = indirect(sField)
if err := c.deepCopy(newDst, sField, depth+1); err != nil {
if err := c.deepCopy(copiedValue, sField, depth+1); err != nil {
return err
}
} else {
copiedValue = sField
}
dst.SetMapIndex(reflect.ValueOf(name), newDst)
} else {
dst.SetMapIndex(reflect.ValueOf(name), sField)
}
dst.SetMapIndex(reflect.ValueOf(name), copiedValue)
}
return nil
@@ -256,7 +252,7 @@ func (c *copier) copyMap(dst, src reflect.Value, depth int) error {
}
return nil
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
}
@@ -266,8 +262,8 @@ func (c *copier) copyMapToMap(dst, src reflect.Value, depth int) error {
}
dstType, _ := indirectType(dst.Type())
iter := src.MapRange()
for iter.Next() {
key := iter.Key()
value := iter.Value()
@@ -285,40 +281,34 @@ func (c *copier) copyMapToMap(dst, src reflect.Value, depth int) error {
key = reflect.ValueOf(fieldName)
}
var copitedValue reflect.Value
switch dstType.Elem().Kind() {
case reflect.Interface:
switch value.Kind() {
case reflect.Interface:
copitedValue = reflect.New(value.Elem().Type()).Elem()
default:
copitedValue = reflect.New(value.Type()).Elem()
}
c.set(copitedValue, value)
default:
copitedValue = reflect.New(dstType.Elem()).Elem()
}
if err := c.deepCopy(copitedValue, value, depth+1); err != nil {
copiedValue := c.prepareMapValue(dstType.Elem(), value)
if err := c.deepCopy(copiedValue, value, depth+1); err != nil {
return err
}
dst.SetMapIndex(key, copitedValue)
dst.SetMapIndex(key, copiedValue)
}
return nil
}
func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error {
func (c *copier) prepareMapValue(elemType reflect.Type, value reflect.Value) reflect.Value {
if elemType.Kind() == reflect.Interface {
if value.Kind() == reflect.Interface {
return reflect.New(value.Elem().Type()).Elem()
}
return reflect.New(value.Type()).Elem()
}
return reflect.New(elemType).Elem()
}
func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error {
fields := c.deepFields(dst.Type())
for _, sf := range fields {
if nestedAnonymousField(dst, sf.Field.Name) {
continue
}
field := dst.FieldByName(sf.Field.Name)
field := dst.FieldByName(sf.Field.Name)
if !field.CanSet() {
continue
}
@@ -334,6 +324,7 @@ func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error {
if c.opt.caseSensitive {
mapValue = src.MapIndex(reflect.ValueOf(name))
} else {
// fix: 忽略大小写查询
mapValue = src.MapIndex(reflect.ValueOf(strings.ToLower(name)))
}
@@ -356,31 +347,21 @@ func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error {
func (c *copier) copyStruct(dst, src reflect.Value, depth int) error {
switch dst.Kind() {
case reflect.Struct:
return c.copyStruct2Struct(dst, src, depth)
return c.copyStructToStruct(dst, src, depth)
case reflect.Map:
return c.copyStruct2Map(dst, src, depth)
return c.copyStructToMap(dst, src, depth)
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
}
func (c *copier) copyStruct2Struct(dst, src reflect.Value, depth int) error {
if dst.CanSet() {
if src.Type().AssignableTo(dst.Type()) {
func (c *copier) copyStructToStruct(dst, src reflect.Value, depth int) error {
if _, ok := src.Interface().(time.Time); ok && src.Type().AssignableTo(dst.Type()) {
dst.Set(src)
return nil
}
if _, ok := src.Interface().(time.Time); ok {
if dst.Type().AssignableTo(src.Type()) {
dst.Set(src)
return nil
}
}
}
typ := src.Type()
fields := c.deepFields(typ)
fields := c.deepFields(src.Type())
for _, sf := range fields {
name := c.getFieldName(sf.Field.Name, sf.Tag)
@@ -450,7 +431,7 @@ func (c *copier) copyPointer(dst, src reflect.Value, depth int) error {
return c.deepCopy(dst, srcElem, depth+1)
}
func (c *copier) set(dst, src reflect.Value) error {
func (c *copier) copyBasic(dst, src reflect.Value) error {
if !src.IsValid() {
return ErrInvalidCopyFrom
}
@@ -479,7 +460,7 @@ func (c *copier) set(dst, src reflect.Value) error {
case reflect.Bool:
return c.setBool(dst, src)
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
}
@@ -500,16 +481,7 @@ func (c *copier) setString(dst, src reflect.Value) error {
dst.SetString(string(src.Bytes()))
return nil
}
// 其他slice类型转换为字符串表示
dst.SetString(fmt.Sprintf("%v", src.Interface()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
dst.SetString(strconv.FormatInt(src.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
dst.SetString(strconv.FormatUint(src.Uint(), 10))
case reflect.Float32, reflect.Float64:
dst.SetString(strconv.FormatFloat(src.Float(), 'f', -1, 64))
case reflect.Bool:
dst.SetString(strconv.FormatBool(src.Bool()))
fallthrough
default:
dst.SetString(fmt.Sprintf("%v", src.Interface()))
}
@@ -525,7 +497,7 @@ func (c *copier) setInt(dst, src reflect.Value) error {
case reflect.Float32, reflect.Float64:
dst.SetInt(int64(src.Float()))
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
return nil
@@ -544,9 +516,9 @@ func (c *copier) setUint(dst, src reflect.Value) error {
dst.SetUint(i)
return nil
}
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
return nil
@@ -565,9 +537,9 @@ func (c *copier) setFloat(dst, src reflect.Value) error {
dst.SetFloat(f)
return nil
}
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
return nil
@@ -592,13 +564,13 @@ func (c *copier) setBool(dst, src reflect.Value) error {
dst.SetBool(false)
return nil
}
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
dst.SetBool(src.Int() != 0)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
dst.SetBool(src.Uint() != 0)
default:
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
return nil
@@ -699,7 +671,7 @@ func (c *copier) copyTime(dst, src reflect.Value) error {
return c.otherToTime(dst, src)
}
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
func (c *copier) timeToOther(dst reflect.Value, src time.Time) error {
@@ -724,7 +696,7 @@ func (c *copier) timeToOther(dst reflect.Value, src time.Time) error {
}
}
return ErrNotSupported
return ErrNotSupported(reflect.TypeOf(src), dst.Type())
}
func (c *copier) otherToTime(dst reflect.Value, src reflect.Value) error {
@@ -754,7 +726,7 @@ func (c *copier) otherToTime(dst reflect.Value, src reflect.Value) error {
return c.mapToTime(dst, src)
}
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
func (c *copier) getTimeFormat(dstType reflect.Type) string {
@@ -798,7 +770,7 @@ func (c *copier) stringToTime(dst reflect.Value, timeStr string) error {
}
}
return fmt.Errorf("无法解析时间字符串: %s", timeStr)
return fmt.Errorf("cannot parse time string: %s", timeStr)
}
func (c *copier) setTimeValue(dst reflect.Value, t time.Time) error {
@@ -823,10 +795,14 @@ func (c *copier) setTimeValue(dst reflect.Value, t time.Time) error {
return nil
}
// 目标类型是自定义时间类型
if setter, ok := dst.Addr().Interface().(interface{ SetTime(time.Time) }); ok {
setter.SetTime(t)
return nil
}
return ErrNotSupported(reflect.TypeOf(t), dst.Type())
}
func (c *copier) getTimeFormats(dstType reflect.Type) []string {
var formats []string
@@ -869,7 +845,11 @@ func (c *copier) getTimeFormats(dstType reflect.Type) []string {
func (c *copier) mapToTime(dst reflect.Value, src reflect.Value) error {
if src.Type().Key().Kind() != reflect.String {
return ErrNotSupported
return ErrNotSupported(src.Type(), dst.Type())
}
if src.Len() == 0 {
return nil
}
// 从map中提取时间组件
@@ -905,9 +885,12 @@ func (c *copier) mapToTime(dst reflect.Value, src reflect.Value) error {
}
}
if year == 0 {
year = time.Now().Year()
}
t := time.Date(year, time.Month(month), day, hour, min, sec, nsec, loc)
dst.Set(reflect.ValueOf(t))
return nil
return c.setTimeValue(dst, t)
}
func getIntValue(v reflect.Value) int64 {
@@ -924,6 +907,7 @@ func getIntValue(v reflect.Value) int64 {
return i
}
}
return 0
}
@@ -940,12 +924,11 @@ func nestedAnonymousField(dst reflect.Value, fieldName string) bool {
return false
}
if !destField.CanSet() {
if destField.CanSet() {
destField.Set(reflect.New(destField.Type().Elem()))
} else {
return true
}
newValue := reflect.New(destField.Type().Elem())
destField.Set(newValue)
}
}
@@ -1029,6 +1012,10 @@ func isTimeType(t reflect.Type) bool {
return true
}
if _, ok := t.MethodByName("Time"); ok {
return true
}
return false
}
@@ -1068,14 +1055,12 @@ func getTimeValueExt(v reflect.Value) reflect.Value {
if v.IsValid() && v.CanInterface() {
// 尝试通过接口转换
if converter, ok := v.Interface().(interface{ ToTime() time.Time }); ok {
timeVal := converter.ToTime()
return reflect.ValueOf(timeVal)
return reflect.ValueOf(converter.ToTime())
}
// 尝试直接转换
if v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) {
timeVal := v.Convert(reflect.TypeOf(time.Time{}))
return timeVal
return v.Convert(reflect.TypeOf(time.Time{}))
}
}
@@ -1084,6 +1069,9 @@ func getTimeValueExt(v reflect.Value) reflect.Value {
func indirect(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface {
if v.IsNil() {
break
}
v = v.Elem()
}

View File

@@ -94,12 +94,14 @@ func TestStructToMapCopy(t *testing.T) {
type person struct {
Name string
Age int
Birthday time.Time
Address *Address
}
src := person{
Name: "John",
Age: 30,
Birthday: time.Now(),
Address: &Address{
City: "Beijing",
},
@@ -227,7 +229,10 @@ func TestMixMapToStruct(t *testing.T) {
t.Fatal(err)
}
b, _ := json.MarshalIndent(dst, "", " ")
t.Log(dst)
t.Log(string(b))
}
func BenchmarkMapCopy(b *testing.B) {

View File

@@ -294,7 +294,7 @@ func TestDeepFileds(t *testing.T) {
var emp = Employee{}
typ := reflect.TypeOf(emp)
copier := newCopier()
copier := New()
wrapper := copier.deepFields(typ)
for _, v := range wrapper {
t.Log(v)
@@ -305,11 +305,17 @@ func BenchmarkCopy(b *testing.B) {
var emp = Employee{
Name: "John",
Age: 30,
Address: &Address{
Country: "USA",
City: "New York",
},
}
for b.Loop() {
var dst Employee
Copy(&dst, &emp)
if err := Copy(&dst, &emp); err != nil {
b.Fatal(err)
}
}
}
@@ -318,7 +324,7 @@ func BenchmarkDeepFields(b *testing.B) {
typ := reflect.TypeOf(emp)
for b.Loop() {
copier := newCopier()
copier := New()
copier.deepFields(typ)
}

View File

@@ -3,7 +3,7 @@ package copier
import "reflect"
func Copy(dst, src any, opts ...option) error {
copier := newCopier(opts...)
copier := New(opts...)
return copier.Copy(dst, src)
}
@@ -15,6 +15,6 @@ func Clone(src any, opts ...option) (any, error) {
srcType := reflect.TypeOf(src)
dst := reflect.New(srcType).Interface()
err := newCopier(opts...).Copy(dst, src)
err := New(opts...).Copy(dst, src)
return dst, err
}

View File

@@ -1,9 +1,27 @@
package copier
import "errors"
import (
"errors"
"fmt"
"reflect"
)
var (
ErrInvalidCopyDestination = errors.New("copy destination must be non-nil and addressable")
ErrInvalidCopyFrom = errors.New("copy from must be non-nil and addressable")
ErrNotSupported = errors.New("not supported")
ErrCircularReference = errors.New("circular reference detected")
ErrMaxDepthExceeded = errors.New("max depth exceeded")
)
type NotSupportedError struct {
SrcType reflect.Type
DstType reflect.Type
}
func (e *NotSupportedError) Error() string {
return fmt.Sprintf("unsupported type conversion from %s to %s", e.SrcType, e.DstType)
}
func ErrNotSupported(srcType, dstType reflect.Type) error {
return &NotSupportedError{SrcType: srcType, DstType: dstType}
}