diff --git a/copier.go b/copier.go index 56ab653..76e7856 100644 --- a/copier.go +++ b/copier.go @@ -15,7 +15,7 @@ type copier struct { mu sync.Mutex } -func newCopier(opts ...option) *copier { +func New(opts ...option) *copier { opt := getOpt(opts...) var visited map[uintptr]struct{} if opt.detectCircularRefs { @@ -52,29 +52,30 @@ func (c *copier) Copy(dst, src any) error { func (c *copier) reset() { if c.opt.detectCircularRefs { - c.visited = make(map[uintptr]struct{}) + clear(c.visited) } } func (c *copier) deepCopy(dst, src reflect.Value, depth int) error { if c.ExceedMaxDepth(depth) { - return nil + return ErrMaxDepthExceeded } + // 复制源为空值,直接返回 if !src.IsValid() || src.IsZero() { return nil } if c.isCircularRefs(src) { - return nil + return ErrCircularReference } // 可以直接赋值时直接赋值并返回 srcType, dstType := src.Type(), dst.Type() - if srcType.AssignableTo(dstType) { - dst.Set(src) - return nil - } + // if srcType.AssignableTo(dstType) { + // dst.Set(src) + // return nil + // } // 处理时间类型 if isTimeType(srcType) || isTimeType(dstType) { @@ -83,7 +84,7 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error { switch src.Kind() { case reflect.Slice, reflect.Array: - return c.copySlice(dst, src, depth) + return c.copySliceOrArray(dst, src, depth) case reflect.Map: return c.copyMap(dst, src, depth) case reflect.Struct: @@ -93,26 +94,26 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error { case reflect.Interface: return c.copyInterface(dst, src, depth) case reflect.Chan, reflect.Func, reflect.UnsafePointer: - return ErrNotSupported + return ErrNotSupported(srcType, dstType) default: - return c.set(dst, src) + return c.copyBasic(dst, src) } } -func (c *copier) copySlice(dst, src reflect.Value, depth int) error { +func (c *copier) copySliceOrArray(dst, src reflect.Value, depth int) error { switch dst.Kind() { case reflect.Struct: - return c.copyStruct2Slice(dst, src, depth) + return c.copyStructToSlice(dst, src, depth) case reflect.Slice, reflect.Array: return c.copySliceToSlice(dst, src, depth) default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } } -func (c *copier) copyStruct2Slice(dst, src reflect.Value, depth int) error { +func (c *copier) copyStructToSlice(dst, src reflect.Value, depth int) error { if dst.Kind() != reflect.Slice { - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } // 创建新的slice,长度+1 @@ -152,7 +153,6 @@ func (c *copier) copyToSlice(dst, src reflect.Value, depth int) error { } copyLen := min(dst.Len(), srcLen) - for i := range copyLen { if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil { return err @@ -168,7 +168,6 @@ func (c *copier) copyToArray(dst, src reflect.Value, depth int) error { // 取较小长度 copyLen := min(dstLen, srcLen) - for i := range copyLen { if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil { return err @@ -178,7 +177,7 @@ func (c *copier) copyToArray(dst, src reflect.Value, depth int) error { return nil } -func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error { +func (c *copier) copyStructToMap(dst, src reflect.Value, depth int) error { if dst.IsNil() { dst.Set(reflect.MakeMapWithSize(dst.Type(), src.NumField())) } @@ -210,26 +209,23 @@ func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error { continue } - if sField.Kind() == reflect.Pointer || sField.Kind() == reflect.Struct { - var newDst reflect.Value - if isTimeType(sField.Type()) { - newDst = reflect.New(sField.Type()).Elem() - if err := c.copyTime(newDst, sField); err != nil { - return err - } - } else { - var newDst = reflect.ValueOf(make(map[string]any)) - sField = indirect(sField) - - if err := c.deepCopy(newDst, sField, depth+1); err != nil { - return err - } + var copiedValue reflect.Value + if isTimeType(sField.Type()) { + copiedValue = reflect.New(sField.Type()).Elem() + if err := c.copyTime(copiedValue, sField); err != nil { + return err + } + } else if sField.Kind() == reflect.Struct { + copiedValue = reflect.ValueOf(make(map[string]any)) + sField = indirect(sField) + if err := c.deepCopy(copiedValue, sField, depth+1); err != nil { + return err } - - dst.SetMapIndex(reflect.ValueOf(name), newDst) } else { - dst.SetMapIndex(reflect.ValueOf(name), sField) + copiedValue = sField } + + dst.SetMapIndex(reflect.ValueOf(name), copiedValue) } return nil @@ -256,7 +252,7 @@ func (c *copier) copyMap(dst, src reflect.Value, depth int) error { } return nil default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } } @@ -266,8 +262,8 @@ func (c *copier) copyMapToMap(dst, src reflect.Value, depth int) error { } dstType, _ := indirectType(dst.Type()) - iter := src.MapRange() + for iter.Next() { key := iter.Key() value := iter.Value() @@ -285,40 +281,34 @@ func (c *copier) copyMapToMap(dst, src reflect.Value, depth int) error { key = reflect.ValueOf(fieldName) } - var copitedValue reflect.Value - switch dstType.Elem().Kind() { - case reflect.Interface: - switch value.Kind() { - case reflect.Interface: - copitedValue = reflect.New(value.Elem().Type()).Elem() - default: - copitedValue = reflect.New(value.Type()).Elem() - } - - c.set(copitedValue, value) - default: - copitedValue = reflect.New(dstType.Elem()).Elem() - } - - if err := c.deepCopy(copitedValue, value, depth+1); err != nil { + copiedValue := c.prepareMapValue(dstType.Elem(), value) + if err := c.deepCopy(copiedValue, value, depth+1); err != nil { return err } - - dst.SetMapIndex(key, copitedValue) + dst.SetMapIndex(key, copiedValue) } return nil } -func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error { +func (c *copier) prepareMapValue(elemType reflect.Type, value reflect.Value) reflect.Value { + if elemType.Kind() == reflect.Interface { + if value.Kind() == reflect.Interface { + return reflect.New(value.Elem().Type()).Elem() + } + return reflect.New(value.Type()).Elem() + } + return reflect.New(elemType).Elem() +} +func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error { fields := c.deepFields(dst.Type()) for _, sf := range fields { if nestedAnonymousField(dst, sf.Field.Name) { continue } - field := dst.FieldByName(sf.Field.Name) + field := dst.FieldByName(sf.Field.Name) if !field.CanSet() { continue } @@ -334,6 +324,7 @@ func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error { if c.opt.caseSensitive { mapValue = src.MapIndex(reflect.ValueOf(name)) } else { + // fix: 忽略大小写查询 mapValue = src.MapIndex(reflect.ValueOf(strings.ToLower(name))) } @@ -356,31 +347,21 @@ func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error { func (c *copier) copyStruct(dst, src reflect.Value, depth int) error { switch dst.Kind() { case reflect.Struct: - return c.copyStruct2Struct(dst, src, depth) + return c.copyStructToStruct(dst, src, depth) case reflect.Map: - return c.copyStruct2Map(dst, src, depth) + return c.copyStructToMap(dst, src, depth) default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } } -func (c *copier) copyStruct2Struct(dst, src reflect.Value, depth int) error { - if dst.CanSet() { - if src.Type().AssignableTo(dst.Type()) { - dst.Set(src) - return nil - } - - if _, ok := src.Interface().(time.Time); ok { - if dst.Type().AssignableTo(src.Type()) { - dst.Set(src) - return nil - } - } +func (c *copier) copyStructToStruct(dst, src reflect.Value, depth int) error { + if _, ok := src.Interface().(time.Time); ok && src.Type().AssignableTo(dst.Type()) { + dst.Set(src) + return nil } - typ := src.Type() - fields := c.deepFields(typ) + fields := c.deepFields(src.Type()) for _, sf := range fields { name := c.getFieldName(sf.Field.Name, sf.Tag) @@ -450,7 +431,7 @@ func (c *copier) copyPointer(dst, src reflect.Value, depth int) error { return c.deepCopy(dst, srcElem, depth+1) } -func (c *copier) set(dst, src reflect.Value) error { +func (c *copier) copyBasic(dst, src reflect.Value) error { if !src.IsValid() { return ErrInvalidCopyFrom } @@ -479,7 +460,7 @@ func (c *copier) set(dst, src reflect.Value) error { case reflect.Bool: return c.setBool(dst, src) default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } } @@ -500,16 +481,7 @@ func (c *copier) setString(dst, src reflect.Value) error { dst.SetString(string(src.Bytes())) return nil } - // 其他slice类型转换为字符串表示 - dst.SetString(fmt.Sprintf("%v", src.Interface())) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - dst.SetString(strconv.FormatInt(src.Int(), 10)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - dst.SetString(strconv.FormatUint(src.Uint(), 10)) - case reflect.Float32, reflect.Float64: - dst.SetString(strconv.FormatFloat(src.Float(), 'f', -1, 64)) - case reflect.Bool: - dst.SetString(strconv.FormatBool(src.Bool())) + fallthrough default: dst.SetString(fmt.Sprintf("%v", src.Interface())) } @@ -525,7 +497,7 @@ func (c *copier) setInt(dst, src reflect.Value) error { case reflect.Float32, reflect.Float64: dst.SetInt(int64(src.Float())) default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } return nil @@ -544,9 +516,9 @@ func (c *copier) setUint(dst, src reflect.Value) error { dst.SetUint(i) return nil } - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } return nil @@ -565,9 +537,9 @@ func (c *copier) setFloat(dst, src reflect.Value) error { dst.SetFloat(f) return nil } - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } return nil @@ -592,13 +564,13 @@ func (c *copier) setBool(dst, src reflect.Value) error { dst.SetBool(false) return nil } - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: dst.SetBool(src.Int() != 0) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: dst.SetBool(src.Uint() != 0) default: - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } return nil @@ -699,7 +671,7 @@ func (c *copier) copyTime(dst, src reflect.Value) error { return c.otherToTime(dst, src) } - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } func (c *copier) timeToOther(dst reflect.Value, src time.Time) error { @@ -724,7 +696,7 @@ func (c *copier) timeToOther(dst reflect.Value, src time.Time) error { } } - return ErrNotSupported + return ErrNotSupported(reflect.TypeOf(src), dst.Type()) } func (c *copier) otherToTime(dst reflect.Value, src reflect.Value) error { @@ -754,7 +726,7 @@ func (c *copier) otherToTime(dst reflect.Value, src reflect.Value) error { return c.mapToTime(dst, src) } - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) } func (c *copier) getTimeFormat(dstType reflect.Type) string { @@ -798,7 +770,7 @@ func (c *copier) stringToTime(dst reflect.Value, timeStr string) error { } } - return fmt.Errorf("无法解析时间字符串: %s", timeStr) + return fmt.Errorf("cannot parse time string: %s", timeStr) } func (c *copier) setTimeValue(dst reflect.Value, t time.Time) error { @@ -823,8 +795,12 @@ func (c *copier) setTimeValue(dst reflect.Value, t time.Time) error { return nil } - // 目标类型是自定义时间类型 - return nil + if setter, ok := dst.Addr().Interface().(interface{ SetTime(time.Time) }); ok { + setter.SetTime(t) + return nil + } + + return ErrNotSupported(reflect.TypeOf(t), dst.Type()) } func (c *copier) getTimeFormats(dstType reflect.Type) []string { @@ -869,7 +845,11 @@ func (c *copier) getTimeFormats(dstType reflect.Type) []string { func (c *copier) mapToTime(dst reflect.Value, src reflect.Value) error { if src.Type().Key().Kind() != reflect.String { - return ErrNotSupported + return ErrNotSupported(src.Type(), dst.Type()) + } + + if src.Len() == 0 { + return nil } // 从map中提取时间组件 @@ -905,9 +885,12 @@ func (c *copier) mapToTime(dst reflect.Value, src reflect.Value) error { } } + if year == 0 { + year = time.Now().Year() + } + t := time.Date(year, time.Month(month), day, hour, min, sec, nsec, loc) - dst.Set(reflect.ValueOf(t)) - return nil + return c.setTimeValue(dst, t) } func getIntValue(v reflect.Value) int64 { @@ -924,6 +907,7 @@ func getIntValue(v reflect.Value) int64 { return i } } + return 0 } @@ -940,12 +924,11 @@ func nestedAnonymousField(dst reflect.Value, fieldName string) bool { return false } - if !destField.CanSet() { + if destField.CanSet() { + destField.Set(reflect.New(destField.Type().Elem())) + } else { return true } - - newValue := reflect.New(destField.Type().Elem()) - destField.Set(newValue) } } @@ -1029,6 +1012,10 @@ func isTimeType(t reflect.Type) bool { return true } + if _, ok := t.MethodByName("Time"); ok { + return true + } + return false } @@ -1068,14 +1055,12 @@ func getTimeValueExt(v reflect.Value) reflect.Value { if v.IsValid() && v.CanInterface() { // 尝试通过接口转换 if converter, ok := v.Interface().(interface{ ToTime() time.Time }); ok { - timeVal := converter.ToTime() - return reflect.ValueOf(timeVal) + return reflect.ValueOf(converter.ToTime()) } // 尝试直接转换 if v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { - timeVal := v.Convert(reflect.TypeOf(time.Time{})) - return timeVal + return v.Convert(reflect.TypeOf(time.Time{})) } } @@ -1084,6 +1069,9 @@ func getTimeValueExt(v reflect.Value) reflect.Value { func indirect(v reflect.Value) reflect.Value { for v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface { + if v.IsNil() { + break + } v = v.Elem() } diff --git a/copier_map_test.go b/copier_map_test.go index 94f0360..195e984 100644 --- a/copier_map_test.go +++ b/copier_map_test.go @@ -92,14 +92,16 @@ func TestDiffStructMapCopy(t *testing.T) { func TestStructToMapCopy(t *testing.T) { type person struct { - Name string - Age int - Address *Address + Name string + Age int + Birthday time.Time + Address *Address } src := person{ - Name: "John", - Age: 30, + Name: "John", + Age: 30, + Birthday: time.Now(), Address: &Address{ City: "Beijing", }, @@ -227,7 +229,10 @@ func TestMixMapToStruct(t *testing.T) { t.Fatal(err) } + b, _ := json.MarshalIndent(dst, "", " ") + t.Log(dst) + t.Log(string(b)) } func BenchmarkMapCopy(b *testing.B) { diff --git a/copier_test.go b/copier_test.go index a5011eb..e0848d1 100644 --- a/copier_test.go +++ b/copier_test.go @@ -294,7 +294,7 @@ func TestDeepFileds(t *testing.T) { var emp = Employee{} typ := reflect.TypeOf(emp) - copier := newCopier() + copier := New() wrapper := copier.deepFields(typ) for _, v := range wrapper { t.Log(v) @@ -305,11 +305,17 @@ func BenchmarkCopy(b *testing.B) { var emp = Employee{ Name: "John", Age: 30, + Address: &Address{ + Country: "USA", + City: "New York", + }, } for b.Loop() { var dst Employee - Copy(&dst, &emp) + if err := Copy(&dst, &emp); err != nil { + b.Fatal(err) + } } } @@ -318,7 +324,7 @@ func BenchmarkDeepFields(b *testing.B) { typ := reflect.TypeOf(emp) for b.Loop() { - copier := newCopier() + copier := New() copier.deepFields(typ) } diff --git a/deepcopier.go b/deepcopier.go index 2eefecd..01c7fe6 100644 --- a/deepcopier.go +++ b/deepcopier.go @@ -3,7 +3,7 @@ package copier import "reflect" func Copy(dst, src any, opts ...option) error { - copier := newCopier(opts...) + copier := New(opts...) return copier.Copy(dst, src) } @@ -15,6 +15,6 @@ func Clone(src any, opts ...option) (any, error) { srcType := reflect.TypeOf(src) dst := reflect.New(srcType).Interface() - err := newCopier(opts...).Copy(dst, src) + err := New(opts...).Copy(dst, src) return dst, err } diff --git a/errors.go b/errors.go index ee8bae5..2370536 100644 --- a/errors.go +++ b/errors.go @@ -1,9 +1,27 @@ package copier -import "errors" +import ( + "errors" + "fmt" + "reflect" +) var ( ErrInvalidCopyDestination = errors.New("copy destination must be non-nil and addressable") ErrInvalidCopyFrom = errors.New("copy from must be non-nil and addressable") - ErrNotSupported = errors.New("not supported") + ErrCircularReference = errors.New("circular reference detected") + ErrMaxDepthExceeded = errors.New("max depth exceeded") ) + +type NotSupportedError struct { + SrcType reflect.Type + DstType reflect.Type +} + +func (e *NotSupportedError) Error() string { + return fmt.Sprintf("unsupported type conversion from %s to %s", e.SrcType, e.DstType) +} + +func ErrNotSupported(srcType, dstType reflect.Type) error { + return &NotSupportedError{SrcType: srcType, DstType: dstType} +}