From 5c4d428b5e74e144b675b04b67bf249c0632f731 Mon Sep 17 00:00:00 2001 From: charlie <3140647@qq.com> Date: Sun, 28 Sep 2025 16:16:30 +0800 Subject: [PATCH] optimize --- README.md | 9 + copier.go | 942 ++++++++++++++++++++++++++++++-------------- copier_map_test.go | 14 + copier_time_test.go | 87 ++++ optiongs.go | 30 ++ 5 files changed, 796 insertions(+), 286 deletions(-) create mode 100644 copier_time_test.go diff --git a/README.md b/README.md index 9f82b5c..8e511ca 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,15 @@ Golang object deep copy library +1. 支持深度复制,包括结构体、切片、数组、映射、指针、接口等。 +2. 支持通过标签(tag)来指定字段复制行为。 +3. 支持忽略空值。 +4. 支持字段名称转换(如驼峰转下划线)。 +5. 支持自定义转换函数。 +6. 支持时间类型的特殊处理。 +7. 支持循环引用的检测。 +8. 支持最大深度限制。 + - map->map - slice->slice - struct->struct diff --git a/copier.go b/copier.go index 75da9e7..acd3e2b 100644 --- a/copier.go +++ b/copier.go @@ -3,6 +3,7 @@ package copier import ( "fmt" "reflect" + "strconv" "strings" "sync" "time" @@ -10,15 +11,15 @@ import ( type copier struct { opt *options - visited map[reflect.Value]struct{} + visited map[uintptr]struct{} mu sync.Mutex } func newCopier(opts ...option) *copier { opt := getOpt(opts...) - var visited map[reflect.Value]struct{} + var visited map[uintptr]struct{} if opt.detectCircularRefs { - visited = make(map[reflect.Value]struct{}) + visited = make(map[uintptr]struct{}) } return &copier{ @@ -29,17 +30,14 @@ func newCopier(opts ...option) *copier { func (c *copier) Copy(dst, src any) error { c.mu.Lock() - defer c.mu.Unlock() defer c.reset() + defer c.mu.Unlock() if dst == nil || src == nil { return ErrInvalidCopyDestination } - var ( - srcValue = indirect(reflect.ValueOf(src)) - dstValue = indirect(reflect.ValueOf(dst)) - ) + srcValue, dstValue := indirect(reflect.ValueOf(src)), indirect(reflect.ValueOf(dst)) if !srcValue.IsValid() { return ErrInvalidCopyFrom @@ -54,7 +52,7 @@ func (c *copier) Copy(dst, src any) error { func (c *copier) reset() { if c.opt.detectCircularRefs { - c.visited = make(map[reflect.Value]struct{}) + c.visited = make(map[uintptr]struct{}) } } @@ -71,39 +69,25 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error { return nil } + // 可以直接赋值时直接赋值并返回 + srcType, dstType := src.Type(), dst.Type() + if srcType.AssignableTo(dstType) { + dst.Set(src) + return nil + } + + // 处理时间类型 + if isTimeType(srcType) || isTimeType(dstType) { + return c.copyTime(dst, src) + } + switch src.Kind() { case reflect.Slice, reflect.Array: - switch dst.Kind() { - case reflect.Slice, reflect.Array: - return c.copySlice(dst, src, depth) - default: - return ErrNotSupported - } + return c.copySlice(dst, src, depth) case reflect.Map: - switch dst.Kind() { - case reflect.Map: - return c.copyMap(dst, src, depth) - case reflect.Struct: - return c.copyMap2Struct(dst, src, depth) - case reflect.Pointer: - return c.copyPointer(dst, src, depth) - default: - return ErrNotSupported - } + return c.copyMap(dst, src, depth) case reflect.Struct: - switch dst.Kind() { - case reflect.Map: - // struct -> map - return c.copyStruct2Map(dst, src, depth) - case reflect.Struct: - // struct -> struct - return c.copyStruct(dst, src, depth) - case reflect.Array, reflect.Slice: - return c.copySlice(dst, src, depth) - default: - return c.set(dst, src) - } - + return c.copyStruct(dst, src, depth) case reflect.Pointer: return c.copyPointer(dst, src, depth) case reflect.Interface: @@ -115,17 +99,99 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error { } } -func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error { - for i, n := 0, src.NumField(); i < n; i++ { - sf := src.Type().Field(i) - if sf.PkgPath != "" && !sf.Anonymous { - continue - } +func (c *copier) copySlice(dst, src reflect.Value, depth int) error { + switch dst.Kind() { + case reflect.Struct: + return c.copyStruct2Slice(dst, src, depth) + case reflect.Slice, reflect.Array: + return c.copySliceToSlice(dst, src, depth) + default: + return ErrNotSupported + } +} - if sf.Anonymous { - switch sf.Type.Kind() { +func (c *copier) copyStruct2Slice(dst, src reflect.Value, depth int) error { + if dst.Kind() != reflect.Slice { + return ErrNotSupported + } + + // 创建新的slice,长度+1 + dstLen := dst.Len() + newSlice := reflect.MakeSlice(dst.Type(), dstLen+1, dstLen+1) + + // 复制原有元素 + for i := range dstLen { + if err := c.deepCopy(newSlice.Index(i), dst.Index(i), depth+1); err != nil { + return err + } + } + + // 追加新元素(结构体) + if err := c.deepCopy(newSlice.Index(dstLen), src, depth+1); err != nil { + return err + } + + dst.Set(newSlice) + return nil +} + +func (c *copier) copySliceToSlice(dst, src reflect.Value, depth int) error { + if dst.Kind() == reflect.Array { + return c.copyToArray(dst, src, depth) + } + + return c.copyToSlice(dst, src, depth) +} + +func (c *copier) copyToSlice(dst, src reflect.Value, depth int) error { + srcLen := src.Len() + + // 如果目标slice为空或长度不够,重新创建 + if dst.IsNil() || dst.Len() < srcLen { + newSlice := reflect.MakeSlice(dst.Type(), srcLen, srcLen) + dst.Set(newSlice) + } + + // 确定实际复制的元素数量 + copyLen := min(dst.Len(), srcLen) + + // 复制元素 + for i := range copyLen { + if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil { + return err + } + } + + return nil +} + +func (c *copier) copyToArray(dst, src reflect.Value, depth int) error { + srcLen := src.Len() + dstLen := dst.Len() + + // 取较小长度 + copyLen := min(dstLen, srcLen) + + for i := range copyLen { + if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil { + return err + } + } + + return nil +} + +func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error { + if dst.IsNil() { + dst.Set(reflect.MakeMapWithSize(dst.Type(), src.NumField())) + } + + fields := c.deepFields(src.Type()) + for _, sf := range fields { + if sf.Field.Anonymous { + switch sf.Field.Type.Kind() { case reflect.Struct, reflect.Pointer: - if err := c.deepCopy(dst, src.Field(i), depth+1); err != nil { + if err := c.deepCopy(dst, src.Field(sf.Index), depth+1); err != nil { return err } } @@ -133,25 +199,33 @@ func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error { continue } - tag := parseTag(sf.Tag.Get(c.opt.tagName)) + tag := parseTag(sf.Field.Tag.Get(c.opt.tagName)) if tag.Contains(tagIgnore) { continue } - name := c.getFieldName(sf.Name, tag) + name := c.getFieldName(sf.Field.Name, tag) - sField := src.Field(i) + sField := src.Field(sf.Index) if c.opt.ignoreEmpty && sField.IsZero() { continue } if sField.Kind() == reflect.Pointer || sField.Kind() == reflect.Struct { - var newDst = reflect.ValueOf(make(map[string]any)) - sField = indirect(sField) + var newDst reflect.Value + if isTimeType(sField.Type()) { + newDst = reflect.New(sField.Type()).Elem() + if err := c.copyTime(newDst, sField); err != nil { + return err + } + } else { + var newDst = reflect.ValueOf(make(map[string]any)) + sField = indirect(sField) - if err := c.deepCopy(newDst, sField, depth+1); err != nil { - return err + if err := c.deepCopy(newDst, sField, depth+1); err != nil { + return err + } } dst.SetMapIndex(reflect.ValueOf(name), newDst) @@ -163,48 +237,18 @@ func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error { return nil } -func (c *copier) copyMap2Struct(dst, src reflect.Value, depth int) error { - typ := dst.Type() - for i, n := 0, dst.NumField(); i < n; i++ { - field := dst.Field(i) - sf := typ.Field(i) - - if !field.CanSet() { - continue - } - - if sf.Anonymous { - if err := c.deepCopy(field, src, depth+1); err != nil { - return err - } - } - - name := sf.Name - - mapValue := src.MapIndex(reflect.ValueOf(name)) - if !mapValue.IsValid() { - continue - } - - if mapValue.Kind() == reflect.Interface { - mapValue = mapValue.Elem() - } - - if mapValue.Kind() == reflect.Map || mapValue.Kind() == reflect.Array || mapValue.Kind() == reflect.Slice { - if err := c.deepCopy(field, mapValue, depth+1); err != nil { - return err - } - } else { - if err := c.set(field, mapValue); err != nil { - return err - } - } +func (c *copier) copyMap(dst, src reflect.Value, depth int) error { + switch dst.Kind() { + case reflect.Map: + return c.copyMap2Map(dst, src, depth) + case reflect.Struct: + return c.copyMapToStruct(dst, src, depth) + default: + return ErrNotSupported } - - return nil } -func (c *copier) copyMap(dst, src reflect.Value, depth int) error { +func (c *copier) copyMap2Map(dst, src reflect.Value, depth int) error { if dst.IsNil() { dst.Set(reflect.MakeMapWithSize(dst.Type(), src.Len())) } @@ -216,15 +260,19 @@ func (c *copier) copyMap(dst, src reflect.Value, depth int) error { key := iter.Key() value := iter.Value() + if c.opt.ignoreEmpty && value.IsZero() { + continue + } + + if key.Kind() == reflect.Interface { + key = key.Elem() + } + if name, ok := key.Interface().(string); ok { fieldName := c.getFieldName(name, nil) key = reflect.ValueOf(fieldName) } - if c.opt.ignoreEmpty && value.IsZero() { - continue - } - var copitedValue reflect.Value switch dstType.Elem().Kind() { case reflect.Interface: @@ -250,87 +298,34 @@ func (c *copier) copyMap(dst, src reflect.Value, depth int) error { return nil } -func (c *copier) copySlice(dst, src reflect.Value, depth int) error { - switch src.Kind() { - case reflect.Struct: - return c.copyStruct2Slice(dst, src, depth) - case reflect.Slice, reflect.Array: - return c.copySlice2Slice(dst, src, depth) - default: - return ErrNotSupported - } -} +func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error { -func (c *copier) copyStruct2Slice(dst, src reflect.Value, depth int) error { - if dst.Kind() != reflect.Slice { - return ErrNotSupported - } + fields := c.deepFields(dst.Type()) + for _, sf := range fields { + field := dst.Field(sf.Index) - // 创建新的slice,长度+1 - dstLen := dst.Len() - newSlice := reflect.MakeSlice(dst.Type(), dstLen+1, dstLen+1) - - // 复制原有元素 - for i := 0; i < dstLen; i++ { - if err := c.deepCopy(newSlice.Index(i), dst.Index(i), depth+1); err != nil { - return err + if !field.CanSet() { + continue } - } - // 追加新元素(结构体) - if err := c.deepCopy(newSlice.Index(dstLen), src, depth+1); err != nil { - return err - } - - dst.Set(newSlice) - return nil -} - -func (c *copier) copySlice2Slice(dst, src reflect.Value, depth int) error { - if dst.Kind() == reflect.Array { - return c.copyToArray(dst, src, depth) - } - - return c.copyToSlice(dst, src, depth) -} - -func (c *copier) copyToSlice(dst, src reflect.Value, depth int) error { - srcLen := src.Len() - - // 如果目标slice为空或长度不够,重新创建 - if dst.IsNil() || dst.Len() < srcLen { - newSlice := reflect.MakeSlice(dst.Type(), srcLen, srcLen) - dst.Set(newSlice) - } - - // 确定实际复制的元素数量 - copyLen := srcLen - if dst.Len() < srcLen { - copyLen = dst.Len() - } - - // 复制元素 - for i := 0; i < copyLen; i++ { - if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil { - return err + if sf.Field.Anonymous { + if err := c.deepCopy(field, src, depth+1); err != nil { + return err + } } - } - return nil -} + name := sf.Field.Name -func (c *copier) copyToArray(dst, src reflect.Value, depth int) error { - srcLen := src.Len() - dstLen := dst.Len() + mapValue := src.MapIndex(reflect.ValueOf(name)) + if !mapValue.IsValid() { + continue + } - // 取较小长度 - copyLen := srcLen - if dstLen < srcLen { - copyLen = dstLen - } + if mapValue.Kind() == reflect.Interface { + mapValue = mapValue.Elem() + } - for i := 0; i < copyLen; i++ { - if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil { + if err := c.deepCopy(field, mapValue, depth+1); err != nil { return err } } @@ -339,6 +334,17 @@ func (c *copier) copyToArray(dst, src reflect.Value, depth int) error { } func (c *copier) copyStruct(dst, src reflect.Value, depth int) error { + switch dst.Kind() { + case reflect.Struct: + return c.copyStruct2Struct(dst, src, depth) + case reflect.Map: + return c.copyStruct2Map(dst, src, depth) + default: + return ErrNotSupported + } +} + +func (c *copier) copyStruct2Struct(dst, src reflect.Value, depth int) error { if dst.CanSet() { if src.Type().AssignableTo(dst.Type()) { dst.Set(src) @@ -379,7 +385,7 @@ func (c *copier) copyStruct(dst, src reflect.Value, depth int) error { if ok, err := c.lookupAndCopyWithConverter(dstValue, sField, name); err != nil { return err } else if ok { - return nil + continue } if err := c.deepCopy(dstValue, sField, depth+1); err != nil { @@ -395,10 +401,6 @@ func (c *copier) copyInterface(dst, src reflect.Value, depth int) error { } srcElem := src.Elem() - if !srcElem.IsValid() { - return nil - } - if dst.Kind() == reflect.Interface { newValue := reflect.New(srcElem.Type()).Elem() if err := c.deepCopy(newValue, srcElem, depth+1); err != nil { @@ -413,19 +415,14 @@ func (c *copier) copyInterface(dst, src reflect.Value, depth int) error { } func (c *copier) copyPointer(dst, src reflect.Value, depth int) error { - if src.Kind() == reflect.Pointer && src.IsNil() { + if src.IsNil() { return nil } srcElem := indirect(src) if dst.Kind() == reflect.Pointer { if dst.IsNil() { - if !dst.CanSet() { - return nil - } - - newPtr := reflect.New(dst.Type().Elem()) - dst.Set(newPtr) + dst.Set(reflect.New(dst.Type().Elem())) } return c.deepCopy(dst.Elem(), srcElem, depth+1) } @@ -438,27 +435,8 @@ func (c *copier) set(dst, src reflect.Value) error { return ErrInvalidCopyFrom } - if src.Kind() == reflect.Interface { - src = src.Elem() - } - - // if ok, err := c.lookupAndCopyWithConverter(dst, src, fieldName); err != nil { - // return err - // } else if ok { - // return nil - // } - - if dst.Kind() == reflect.Pointer { - if dst.IsNil() { - if !dst.CanSet() { - return ErrInvalidCopyDestination - } - dst.Set(reflect.New(dst.Type().Elem())) - } - - dst = dst.Elem() - } - + src = indirect(src) + dst = ensureSettable(dst) if src.Type().AssignableTo(dst.Type()) { dst.Set(src) return nil @@ -469,71 +447,136 @@ func (c *copier) set(dst, src reflect.Value) error { return nil } - if _, ok := dst.Interface().(time.Time); ok { - switch v := src.Interface().(type) { - case string: - if t, err := time.Parse(time.RFC3339, v); err == nil { - dst.Set(reflect.ValueOf(t)) - } - case int64: - dst.Set(reflect.ValueOf(time.Unix(v, 0))) - } - - return nil - } - switch dst.Kind() { case reflect.String: - switch v := src.Interface().(type) { - case string: - dst.SetString(v) - case []byte: - dst.SetString(string(v)) - case time.Time: - dst.SetString(v.Format(time.RFC3339)) - default: - dst.SetString(fmt.Sprintf("%v", v)) - } + return c.setString(dst, src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch v := src.Interface().(type) { - case int: - dst.SetInt(int64(v)) - case int8: - dst.SetInt(int64(v)) - case int16: - dst.SetInt(int64(v)) - case int32: - dst.SetInt(int64(v)) - case int64: - dst.SetInt(v) - case time.Time: - dst.SetInt(v.Unix()) - } + return c.setInt(dst, src) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - switch v := src.Interface().(type) { - case uint: - dst.SetUint(uint64(v)) - case uint8: - dst.SetUint(uint64(v)) - case uint16: - dst.SetUint(uint64(v)) - case uint32: - dst.SetUint(uint64(v)) - case uint64: - dst.SetUint(v) - } + return c.setUint(dst, src) case reflect.Float32, reflect.Float64: - switch v := src.Interface().(type) { - case float32: - dst.SetFloat(float64(v)) - case float64: - dst.SetFloat(v) - } + return c.setFloat(dst, src) case reflect.Bool: - switch v := src.Interface().(type) { - case bool: - dst.SetBool(v) + return c.setBool(dst, src) + default: + return ErrNotSupported + } +} + +func ensureSettable(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Pointer && v.IsNil() && v.CanSet() { + v.Set(reflect.New(v.Type().Elem())) + return v.Elem() + } + return v +} + +func (c *copier) setString(dst, src reflect.Value) error { + switch src.Kind() { + case reflect.String: + dst.SetString(src.String()) + case reflect.Slice: + if src.Type().Elem().Kind() == reflect.Uint8 { + dst.SetString(string(src.Bytes())) + return nil } + // 其他slice类型转换为字符串表示 + dst.SetString(fmt.Sprintf("%v", src.Interface())) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetString(strconv.FormatInt(src.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.SetString(strconv.FormatUint(src.Uint(), 10)) + case reflect.Float32, reflect.Float64: + dst.SetString(strconv.FormatFloat(src.Float(), 'f', -1, 64)) + case reflect.Bool: + dst.SetString(strconv.FormatBool(src.Bool())) + default: + dst.SetString(fmt.Sprintf("%v", src.Interface())) + } + return nil +} + +func (c *copier) setInt(dst, src reflect.Value) error { + switch src.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetInt(src.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.SetInt(int64(src.Uint())) + case reflect.Float32, reflect.Float64: + dst.SetInt(int64(src.Float())) + default: + return ErrNotSupported + } + + return nil +} + +func (c *copier) setUint(dst, src reflect.Value) error { + switch src.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetUint(uint64(src.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.SetUint(src.Uint()) + case reflect.Float32, reflect.Float64: + dst.SetUint(uint64(src.Float())) + case reflect.String: + if i, err := strconv.ParseUint(src.String(), 10, 64); err == nil { + dst.SetUint(i) + return nil + } + return ErrNotSupported + default: + return ErrNotSupported + } + + return nil +} + +func (c *copier) setFloat(dst, src reflect.Value) error { + switch src.Kind() { + case reflect.Float32, reflect.Float64: + dst.SetFloat(src.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetFloat(float64(src.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.SetFloat(float64(src.Uint())) + case reflect.String: + if f, err := strconv.ParseFloat(src.String(), 64); err == nil { + dst.SetFloat(f) + return nil + } + return ErrNotSupported + default: + return ErrNotSupported + } + + return nil +} + +func (c *copier) setBool(dst, src reflect.Value) error { + switch src.Kind() { + case reflect.Bool: + dst.SetBool(src.Bool()) + case reflect.String: + if b, err := strconv.ParseBool(src.String()); err == nil { + dst.SetBool(b) + return nil + } + // 尝试常见布尔表示 + val := strings.ToLower(src.String()) + if val == "true" || val == "yes" || val == "1" || val == "on" { + dst.SetBool(true) + return nil + } + if val == "false" || val == "no" || val == "0" || val == "off" { + dst.SetBool(false) + return nil + } + return ErrNotSupported + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetBool(src.Int() != 0) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.SetBool(src.Uint() != 0) default: return ErrNotSupported } @@ -600,6 +643,271 @@ func (c *copier) calculateDeepFields(reflectType reflect.Type) []fieldWrapper { return res } +func (c *copier) lookupAndCopyWithConverter(dst, src reflect.Value, fieldName string) (bool, error) { + if cnv, ok := c.opt.convertByName[fieldName]; ok { + return convert(dst, src, cnv) + } + + if len(c.opt.converters) > 0 { + pair := converterPair{ + SrcType: src.Type(), + DstType: dst.Type(), + } + + if cnv, ok := c.opt.converters[pair]; ok { + return convert(dst, src, cnv) + } + } + + return false, nil +} + +func (c *copier) copyTime(dst, src reflect.Value) error { + srcTime, dstTime := getTimeValueExt(src), getTimeValueExt(dst) + + // 双方都是时间类型,直接赋值(需要处理自定义类型转换) + if srcTime.IsValid() && dstTime.IsValid() { + dstTime.Set(srcTime) + } + + // 源是时间,目标是其他类型 + if srcTime.IsValid() { + return c.timeToOther(dst, srcTime.Interface().(time.Time)) + } + + // 目标是时间,源是其他类型 + if dstTime.IsValid() { + return c.otherToTime(dst, src) + } + + return ErrNotSupported +} + +func (c *copier) timeToOther(dst reflect.Value, src time.Time) error { + switch dst.Kind() { + case reflect.String: + format := c.getTimeFormat(dst.Type()) + dst.SetString(src.Format(format)) + return nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.SetInt(src.Unix()) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.SetUint(uint64(src.Unix())) + return nil + case reflect.Float32, reflect.Float64: + dst.SetFloat(float64(src.Unix()) + float64(src.Nanosecond())/1e9) + return nil + case reflect.Struct: + if dst.Type() == reflect.TypeOf(time.Time{}) { + dst.Set(reflect.ValueOf(src)) + return nil + } + } + + return ErrNotSupported +} + +func (c *copier) otherToTime(dst reflect.Value, src reflect.Value) error { + src = indirect(src) + switch src.Kind() { + case reflect.String: + return c.stringToTime(dst, src.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + dst.Set(reflect.ValueOf(time.Unix(src.Int(), 0))) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + dst.Set(reflect.ValueOf(time.Unix(int64(src.Uint()), 0))) + return nil + case reflect.Float32, reflect.Float64: + sec := src.Float() + nsec := int64((sec - float64(int64(sec))) * 1e9) + dst.Set(reflect.ValueOf(time.Unix(int64(sec), nsec))) + return nil + case reflect.Struct: + if src.Type() == reflect.TypeOf(time.Time{}) { + dst.Set(src) + return nil + } + + case reflect.Map: + // 支持从map构造时间,如:map[string]int{"year": 2023, "month": 1, "day": 1} + return c.mapToTime(dst, src) + } + + return ErrNotSupported +} + +func (c *copier) getTimeFormat(dstType reflect.Type) string { + if format, ok := c.opt.timeFormats[dstType]; ok { + return format + } + + if dstType.Kind() == reflect.String { + if format, ok := c.opt.timeFormats[reflect.TypeOf("")]; ok { + return format + } + } + + if format, ok := c.opt.timeFormats[reflect.TypeOf(nil)]; ok { + return format + } + + return time.RFC3339 +} + +func (c *copier) stringToTime(dst reflect.Value, timeStr string) error { + // 首先尝试数字字符串(时间戳) + if sec, err := strconv.ParseInt(timeStr, 10, 64); err == nil { + dst.Set(reflect.ValueOf(time.Unix(sec, 0))) + return nil + } + + // 尝试浮点数时间戳 + if sec, err := strconv.ParseFloat(timeStr, 64); err == nil { + nsec := int64((sec - float64(int64(sec))) * 1e9) + dst.Set(reflect.ValueOf(time.Unix(int64(sec), nsec))) + return nil + } + + // 尝试各种时间格式 + formats := c.getTimeFormats(dst.Type()) + for _, format := range formats { + if t, err := time.Parse(format, timeStr); err == nil { + c.setTimeValue(dst, t) + return nil + } + } + + return fmt.Errorf("无法解析时间字符串: %s", timeStr) +} + +func (c *copier) setTimeValue(dst reflect.Value, t time.Time) error { + if !dst.CanSet() { + return ErrInvalidCopyDestination + } + + // 处理指针类型 + if dst.Kind() == reflect.Pointer { + if dst.IsNil() { + if !dst.CanSet() { + return ErrInvalidCopyDestination + } + dst.Set(reflect.New(dst.Type().Elem())) + } + dst = dst.Elem() + } + + // 目标类型是标准 time.Time + if dst.Type() == reflect.TypeOf(time.Time{}) { + dst.Set(reflect.ValueOf(t)) + return nil + } + + // 目标类型是自定义时间类型 + return nil +} + +func (c *copier) getTimeFormats(dstType reflect.Type) []string { + var formats []string + + // 添加特定类型的格式 + if format, ok := c.opt.timeFormats[dstType]; ok { + formats = append(formats, format) + } + + // 添加字符串类型的格式(如果目标类型是字符串) + if dstType.Kind() == reflect.String { + if format, ok := c.opt.timeFormats[reflect.TypeOf("")]; ok { + formats = append(formats, format) + } + } + + // 添加默认格式 + if format, ok := c.opt.timeFormats[reflect.TypeOf(nil)]; ok { + formats = append(formats, format) + } + + // 添加内置的常见格式 + builtinFormats := []string{ + time.RFC3339, + time.RFC3339Nano, + "2006-01-02 15:04:05", + "2006-01-02", + "15:04:05", + time.RFC1123, + time.RFC822, + time.ANSIC, + "2006/01/02", + "2006/01/02 15:04:05", + "02 Jan 2006", + "02 Jan 2006 15:04:05", + } + + formats = append(formats, builtinFormats...) + return formats +} + +func (c *copier) mapToTime(dst reflect.Value, src reflect.Value) error { + if src.Type().Key().Kind() != reflect.String { + return ErrNotSupported + } + + // 从map中提取时间组件 + year, month, day, hour, min, sec, nsec := 0, 1, 1, 0, 0, 0, 0 + loc := time.UTC + + iter := src.MapRange() + for iter.Next() { + key := iter.Key().String() + value := iter.Value() + + switch key { + case "year", "Year": + year = int(getIntValue(value)) + case "month", "Month": + month = int(getIntValue(value)) + case "day", "Day": + day = int(getIntValue(value)) + case "hour", "Hour": + hour = int(getIntValue(value)) + case "minute", "Minute", "min", "Min": + min = int(getIntValue(value)) + case "second", "Second", "sec", "Sec": + sec = int(getIntValue(value)) + case "nanosecond", "Nanosecond", "nsec", "Nsec": + nsec = int(getIntValue(value)) + case "location", "Location", "loc", "Loc": + if locStr, ok := value.Interface().(string); ok { + if location, err := time.LoadLocation(locStr); err == nil { + loc = location + } + } + } + } + + t := time.Date(year, time.Month(month), day, hour, min, sec, nsec, loc) + dst.Set(reflect.ValueOf(t)) + return nil +} + +func getIntValue(v reflect.Value) int64 { + v = indirect(v) + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(v.Uint()) + case reflect.Float32, reflect.Float64: + return int64(v.Float()) + case reflect.String: + if i, err := strconv.ParseInt(v.String(), 10, 64); err == nil { + return i + } + } + return 0 +} + func nestedAnonymousField(dst reflect.Value, fieldName string) bool { if f, ok := dst.Type().FieldByName(fieldName); ok { if len(f.Index) > 1 { @@ -625,25 +933,6 @@ func nestedAnonymousField(dst reflect.Value, fieldName string) bool { return false } -func (c *copier) lookupAndCopyWithConverter(dst, src reflect.Value, fieldName string) (bool, error) { - if cnv, ok := c.opt.convertByName[fieldName]; ok { - return convert(dst, src, cnv) - } - - if len(c.opt.converters) > 0 { - pair := converterPair{ - SrcType: src.Type(), - DstType: dst.Type(), - } - - if cnv, ok := c.opt.converters[pair]; ok { - return convert(dst, src, cnv) - } - } - - return false, nil -} - func convert(dst, src reflect.Value, fn convertFunc) (bool, error) { result, err := fn(src.Interface()) if err != nil { @@ -684,21 +973,102 @@ func (c *copier) isCircularRefs(v reflect.Value) bool { return false } - addr := v.Addr() - if _, exists := c.visited[addr]; exists { + ptr := c.getValuePointer(v) + if _, exists := c.visited[ptr]; exists { return true } - c.visited[addr] = struct{}{} + c.visited[ptr] = struct{}{} return false } -func indirect(reflectValue reflect.Value) reflect.Value { - for reflectValue.Kind() == reflect.Pointer { - reflectValue = reflectValue.Elem() +func (c *copier) getValuePointer(v reflect.Value) uintptr { + switch v.Kind() { + case reflect.Pointer, reflect.Slice, reflect.Map, reflect.Func, reflect.Chan: + return v.Pointer() + case reflect.Struct: + if v.CanAddr() { + return v.Addr().Pointer() + } } - return reflectValue + return 0 +} + +func isTimeType(t reflect.Type) bool { + if t == nil { + return false + } + + // 处理指针类型 + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + + // 直接比较标准 time.Time 类型 + if t == reflect.TypeOf(time.Time{}) { + return true + } + + return false +} + +func getTimeValue(v reflect.Value) reflect.Value { + if !v.IsValid() { + return reflect.Value{} + } + + // 标准 time.Time 类型 + if v.Type() == reflect.TypeOf(time.Time{}) { + return v + } + + // 检查是否为时间类型 + if isTimeType(v.Type()) { + return v + } + + // 指针类型处理 + if v.Kind() == reflect.Pointer && !v.IsNil() { + elem := v.Elem() + if isTimeType(elem.Type()) { + return elem + } + } + + return reflect.Value{} +} + +func getTimeValueExt(v reflect.Value) reflect.Value { + standardTime := getTimeValue(v) + if standardTime.IsValid() { + return standardTime + } + + // 如果无法直接识别,尝试转换 + if v.IsValid() && v.CanInterface() { + // 尝试通过接口转换 + if converter, ok := v.Interface().(interface{ ToTime() time.Time }); ok { + timeVal := converter.ToTime() + return reflect.ValueOf(timeVal) + } + + // 尝试直接转换 + if v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + timeVal := v.Convert(reflect.TypeOf(time.Time{})) + return timeVal + } + } + + return reflect.Value{} +} + +func indirect(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Pointer || v.Kind() == reflect.Interface { + v = v.Elem() + } + + return v } func indirectType(reflectType reflect.Type) (_ reflect.Type, isPtr bool) { diff --git a/copier_map_test.go b/copier_map_test.go index e0d721a..1fb1713 100644 --- a/copier_map_test.go +++ b/copier_map_test.go @@ -275,3 +275,17 @@ func BenchmarkMapCopy(b *testing.B) { } } + +func BenchmarkStruct2Map(b *testing.B) { + var src = Person{ + Name: "John", + Age: 30, + } + + for b.Loop() { + var dst map[string]any + if err := Copy(&dst, src); err != nil { + b.Fatal(err) + } + } +} diff --git a/copier_time_test.go b/copier_time_test.go new file mode 100644 index 0000000..3e82623 --- /dev/null +++ b/copier_time_test.go @@ -0,0 +1,87 @@ +package copier_test + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "git.charlienet.top/go/copier" +) + +type MyString string +type MyTime time.Time + +func TestTime(t *testing.T) { + // 示例1:为内置字符串类型设置时间格式 + now := time.Now() + var str string + copier.Copy(&str, now, copier.WithStringTimeFormat("2006-01-02 15:04:05")) + fmt.Printf("时间转字符串: %s\n", str) + + // 示例2:为自定义字符串类型设置时间格式 + var myStr MyString + copier.Copy(&myStr, now, copier.WithTimeFormat(MyString(""), "2006/01/02")) + fmt.Printf("时间转自定义字符串: %s\n", myStr) + + // 示例3:设置默认时间格式 + var str1 string + copier.Copy(&str1, now, copier.WithTimeFormat(MyString(""), "2006/01/02")) + fmt.Printf("默认格式: %s\n", str1) + + // 示例5:字符串转时间 + timeStr := "2023-12-25 10:30:00" + var parsedTime time.Time + copier.Copy(&parsedTime, timeStr, + copier.WithStringTimeFormat("2006-01-02 15:04:05"), + copier.WithTimeFormat(MyString(""), "2006/01/02"), + copier.WithDefaultTimeFormat(time.RFC3339)) + fmt.Printf("字符串转时间: %v\n", parsedTime) +} + +func TestTimeStruct(t *testing.T) { + type User struct { + Name string + CreatedAt *time.Time + } + + src := map[string]any{ + "Name": "John", + "CreatedAt": "2023-01-01T10:00:00Z", + } + + var user User + if err := copier.Copy(&user, src, copier.WithStringTimeFormat(time.RFC3339)); err != nil { + fmt.Printf("错误: %v\n", err) + } else { + fmt.Printf("用户: %+v\n", user) + b, _ := json.Marshal(user) + t.Log(string(b)) + } + + // user.CreatedAt +} + +func TestFieldTag(t *testing.T) { + type User struct { + Name string + CreatedAt time.Time `copier:"format=2006-01-02 15:04:05"` + } + + type User2 struct { + Name string + CreatedAt string + } + + var src = User{ + Name: "John", + CreatedAt: time.Now(), + } + + var dst User2 + if err := copier.Copy(&dst, src); err != nil { + fmt.Printf("错误: %v\n", err) + } else { + fmt.Printf("用户: %+v\n", dst) + } +} diff --git a/optiongs.go b/optiongs.go index 0442ef7..9bb5f10 100644 --- a/optiongs.go +++ b/optiongs.go @@ -19,6 +19,7 @@ type options struct { converters map[converterPair]convertFunc // 根据源和目标类型处理的类型转换器v convertByName map[string]convertFunc // 根据名称处理的类型转换器 nameConverter func(string) string // 字段名转换函数 + timeFormats map[reflect.Type]string // 时间格式 } type option func(*options) @@ -156,6 +157,35 @@ func WithNameFn(fn func(string) string) option { } } +func WithTimeFormat(targetType any, format string) option { + return func(o *options) { + if o.timeFormats == nil { + o.timeFormats = make(map[reflect.Type]string) + } + + t := reflect.TypeOf(targetType) + if t.Kind() == reflect.Pointer { + t = t.Elem() + } + + o.timeFormats[t] = format + } +} + +func WithStringTimeFormat(format string) option { + return WithTimeFormat("", format) +} + +func WithDefaultTimeFormat(format string) option { + return func(o *options) { + if o.timeFormats == nil { + o.timeFormats = make(map[reflect.Type]string) + } + + o.timeFormats[reflect.TypeOf(nil)] = format + } +} + // WithTagName 添加标签名 func WithTagName(tagName string) option { return func(o *options) {