fix map to embedded struct

This commit is contained in:
2025-09-28 17:32:36 +08:00
parent 5c4d428b5e
commit 03941068d8
3 changed files with 38 additions and 15 deletions

View File

@@ -146,16 +146,13 @@ func (c *copier) copySliceToSlice(dst, src reflect.Value, depth int) error {
func (c *copier) copyToSlice(dst, src reflect.Value, depth int) error {
srcLen := src.Len()
// 如果目标slice为空或长度不够重新创建
if dst.IsNil() || dst.Len() < srcLen {
newSlice := reflect.MakeSlice(dst.Type(), srcLen, srcLen)
dst.Set(newSlice)
}
// 确定实际复制的元素数量
copyLen := min(dst.Len(), srcLen)
// 复制元素
for i := range copyLen {
if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil {
return err
@@ -191,7 +188,8 @@ func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error {
if sf.Field.Anonymous {
switch sf.Field.Type.Kind() {
case reflect.Struct, reflect.Pointer:
if err := c.deepCopy(dst, src.Field(sf.Index), depth+1); err != nil {
field := c.fieldByName(src, sf.Field.Name)
if err := c.deepCopy(dst, field, depth+1); err != nil {
return err
}
}
@@ -206,7 +204,7 @@ func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error {
name := c.getFieldName(sf.Field.Name, tag)
sField := src.Field(sf.Index)
sField := c.fieldByName(src, sf.Field.Name)
if c.opt.ignoreEmpty && sField.IsZero() {
continue
@@ -240,15 +238,29 @@ func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error {
func (c *copier) copyMap(dst, src reflect.Value, depth int) error {
switch dst.Kind() {
case reflect.Map:
return c.copyMap2Map(dst, src, depth)
return c.copyMapToMap(dst, src, depth)
case reflect.Struct:
return c.copyMapToStruct(dst, src, depth)
case reflect.Pointer:
if dst.IsNil() {
elemType := dst.Type().Elem()
newPtr := reflect.New(elemType)
if err := c.deepCopy(newPtr.Elem(), src, depth+1); err != nil {
return err
}
dst.Set(newPtr)
} else {
if err := c.deepCopy(dst.Elem(), src, depth+1); err != nil {
return err
}
}
return nil
default:
return ErrNotSupported
}
}
func (c *copier) copyMap2Map(dst, src reflect.Value, depth int) error {
func (c *copier) copyMapToMap(dst, src reflect.Value, depth int) error {
if dst.IsNil() {
dst.Set(reflect.MakeMapWithSize(dst.Type(), src.Len()))
}
@@ -302,7 +314,10 @@ func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error {
fields := c.deepFields(dst.Type())
for _, sf := range fields {
field := dst.Field(sf.Index)
if nestedAnonymousField(dst, sf.Field.Name) {
continue
}
field := dst.FieldByName(sf.Field.Name)
if !field.CanSet() {
continue
@@ -315,8 +330,13 @@ func (c *copier) copyMapToStruct(dst, src reflect.Value, depth int) error {
}
name := sf.Field.Name
var mapValue reflect.Value
if c.opt.caseSensitive {
mapValue = src.MapIndex(reflect.ValueOf(name))
} else {
mapValue = src.MapIndex(reflect.ValueOf(strings.ToLower(name)))
}
mapValue := src.MapIndex(reflect.ValueOf(name))
if !mapValue.IsValid() {
continue
}
@@ -592,7 +612,6 @@ type fieldsWrapper struct {
type fieldWrapper struct {
Field reflect.StructField
Tag *tagOption
Index int
}
var deepFieldsMap sync.Map
@@ -634,9 +653,9 @@ func (c *copier) calculateDeepFields(reflectType reflect.Type) []fieldWrapper {
if sf.Anonymous {
res = append(res, c.deepFields(sf.Type)...)
} else {
res = append(res, fieldWrapper{Field: sf, Tag: tag})
}
res = append(res, fieldWrapper{Field: sf, Tag: tag, Index: i})
}
}

View File

@@ -148,7 +148,7 @@ func TestAnonymousStructToMapCopy(t *testing.T) {
}
dst := person2{}
if err := Copy(&dst, src); err != nil {
if err := Copy(&dst, src, WithCaseSensitive()); err != nil {
t.Fatal(err)
}
@@ -174,9 +174,11 @@ func TestMapToStructCopy(t *testing.T) {
}
dst := person{}
if err := Copy(&dst, src); err != nil {
if err := Copy(&dst, src, WithCaseSensitive()); err != nil {
t.Fatal(err)
}
t.Log(dst)
}
func TestMixMapToStruct(t *testing.T) {
@@ -224,6 +226,8 @@ func TestMixMapToStruct(t *testing.T) {
if err := Copy(&dst, src); err != nil {
t.Fatal(err)
}
t.Log(dst)
}
func BenchmarkMapCopy(b *testing.B) {

View File

@@ -297,7 +297,7 @@ func TestDeepFileds(t *testing.T) {
copier := newCopier()
wrapper := copier.deepFields(typ)
for _, v := range wrapper {
t.Log(v, v.Index)
t.Log(v)
}
}