This commit is contained in:
2025-09-26 18:22:17 +08:00
parent 61f7bccdf1
commit a6e0c5e740
6 changed files with 292 additions and 126 deletions

263
copier.go
View File

@@ -6,16 +6,33 @@ import (
"strings"
"sync"
"time"
"github.com/modern-go/reflect2"
)
func Copy(dst, src any, opts ...option) error {
opt := getOpt(opts...)
return copy(dst, src, opt)
type copier struct {
opt *options
visited map[reflect.Value]struct{}
}
func copy(dst, src any, opt *options) error {
func newCopier(opts ...option) *copier {
opt := getOpt(opts...)
var visited map[reflect.Value]struct{}
if opt.detectCircularRefs {
visited = make(map[reflect.Value]struct{})
}
return &copier{
visited: visited,
opt: opt,
}
}
func (c *copier) Copy(dst, src any) error {
defer c.reset()
if dst == nil || src == nil {
return ErrInvalidCopyDestination
}
var (
srcValue = indirect(reflect.ValueOf(src))
dstValue = indirect(reflect.ValueOf(dst))
@@ -29,36 +46,46 @@ func copy(dst, src any, opt *options) error {
return ErrInvalidCopyDestination
}
return deepCopy(dstValue, srcValue, 0, "", opt)
return c.deepCopy(dstValue, srcValue, 0)
}
func deepCopy(dst, src reflect.Value, depth int, fieldName string, opt *options) error {
func (c *copier) reset() {
if c.opt.detectCircularRefs {
c.visited = make(map[reflect.Value]struct{})
}
}
func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
if c.ExceedMaxDepth(depth) {
return nil
}
if src.IsZero() {
return nil
}
if opt.ExceedMaxDepth(depth) {
return nil
}
// fmt.Println("do deep copy src:", src.Kind().String(), "dst:", dst.Kind().String())
switch src.Kind() {
case reflect.Slice, reflect.Array:
// slice -> slice
return copySlice(dst, src, depth, fieldName, opt)
switch dst.Kind() {
case reflect.Slice, reflect.Array:
return c.copySlice(dst, src, depth)
default:
return ErrNotSupported
}
case reflect.Map:
switch dst.Kind() {
case reflect.Map:
// map -> map
return copyMap(dst, src, depth, fieldName, opt)
return c.copyMap(dst, src, depth)
case reflect.Struct:
// map -> struct
return copyMap2Struct(dst, src, depth, fieldName, opt)
return c.copyMap2Struct(dst, src, depth)
case reflect.Pointer:
elemType := dst.Type().Elem()
newPtr := reflect.New(elemType)
if err := deepCopy(newPtr.Elem(), src, depth+1, fieldName, opt); err != nil {
if err := c.deepCopy(newPtr.Elem(), src, depth+1); err != nil {
return err
}
@@ -71,17 +98,17 @@ func deepCopy(dst, src reflect.Value, depth int, fieldName string, opt *options)
switch dst.Kind() {
case reflect.Map:
// struct -> map
return copyStruct2Map(dst, src, depth, fieldName, opt)
return c.copyStruct2Map(dst, src, depth)
case reflect.Struct:
// struct -> struct
return copyStruct(dst, src, depth, opt)
return c.copyStruct(dst, src, depth)
case reflect.Array, reflect.Slice:
if dst.IsNil() {
dstType := dst.Type().Elem()
newDst := reflect.MakeSlice(reflect.SliceOf(dstType), 1, 1)
dst.Set(newDst)
if err := deepCopy(dst.Index(0), src, depth+1, fieldName, opt); err != nil {
if err := c.deepCopy(dst.Index(0), src, depth+1); err != nil {
return err
}
@@ -92,7 +119,7 @@ func deepCopy(dst, src reflect.Value, depth int, fieldName string, opt *options)
reflect.Copy(newSlice, dst)
newDst := newSlice.Index(len)
if err := deepCopy(newDst, src, depth+1, fieldName, opt); err != nil {
if err := c.deepCopy(newDst, src, depth+1); err != nil {
return err
}
dst.Set(newSlice)
@@ -100,14 +127,9 @@ func deepCopy(dst, src reflect.Value, depth int, fieldName string, opt *options)
return nil
default:
return set(dst, src, fieldName, opt)
return c.set(dst, src)
}
case reflect.Func:
if dst.Kind() == reflect.Func && dst.IsNil() {
dst.Set(src)
}
return nil
case reflect.Pointer:
if dst.Kind() == reflect.Pointer && dst.IsNil() {
if !dst.CanSet() {
@@ -125,7 +147,7 @@ func deepCopy(dst, src reflect.Value, depth int, fieldName string, opt *options)
dst = dst.Elem()
}
return deepCopy(dst, src, depth, fieldName, opt)
return c.deepCopy(dst, src, depth)
case reflect.Interface:
if src.IsNil() {
@@ -133,26 +155,26 @@ func deepCopy(dst, src reflect.Value, depth int, fieldName string, opt *options)
}
if src.Kind() != dst.Kind() {
return set(dst, src, fieldName, opt)
return c.set(dst, src)
}
src = src.Elem()
newDst := reflect.New(src.Type().Elem())
if err := deepCopy(newDst, src, depth, fieldName, opt); err != nil {
if err := c.deepCopy(newDst, src, depth); err != nil {
return err
}
dst.Set(newDst)
return nil
case reflect.Chan:
case reflect.Chan, reflect.Func, reflect.UnsafePointer:
return ErrNotSupported
default:
return set(dst, src, fieldName, opt)
return c.set(dst, src)
}
}
func copyStruct2Map(dst, src reflect.Value, depth int, fieldName string, opt *options) error {
func (c *copier) copyStruct2Map(dst, src reflect.Value, depth int) error {
for i, n := 0, src.NumField(); i < n; i++ {
sf := src.Type().Field(i)
if sf.PkgPath != "" && !sf.Anonymous {
@@ -162,7 +184,7 @@ func copyStruct2Map(dst, src reflect.Value, depth int, fieldName string, opt *op
if sf.Anonymous {
switch sf.Type.Kind() {
case reflect.Struct, reflect.Pointer:
if err := deepCopy(dst, src.Field(i), depth+1, fieldName, opt); err != nil {
if err := c.deepCopy(dst, src.Field(i), depth+1); err != nil {
return err
}
}
@@ -170,16 +192,16 @@ func copyStruct2Map(dst, src reflect.Value, depth int, fieldName string, opt *op
continue
}
tag := parseTag(sf.Tag.Get(opt.tagName))
tag := parseTag(sf.Tag.Get(c.opt.tagName))
if tag.Contains(tagIgnore) {
continue
}
name := getFieldName(sf.Name, tag, opt)
name := c.getFieldName(sf.Name, tag)
sField := src.Field(i)
if opt.ignoreEmpty && sField.IsZero() {
if c.opt.ignoreEmpty && sField.IsZero() {
continue
}
@@ -187,7 +209,7 @@ func copyStruct2Map(dst, src reflect.Value, depth int, fieldName string, opt *op
var newDst = reflect.ValueOf(make(map[string]any))
sField = indirect(sField)
if err := deepCopy(newDst, sField, depth+1, name, opt); err != nil {
if err := c.deepCopy(newDst, sField, depth+1); err != nil {
return err
}
@@ -200,11 +222,9 @@ func copyStruct2Map(dst, src reflect.Value, depth int, fieldName string, opt *op
return nil
}
func copyMap2Struct(dst, src reflect.Value, depth int, _ string, opt *options) error {
// 循环结构然后从map取值
func (c *copier) copyMap2Struct(dst, src reflect.Value, depth int) error {
typ := dst.Type()
for i, n := 0, src.NumField(); i < n; i++ {
for i, n := 0, dst.NumField(); i < n; i++ {
field := dst.Field(i)
sf := typ.Field(i)
@@ -213,15 +233,12 @@ func copyMap2Struct(dst, src reflect.Value, depth int, _ string, opt *options) e
}
if sf.Anonymous {
if err := deepCopy(field, src, depth+1, sf.Name, opt); err != nil {
if err := c.deepCopy(field, src, depth+1); err != nil {
return err
}
}
name := getFieldNameFromJsonTag(sf)
if name == "-" {
continue
}
name := sf.Name
mapValue := src.MapIndex(reflect.ValueOf(name))
if !mapValue.IsValid() {
@@ -233,11 +250,11 @@ func copyMap2Struct(dst, src reflect.Value, depth int, _ string, opt *options) e
}
if mapValue.Kind() == reflect.Map || mapValue.Kind() == reflect.Array || mapValue.Kind() == reflect.Slice {
if err := deepCopy(field, mapValue, depth+1, name, opt); err != nil {
if err := c.deepCopy(field, mapValue, depth+1); err != nil {
return err
}
} else {
if err := set(field, mapValue, name, opt); err != nil {
if err := c.set(field, mapValue); err != nil {
return err
}
}
@@ -246,7 +263,7 @@ func copyMap2Struct(dst, src reflect.Value, depth int, _ string, opt *options) e
return nil
}
func copyMap(dst, src reflect.Value, depth int, fieldName string, opt *options) error {
func (c *copier) copyMap(dst, src reflect.Value, depth int) error {
if dst.IsNil() {
dst.Set(reflect.MakeMapWithSize(dst.Type(), src.Len()))
}
@@ -259,11 +276,11 @@ func copyMap(dst, src reflect.Value, depth int, fieldName string, opt *options)
value := iter.Value()
if name, ok := key.Interface().(string); ok {
fieldName = opt.NameConvert(name)
fieldName := c.getFieldName(name, nil)
key = reflect.ValueOf(fieldName)
}
if opt.ignoreEmpty && value.IsZero() {
if c.opt.ignoreEmpty && value.IsZero() {
continue
}
@@ -277,12 +294,12 @@ func copyMap(dst, src reflect.Value, depth int, fieldName string, opt *options)
copitedValue = reflect.New(value.Type()).Elem()
}
set(copitedValue, value, fieldName, opt)
c.set(copitedValue, value)
default:
copitedValue = reflect.New(dstType.Elem()).Elem()
}
if err := deepCopy(copitedValue, value, depth+1, fieldName, opt); err != nil {
if err := c.deepCopy(copitedValue, value, depth+1); err != nil {
return err
}
@@ -292,28 +309,40 @@ func copyMap(dst, src reflect.Value, depth int, fieldName string, opt *options)
return nil
}
func copySlice(dst, src reflect.Value, depth int, fieldName string, opt *options) error {
len := src.Len()
func (c *copier) copySlice(dst, src reflect.Value, depth int) error {
var len int
if src.Kind() == reflect.Struct {
len = 0
} else {
len = src.Len()
}
if dst.Len() > 0 && dst.Len() < src.Len() {
len = dst.Len()
}
if dst.Kind() == reflect.Slice && dst.Len() == 0 && src.Len() > 0 {
if dst.Kind() == reflect.Slice && dst.Len() == 0 && len > 0 {
dstType := dst.Type().Elem()
newDst := reflect.MakeSlice(reflect.SliceOf(dstType), len, len)
dst.Set(newDst)
}
for i := 0; i < len; i++ {
if err := deepCopy(dst.Index(i), src.Index(i), depth, fieldName, opt); err != nil {
if src.Kind() == reflect.Struct {
if err := c.deepCopy(dst.Index(0), src, depth); err != nil {
return err
}
} else {
for i := 0; i < len; i++ {
if err := c.deepCopy(dst.Index(i), src.Index(i), depth); err != nil {
return err
}
}
}
return nil
}
func copyStruct(dst, src reflect.Value, depth int, opt *options) error {
func (c *copier) copyStruct(dst, src reflect.Value, depth int) error {
if dst.CanSet() {
if _, ok := src.Interface().(time.Time); ok {
dst.Set(src)
@@ -322,41 +351,42 @@ func copyStruct(dst, src reflect.Value, depth int, opt *options) error {
}
typ := src.Type()
fields := deepFields(typ)
fields := c.deepFields(typ)
for _, sf := range fields {
tag := parseTag(sf.Tag.Get(opt.tagName))
if tag.Contains(tagIgnore) {
continue
}
name := getFieldName(sf.Name, tag, opt)
name := c.getFieldName(sf.Field.Name, sf.Tag)
if nestedAnonymousField(dst, name) {
continue
}
typ := reflect2.TypeOf(dst)
dstPtr := reflect2.PtrOf(dst)
typ.(reflect2.StructType).Field(1).UnsafeGet(dstPtr)
dstValue := fieldByName(dst, name, opt)
dstValue := c.fieldByName(dst, name)
if !dstValue.IsValid() {
continue
}
sField := src.FieldByName(sf.Name)
if opt.ignoreEmpty && sField.IsZero() {
if c.isCricularRefs(dstValue) {
continue
}
if err := deepCopy(dstValue, sField, depth+1, name, opt); err != nil {
sField := src.FieldByName(sf.Field.Name)
if c.opt.ignoreEmpty && sField.IsZero() {
continue
}
if ok, err := c.lookupAndCopyWithConverter(dstValue, sField, name); err != nil {
return err
} else if ok {
return nil
}
if err := c.deepCopy(dstValue, sField, depth+1); err != nil {
return err
}
}
return nil
}
func set(dst, src reflect.Value, fieldName string, opt *options) error {
func (c *copier) set(dst, src reflect.Value) error {
if !src.IsValid() {
return ErrInvalidCopyFrom
}
@@ -365,11 +395,11 @@ func set(dst, src reflect.Value, fieldName string, opt *options) error {
src = src.Elem()
}
if ok, err := lookupAndCopyWithConverter(dst, src, fieldName, opt); err != nil {
return err
} else if ok {
return nil
}
// if ok, err := c.lookupAndCopyWithConverter(dst, src, fieldName); err != nil {
// return err
// } else if ok {
// return nil
// }
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
@@ -465,17 +495,23 @@ func set(dst, src reflect.Value, fieldName string, opt *options) error {
}
type fieldsWrapper struct {
Fields []reflect.StructField
Fields []fieldWrapper
once sync.Once
}
type fieldWrapper struct {
Field reflect.StructField
Tag *tagOption
Index int
}
var deepFieldsMap sync.Map
func deepFields(reflectType reflect.Type) []reflect.StructField {
func (c *copier) deepFields(reflectType reflect.Type) []fieldWrapper {
if wrapper, ok := deepFieldsMap.Load(reflectType); ok {
w := wrapper.(*fieldsWrapper)
w.once.Do(func() {
w.Fields = calculateDeepFields(reflectType)
w.Fields = c.calculateDeepFields(reflectType)
})
return w.Fields
}
@@ -484,7 +520,7 @@ func deepFields(reflectType reflect.Type) []reflect.StructField {
w := wrapper.(*fieldsWrapper)
if !loaded {
w.once.Do(func() {
w.Fields = calculateDeepFields(reflectType)
w.Fields = c.calculateDeepFields(reflectType)
})
} else {
w.once.Do(func() {})
@@ -493,10 +529,10 @@ func deepFields(reflectType reflect.Type) []reflect.StructField {
return w.Fields
}
func calculateDeepFields(reflectType reflect.Type) []reflect.StructField {
func (c *copier) calculateDeepFields(reflectType reflect.Type) []fieldWrapper {
reflectType, _ = indirectType(reflectType)
num := reflectType.NumField()
res := make([]reflect.StructField, 0, num)
res := make([]fieldWrapper, 0, num)
if reflectType.Kind() == reflect.Struct {
for i := range num {
sf := reflectType.Field(i)
@@ -504,11 +540,13 @@ func calculateDeepFields(reflectType reflect.Type) []reflect.StructField {
continue
}
tag := parseTag(sf.Tag.Get(c.opt.tagName))
if sf.Anonymous {
res = append(res, deepFields(sf.Type)...)
res = append(res, c.deepFields(sf.Type)...)
}
res = append(res, sf)
res = append(res, fieldWrapper{Field: sf, Tag: tag, Index: i})
}
}
@@ -540,18 +578,18 @@ func nestedAnonymousField(dst reflect.Value, fieldName string) bool {
return false
}
func lookupAndCopyWithConverter(dst, src reflect.Value, fieldName string, opt *options) (bool, error) {
if cnv, ok := opt.convertByName[fieldName]; ok {
func (c *copier) lookupAndCopyWithConverter(dst, src reflect.Value, fieldName string) (bool, error) {
if cnv, ok := c.opt.convertByName[fieldName]; ok {
return convert(dst, src, cnv)
}
if len(opt.converters) > 0 {
if len(c.opt.converters) > 0 {
pair := converterPair{
SrcType: src.Type(),
DstType: dst.Type(),
}
if cnv, ok := opt.converters[pair]; ok {
if cnv, ok := c.opt.converters[pair]; ok {
return convert(dst, src, cnv)
}
}
@@ -575,36 +613,37 @@ func convert(dst, src reflect.Value, fn convertFunc) (bool, error) {
}
func getFieldNameFromJsonTag(field reflect.StructField) string {
if tag := field.Tag.Get("json"); tag != "" {
if commaIndex := strings.Index(tag, ","); commaIndex != -1 {
name := tag[:commaIndex]
if name != "" {
return name
}
} else if tag != "" {
return tag
}
}
return field.Name
}
func getFieldName(name string, tag *tagOption, opt *options) string {
func (c *copier) getFieldName(name string, tag *tagOption) string {
if tag != nil && tag.Contains(tagToName) {
return tag.toname
}
return opt.NameConvert(name)
return c.opt.NameConvert(name)
}
func fieldByName(v reflect.Value, name string, opt *options) reflect.Value {
if opt != nil && opt.caseSensitive {
func (c *copier) fieldByName(v reflect.Value, name string) reflect.Value {
if c.opt != nil && c.opt.caseSensitive {
return v.FieldByName(name)
}
return v.FieldByNameFunc(func(s string) bool { return strings.EqualFold(s, name) })
}
func (c *copier) ExceedMaxDepth(depth int) bool {
return c.opt.maxDepth > 0 && depth >= c.opt.maxDepth
}
func (c *copier) isCricularRefs(reflectValue reflect.Value) bool {
if c.opt.detectCircularRefs {
_, ok := c.visited[reflectValue.Addr()]
if !ok {
c.visited[reflectValue.Addr()] = struct{}{}
}
return ok
}
return false
}
func indirect(reflectValue reflect.Value) reflect.Value {
for reflectValue.Kind() == reflect.Pointer {
reflectValue = reflectValue.Elem()

View File

@@ -75,3 +75,42 @@ func TestAndSlice(t *testing.T) {
t.Errorf("Expected %v, got %v", []string{"developer", "golang", "backend"}, dst)
}
}
func TestStructAppendSlice(t *testing.T) {
type person struct {
Name string
Age int
}
src := person{Name: "Alice", Age: 20}
dst := []person{
{Name: "Charlie", Age: 40},
}
if err := Copy(&dst, src); err != nil {
t.Fatal(err)
}
if len(dst) != 2 {
t.Errorf("Expected %d elements, got %d", 2, len(dst))
}
t.Log(dst)
}
func TestAppendSlice(t *testing.T) {
src := []string{"developer", "golang", "backend"}
dst := []any{"do"}
if err := Copy(&dst, src); err != nil {
t.Fatal(err)
}
t.Log(dst)
if !slices.Equal([]any{"do", "developer", "golang", "backend"}, dst) {
t.Errorf("Expected %v, got %v", []string{"do", "developer", "golang", "backend"}, dst)
}
}

View File

@@ -2,6 +2,7 @@ package copier
import (
"fmt"
"reflect"
"testing"
"time"
)
@@ -22,6 +23,12 @@ type Person2 struct {
Age int
}
type Employee struct {
Name string
Age int
Address *Address
}
func TestCopySameStruct(t *testing.T) {
src := &Person{
Name: "John",
@@ -282,3 +289,37 @@ func TestTypeConvert(t *testing.T) {
}
t.Log("src,dst:", src, dst)
}
func TestDeepFileds(t *testing.T) {
var emp = Employee{}
typ := reflect.TypeOf(emp)
copier := newCopier()
wrapper := copier.deepFields(typ)
for _, v := range wrapper {
t.Log(v, v.Index)
}
}
func BenchmarkCopy(b *testing.B) {
var emp = Employee{
Name: "John",
Age: 30,
}
for b.Loop() {
var dst Employee
Copy(&dst, &emp)
}
}
func BenchmarkDeepFields(b *testing.B) {
var emp = Employee{}
typ := reflect.TypeOf(emp)
for b.Loop() {
copier := newCopier()
copier.deepFields(typ)
}
}

20
deepcopier.go Normal file
View File

@@ -0,0 +1,20 @@
package copier
import "reflect"
func Copy(dst, src any, opts ...option) error {
copier := newCopier(opts...)
return copier.Copy(dst, src)
}
func Clone(src any, opts ...option) (any, error) {
if src == nil {
return nil, nil
}
srcType := reflect.TypeOf(src)
dst := reflect.New(srcType).Interface()
err := newCopier(opts...).Copy(dst, src)
return dst, err
}

2
go.mod
View File

@@ -1,3 +1,3 @@
module git.charlienet.top/go/copier
go 1.25
go 1.24

View File

@@ -14,10 +14,11 @@ type options struct {
ignoreEmpty bool // 复制时忽略空字段
caseSensitive bool // 复制时大小写敏感
must bool // 只复制具有must标识的字段
detectCircularRefs bool // 检测循环引用
fieldNameMapping map[string]string // 字段名转映射
nameConverter func(string) string // 字段名转换函数
converters map[converterPair]convertFunc // 根据源和目标类型处理的类型转换器v
convertByName map[string]convertFunc // 根据名称处理的类型转换器
nameConverter func(string) string // 字段名转换函数
}
type option func(*options)
@@ -39,8 +40,6 @@ func getOpt(opts ...option) *options {
opt := &options{
maxDepth: noDepthLimited,
tagName: defaultTag,
convertByName: make(map[string]convertFunc),
converters: make(map[converterPair]convertFunc),
}
for _, o := range opts {
@@ -94,16 +93,44 @@ func WithMust() option {
}
}
// WithDetectCircularRefs 添加检测循环引用标识
func WithDetectCircularRefs() option {
return func(o *options) {
o.detectCircularRefs = true
}
}
// WithNameConvertByName 添加根据名称处理的类型转换器
func WithTypeConvertByName(name string, f func(src any) (dst any, err error)) option {
return func(o *options) {
if o.convertByName == nil {
o.convertByName = make(map[string]convertFunc)
}
o.convertByName[name] = f
}
}
// WithConvertByNames 添加根据名称处理的类型转换器
func WithTypeConvertByNames(nameConvert map[string]func(src any) (dst any, err error)) option {
return func(o *options) {
if o.convertByName == nil {
o.convertByName = make(map[string]convertFunc)
}
for name, nameConvert := range nameConvert {
o.convertByName[name] = nameConvert
}
}
}
// WithTypeConvert 添加根据源和目标类型处理的类型转换器
func WithConverters(converters ...TypeConverter) option {
return func(o *options) {
if len(converters) > 0 && o.converters == nil {
o.converters = make(map[converterPair]convertFunc)
}
for i := range converters {
pair := converterPair{
SrcType: reflect.TypeOf(converters[i].SrcType),