This commit is contained in:
2025-09-30 18:41:21 +08:00
parent 2b90888901
commit 362dc000b1
7 changed files with 344 additions and 79 deletions

117
copier.go
View File

@@ -3,6 +3,7 @@ package copier
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -72,16 +73,14 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
// 可以直接赋值时直接赋值并返回 // 可以直接赋值时直接赋值并返回
srcType, dstType := src.Type(), dst.Type() srcType, dstType := src.Type(), dst.Type()
// if srcType.AssignableTo(dstType) {
// dst.Set(src)
// return nil
// }
// 处理时间类型 // 处理时间类型
if isTimeType(srcType) || isTimeType(dstType) { if isTimeType(srcType) || isTimeType(dstType) {
return c.copyTime(dst, src) return c.copyTime(dst, src)
} }
// fmt.Println("srcType:", srcType, src.Kind().String(), "dstType:", dstType, dst.Kind().String())
switch src.Kind() { switch src.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
return c.copySliceOrArray(dst, src, depth) return c.copySliceOrArray(dst, src, depth)
@@ -102,8 +101,6 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
func (c *copier) copySliceOrArray(dst, src reflect.Value, depth int) error { func (c *copier) copySliceOrArray(dst, src reflect.Value, depth int) error {
switch dst.Kind() { switch dst.Kind() {
case reflect.Struct:
return c.copyStructToSlice(dst, src, depth)
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
return c.copySliceToSlice(dst, src, depth) return c.copySliceToSlice(dst, src, depth)
default: default:
@@ -350,6 +347,8 @@ func (c *copier) copyStruct(dst, src reflect.Value, depth int) error {
return c.copyStructToStruct(dst, src, depth) return c.copyStructToStruct(dst, src, depth)
case reflect.Map: case reflect.Map:
return c.copyStructToMap(dst, src, depth) return c.copyStructToMap(dst, src, depth)
case reflect.Slice:
return c.copyStructToSlice(dst, src, depth)
default: default:
return ErrNotSupported(src.Type(), dst.Type()) return ErrNotSupported(src.Type(), dst.Type())
} }
@@ -361,15 +360,42 @@ func (c *copier) copyStructToStruct(dst, src reflect.Value, depth int) error {
return nil return nil
} }
fields := c.deepFields(src.Type()) srcFields := c.deepFields(src.Type())
for _, sf := range fields { dstFields := c.deepFields(dst.Type())
for _, sf := range dstFields {
_ = sf
}
for _, sf := range srcFields {
name := c.getFieldName(sf.Field.Name, sf.Tag) name := c.getFieldName(sf.Field.Name, sf.Tag)
sField := src.FieldByName(sf.Field.Name)
// 是否忽略该字段
if c.ignore(sField, name, name, sf.Tag) {
continue
}
if nestedAnonymousField(dst, name) { if nestedAnonymousField(dst, name) {
continue continue
} }
dstValue := c.fieldByName(dst, name) dstValue := c.fieldByName(dst, name)
if dstValue.IsValid() {
} else {
var toMethod reflect.Value
if dst.CanAddr() {
toMethod = dst.Addr().MethodByName(name)
} else {
toMethod = dst.MethodByName(name)
}
if toMethod.IsValid() && toMethod.Type().NumIn() == 1 && toMethod.Type().In(0) == sField.Type() {
toMethod.Call([]reflect.Value{sField})
}
}
if !dstValue.IsValid() { if !dstValue.IsValid() {
continue continue
} }
@@ -378,11 +404,6 @@ func (c *copier) copyStructToStruct(dst, src reflect.Value, depth int) error {
continue continue
} }
sField := src.FieldByName(sf.Field.Name)
if c.opt.ignoreEmpty && sField.IsZero() {
continue
}
if ok, err := c.lookupAndCopyWithConverter(dstValue, sField, name); err != nil { if ok, err := c.lookupAndCopyWithConverter(dstValue, sField, name); err != nil {
return err return err
} else if ok { } else if ok {
@@ -393,6 +414,7 @@ func (c *copier) copyStructToStruct(dst, src reflect.Value, depth int) error {
return err return err
} }
} }
return nil return nil
} }
@@ -576,62 +598,25 @@ func (c *copier) setBool(dst, src reflect.Value) error {
return nil return nil
} }
type fieldsWrapper struct { func (c *copier) ignore(field reflect.Value, srcName, dstName string, tag *tagOption) bool {
Fields []fieldWrapper // 忽略空值
once sync.Once if tag != nil && tag.Contains(tagIgnore) {
} return true
type fieldWrapper struct {
Field reflect.StructField
Tag *tagOption
}
var deepFieldsMap sync.Map
func (c *copier) deepFields(reflectType reflect.Type) []fieldWrapper {
if wrapper, ok := deepFieldsMap.Load(reflectType); ok {
w := wrapper.(*fieldsWrapper)
w.once.Do(func() {
w.Fields = c.calculateDeepFields(reflectType)
})
return w.Fields
} }
wrapper, loaded := deepFieldsMap.LoadOrStore(reflectType, &fieldsWrapper{}) // 忽略空值
w := wrapper.(*fieldsWrapper) if c.opt.ignoreEmpty && field.IsZero() {
if !loaded { return true
w.once.Do(func() {
w.Fields = c.calculateDeepFields(reflectType)
})
} else {
w.once.Do(func() {})
} }
return w.Fields // 忽略指定字段
} if c.opt.skipFields != nil {
if slices.Contains(c.opt.skipFields, srcName) || slices.Contains(c.opt.skipFields, dstName) {
func (c *copier) calculateDeepFields(reflectType reflect.Type) []fieldWrapper { return true
reflectType, _ = indirectType(reflectType)
num := reflectType.NumField()
res := make([]fieldWrapper, 0, num)
if reflectType.Kind() == reflect.Struct {
for i := range num {
sf := reflectType.Field(i)
if sf.PkgPath != "" && !sf.Anonymous {
continue
}
tag := parseTag(sf.Tag.Get(c.opt.tagName))
if sf.Anonymous {
res = append(res, c.deepFields(sf.Type)...)
} else {
res = append(res, fieldWrapper{Field: sf, Tag: tag})
}
} }
} }
return res return false
} }
func (c *copier) lookupAndCopyWithConverter(dst, src reflect.Value, fieldName string) (bool, error) { func (c *copier) lookupAndCopyWithConverter(dst, src reflect.Value, fieldName string) (bool, error) {
@@ -952,13 +937,17 @@ func convert(dst, src reflect.Value, fn convertFunc) (bool, error) {
} }
func (c *copier) getFieldName(name string, tag *tagOption) string { func (c *copier) getFieldName(name string, tag *tagOption) string {
if tag != nil && tag.Contains(tagToName) { if tag != nil && tag.Contains(tagName) {
return tag.toname return tag.name
} }
return c.opt.NameConvert(name) return c.opt.NameConvert(name)
} }
func (c *copier) getFieldName2(fieldName string, tag *tagOption) (srcFieldName string, dstFieldName string) {
return fieldName, c.getFieldName(fieldName, tag)
}
func (c *copier) fieldByName(v reflect.Value, name string) reflect.Value { func (c *copier) fieldByName(v reflect.Value, name string) reflect.Value {
if c.opt != nil && c.opt.caseSensitive { if c.opt != nil && c.opt.caseSensitive {
return v.FieldByName(name) return v.FieldByName(name)

View File

@@ -194,7 +194,7 @@ func TestMixMapToStruct(t *testing.T) {
Name string `json:"name"` Name string `json:"name"`
Age int `json:"age"` Age int `json:"age"`
Height float64 `json:"height"` Height float64 `json:"height"`
IsStudent bool `json:"is_student"` IsStudent bool `json:"is_student" copier:"-"`
Birthday time.Time `json:"birthday"` Birthday time.Time `json:"birthday"`
Tags []string `json:"tags"` Tags []string `json:"tags"`
Scores []int `json:"scores"` Scores []int `json:"scores"`

View File

@@ -2,6 +2,7 @@ package copier
import ( import (
"fmt" "fmt"
"reflect"
"slices" "slices"
"testing" "testing"
) )
@@ -114,3 +115,55 @@ func TestAppendSlice(t *testing.T) {
} }
} }
func TestCopySliceOfDifferentTypes(t *testing.T) {
var ss []string
var is []int
if err := Copy(&ss, is); err != nil {
t.Error(err)
}
var anotherSs []string
if !reflect.DeepEqual(ss, anotherSs) {
t.Errorf("Copy nil slice to nil slice should get nil slice")
}
t.Log(ss)
}
func TestCopyFromStructToSlice(t *testing.T) {
user := User{Name: "Jinzhu", Age: 18, Role: "Admin", Notes: []string{"hello world"}}
employees := []Employee{}
if err := Copy(employees, &user); err != nil && len(employees) != 0 {
t.Errorf("Copy to unaddressable value should get error")
}
if Copy(&employees, &user); len(employees) != 1 {
t.Errorf("Should only have one elem when copy struct to slice")
} else {
checkEmployee(employees[0], user, t, "Copy From Struct To Slice Ptr")
}
t.Log("ssss", employees)
employees2 := &[]Employee{}
if Copy(&employees2, user); len(*employees2) != 1 {
t.Errorf("Should only have one elem when copy struct to slice")
} else {
checkEmployee((*employees2)[0], user, t, "Copy From Struct To Double Slice Ptr")
}
employees3 := []*Employee{}
if Copy(&employees3, user); len(employees3) != 1 {
t.Errorf("Should only have one elem when copy struct to slice")
} else {
checkEmployee(*(employees3[0]), user, t, "Copy From Struct To Ptr Slice Ptr")
}
employees4 := &[]*Employee{}
if Copy(&employees4, user); len(*employees4) != 1 {
t.Errorf("Should only have one elem when copy struct to slice")
} else {
checkEmployee(*((*employees4)[0]), user, t, "Copy From Struct To Double Ptr Slice Ptr")
}
}

View File

@@ -23,10 +23,74 @@ type Person2 struct {
Age int Age int
} }
type User struct {
Name string
Birthday *time.Time
Nickname string
Role string
Age int32
FakeAge *int32
Notes []string
flags []byte
}
func (user User) DoubleAge() int32 {
return 2 * user.Age
}
type Employee struct { type Employee struct {
Name string _User *User
Age int Name string
Address *Address Birthday *time.Time
NickName *string
Age int64
FakeAge int
EmployeID int64
DoubleAge int32
SuperRule string
Notes []*string
flags []byte
}
func checkEmployee(employee Employee, user User, t *testing.T, testCase string) {
if employee.Name != user.Name {
t.Errorf("%v: Name haven't been copied correctly.", testCase)
}
if employee.NickName == nil || *employee.NickName != user.Nickname {
t.Errorf("%v: NickName haven't been copied correctly.", testCase)
}
if employee.Birthday == nil && user.Birthday != nil {
t.Errorf("%v: Birthday haven't been copied correctly.", testCase)
}
if employee.Birthday != nil && user.Birthday == nil {
t.Errorf("%v: Birthday haven't been copied correctly.", testCase)
}
if employee.Birthday != nil && user.Birthday != nil &&
!employee.Birthday.Equal(*(user.Birthday)) {
t.Errorf("%v: Birthday haven't been copied correctly.", testCase)
}
if employee.Age != int64(user.Age) {
t.Errorf("%v: Age haven't been copied correctly.", testCase)
}
if user.FakeAge != nil && employee.FakeAge != int(*user.FakeAge) {
t.Errorf("%v: FakeAge haven't been copied correctly.", testCase)
}
if employee.DoubleAge != user.DoubleAge() {
t.Errorf("%v: Copy from method doesn't work", testCase)
}
if employee.SuperRule != "Super "+user.Role {
t.Errorf("%v: Copy to method doesn't work", testCase)
}
if len(employee.Notes) != len(user.Notes) {
t.Fatalf("%v: Copy from slice doesn't work, employee notes len: %v, user: %v", testCase, len(employee.Notes), len(user.Notes))
}
for idx, note := range user.Notes {
if note != *employee.Notes[idx] {
t.Fatalf("%v: Copy from slice doesn't work, notes idx: %v employee: %v user: %v", testCase, idx, *employee.Notes[idx], note)
}
}
} }
func TestCopySameStruct(t *testing.T) { func TestCopySameStruct(t *testing.T) {
@@ -305,10 +369,6 @@ func BenchmarkCopy(b *testing.B) {
var emp = Employee{ var emp = Employee{
Name: "John", Name: "John",
Age: 30, Age: 30,
Address: &Address{
Country: "USA",
City: "New York",
},
} }
for b.Loop() { for b.Loop() {
@@ -329,3 +389,47 @@ func BenchmarkDeepFields(b *testing.B) {
} }
} }
func TestEmbeddedAndBase(t *testing.T) {
type Base struct {
BaseField1 int
BaseField2 int
User *User
}
type Embed struct {
EmbedField1 int
EmbedField2 int
Base
}
base := Base{}
embedded := Embed{}
embedded.BaseField1 = 1
embedded.BaseField2 = 2
embedded.EmbedField1 = 3
embedded.EmbedField2 = 4
user := User{
Name: "testName",
}
embedded.User = &user
Copy(&base, &embedded)
if base.BaseField1 != 1 || base.User.Name != "testName" {
t.Error("Embedded fields not copied")
}
base.BaseField1 = 11
base.BaseField2 = 12
user1 := User{
Name: "testName1",
}
base.User = &user1
Copy(&embedded, &base)
if embedded.BaseField1 != 11 || embedded.User.Name != "testName1" {
t.Error("base fields not copied")
}
}

79
deep_field.go Normal file
View File

@@ -0,0 +1,79 @@
package copier
import (
"reflect"
"sync"
)
type fieldsWrapper struct {
Fields []fieldWrapper
once sync.Once
}
type fieldWrapper struct {
Field reflect.StructField
Name string
Format string
Tag *tagOption
}
// type deepFieldCopier struct {
// srcFields []fieldWrapper
// dstFields []fieldWrapper
// }
var deepFieldsMap sync.Map
func (c *copier) deepFields(reflectType reflect.Type) []fieldWrapper {
if wrapper, ok := deepFieldsMap.Load(reflectType); ok {
w := wrapper.(*fieldsWrapper)
w.once.Do(func() {
w.Fields = c.calculateDeepFields(reflectType)
})
return w.Fields
}
wrapper, loaded := deepFieldsMap.LoadOrStore(reflectType, &fieldsWrapper{})
w := wrapper.(*fieldsWrapper)
if !loaded {
w.once.Do(func() {
w.Fields = c.calculateDeepFields(reflectType)
})
} else {
w.once.Do(func() {})
}
return w.Fields
}
func (c *copier) calculateDeepFields(reflectType reflect.Type) []fieldWrapper {
reflectType, _ = indirectType(reflectType)
num := reflectType.NumField()
res := make([]fieldWrapper, 0, num)
if reflectType.Kind() == reflect.Struct {
for i := range num {
sf := reflectType.Field(i)
if sf.PkgPath != "" && !sf.Anonymous {
continue
}
tag := parseTag(sf.Tag.Get(c.opt.tagName))
if tag.Contains(tagIgnore) {
continue
}
if sf.Anonymous {
res = append(res, c.deepFields(sf.Type)...)
} else {
name := sf.Name
if tag.Contains(tagName) {
name = tag.name
}
res = append(res, fieldWrapper{Field: sf, Tag: tag, Name: name})
}
}
}
return res
}

23
tags.go
View File

@@ -11,15 +11,17 @@ type tagt uint8
const ( const (
tagMust tagt = 1 << iota tagMust tagt = 1 << iota
tagIgnore tagIgnore
tagToName tagName
tagFormat
) )
// 标签 copier 取值 must、ignore、toname 等 // 标签 copier 取值 must、ignore、toname 等
// copier:"must,toname=xxx" 必须复制,并重命名为 xxx // copier:"must,name=xxx" 必须复制,并重命名为 xxx
type tagOption struct { type tagOption struct {
flg tagt flg tagt
toname string name string
format string
} }
func parseTag(tag string) *tagOption { func parseTag(tag string) *tagOption {
@@ -28,15 +30,22 @@ func parseTag(tag string) *tagOption {
flg := tagt(0) flg := tagt(0)
for t := range strings.SplitSeq(tag, ",") { for t := range strings.SplitSeq(tag, ",") {
tag, value, found := strings.Cut(t, "=") tag, value, found := strings.Cut(t, "=")
switch tag { lower := strings.ToLower(tag)
switch lower {
case "-", "ignore": case "-", "ignore":
flg |= tagIgnore flg |= tagIgnore
case "must": case "must":
flg |= tagMust flg |= tagMust
case "toname": case "name":
flg |= tagToName flg |= tagName
if found { if found {
opt.toname = value opt.name = value
}
case "format":
flg |= tagFormat
if found {
opt.format = value
} }
} }
} }

31
tags_test.go Normal file
View File

@@ -0,0 +1,31 @@
package copier
import (
"reflect"
"testing"
"time"
)
func TestTags(t *testing.T) {
type TagTest struct {
Tag time.Time `copier:"must,name=xxx,format=2006-01-02 15:04:05" json:"-"`
}
tag := parseTag("must,name=xxx,format=2006-01-02 15:04:05")
if !tag.Contains(tagFormat) {
t.Error("tag format should not be in tag")
}
typ := reflect.TypeOf(TagTest{})
t.Log(typ.NumMethod())
for i, n := 0, typ.NumField(); i < n; i++ {
sf := typ.Field(i)
str := sf.Tag.Get("copier")
tag = parseTag(str)
t.Log(tag)
}
}