diff --git a/copier.go b/copier.go index df28d15..b5cb489 100644 --- a/copier.go +++ b/copier.go @@ -4,7 +4,10 @@ import ( "fmt" "reflect" "strings" + "sync" "time" + + "github.com/modern-go/reflect2" ) func Copy(dst, src any, opts ...option) error { @@ -72,6 +75,30 @@ func deepCopy(dst, src reflect.Value, depth int, fieldName string, opt *options) case reflect.Struct: // struct -> struct return copyStruct(dst, src, depth, opt) + case reflect.Array, reflect.Slice: + if dst.IsNil() { + dstType := dst.Type().Elem() + newDst := reflect.MakeSlice(reflect.SliceOf(dstType), 1, 1) + dst.Set(newDst) + + if err := deepCopy(dst.Index(0), src, depth+1, fieldName, opt); err != nil { + return err + } + + } else { + len := dst.Len() + cap := dst.Cap() + newSlice := reflect.MakeSlice(reflect.SliceOf(dst.Type().Elem()), len+1, cap+10) + reflect.Copy(newSlice, dst) + + newDst := newSlice.Index(len) + if err := deepCopy(newDst, src, depth+1, fieldName, opt); err != nil { + return err + } + dst.Set(newSlice) + } + + return nil default: return set(dst, src, fieldName, opt) } @@ -148,7 +175,7 @@ func copyStruct2Map(dst, src reflect.Value, depth int, fieldName string, opt *op continue } - name := toName(sf.Name, tag, opt) + name := getFieldName(sf.Name, tag, opt) sField := src.Field(i) @@ -191,7 +218,7 @@ func copyMap2Struct(dst, src reflect.Value, depth int, _ string, opt *options) e } } - name := getFieldName(sf) + name := getFieldNameFromJsonTag(sf) if name == "-" { continue } @@ -295,25 +322,29 @@ func copyStruct(dst, src reflect.Value, depth int, opt *options) error { } typ := src.Type() - for i, n := 0, src.NumField(); i < n; i++ { - sf := typ.Field(i) - if sf.PkgPath != "" && !sf.Anonymous { - continue - } - + fields := deepFields(typ) + for _, sf := range fields { tag := parseTag(sf.Tag.Get(opt.tagName)) if tag.Contains(tagIgnore) { continue } - name := toName(sf.Name, tag, opt) + name := getFieldName(sf.Name, tag, opt) + + if nestedAnonymousField(dst, name) { + continue + } + + typ := reflect2.TypeOf(dst) + dstPtr := reflect2.PtrOf(dst) + typ.(reflect2.StructType).Field(1).UnsafeGet(dstPtr) dstValue := fieldByName(dst, name, opt) if !dstValue.IsValid() { continue } - sField := src.Field(i) + sField := src.FieldByName(sf.Name) if opt.ignoreEmpty && sField.IsZero() { continue } @@ -381,6 +412,8 @@ func set(dst, src reflect.Value, fieldName string, opt *options) error { dst.SetString(v) case []byte: dst.SetString(string(v)) + case time.Time: + dst.SetString(v.Format(time.RFC3339)) default: dst.SetString(fmt.Sprintf("%v", v)) } @@ -395,6 +428,9 @@ func set(dst, src reflect.Value, fieldName string, opt *options) error { case int32: dst.SetInt(int64(v)) case int64: + dst.SetInt(v) + case time.Time: + dst.SetInt(v.Unix()) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: switch v := src.Interface().(type) { @@ -406,6 +442,8 @@ func set(dst, src reflect.Value, fieldName string, opt *options) error { dst.SetUint(uint64(v)) case uint32: dst.SetUint(uint64(v)) + case uint64: + dst.SetUint(v) } case reflect.Float32, reflect.Float64: switch v := src.Interface().(type) { @@ -426,6 +464,82 @@ func set(dst, src reflect.Value, fieldName string, opt *options) error { return nil } +type fieldsWrapper struct { + Fields []reflect.StructField + once sync.Once +} + +var deepFieldsMap sync.Map + +func deepFields(reflectType reflect.Type) []reflect.StructField { + if wrapper, ok := deepFieldsMap.Load(reflectType); ok { + w := wrapper.(*fieldsWrapper) + w.once.Do(func() { + w.Fields = calculateDeepFields(reflectType) + }) + return w.Fields + } + + wrapper, loaded := deepFieldsMap.LoadOrStore(reflectType, &fieldsWrapper{}) + w := wrapper.(*fieldsWrapper) + if !loaded { + w.once.Do(func() { + w.Fields = calculateDeepFields(reflectType) + }) + } else { + w.once.Do(func() {}) + } + + return w.Fields +} + +func calculateDeepFields(reflectType reflect.Type) []reflect.StructField { + reflectType, _ = indirectType(reflectType) + num := reflectType.NumField() + res := make([]reflect.StructField, 0, num) + if reflectType.Kind() == reflect.Struct { + for i := range num { + sf := reflectType.Field(i) + if sf.PkgPath != "" && !sf.Anonymous { + continue + } + + if sf.Anonymous { + res = append(res, deepFields(sf.Type)...) + } + + res = append(res, sf) + } + } + + return res +} + +func nestedAnonymousField(dst reflect.Value, fieldName string) bool { + if f, ok := dst.Type().FieldByName(fieldName); ok { + if len(f.Index) > 1 { + destField := dst.Field(f.Index[0]) + + if destField.Kind() != reflect.Pointer { + return false + } + + if !destField.IsNil() { + return false + } + + if !destField.CanSet() { + return true + } + + newValue := reflect.New(destField.Type().Elem()) + destField.Set(newValue) + } + } + + return false +} + func lookupAndCopyWithConverter(dst, src reflect.Value, fieldName string, opt *options) (bool, error) { if cnv, ok := opt.convertByName[fieldName]; ok { return convert(dst, src, cnv) @@ -461,7 +575,7 @@ func convert(dst, src reflect.Value, fn convertFunc) (bool, error) { } -func getFieldName(field reflect.StructField) string { +func getFieldNameFromJsonTag(field reflect.StructField) string { if tag := field.Tag.Get("json"); tag != "" { if commaIndex := strings.Index(tag, ","); commaIndex != -1 { name := tag[:commaIndex] @@ -476,7 +590,7 @@ func getFieldName(field reflect.StructField) string { return field.Name } -func toName(name string, tag *tagOption, opt *options) string { +func getFieldName(name string, tag *tagOption, opt *options) string { if tag != nil && tag.Contains(tagToName) { return tag.toname } @@ -485,7 +599,7 @@ func toName(name string, tag *tagOption, opt *options) string { } func fieldByName(v reflect.Value, name string, opt *options) reflect.Value { - if opt.caseSensitive { + if opt != nil && opt.caseSensitive { return v.FieldByName(name) } return v.FieldByNameFunc(func(s string) bool { return strings.EqualFold(s, name) })