This commit is contained in:
2025-09-28 09:15:42 +08:00
parent a6e0c5e740
commit 3482b9b6e5

251
copier.go
View File

@@ -11,6 +11,7 @@ import (
type copier struct {
opt *options
visited map[reflect.Value]struct{}
mu sync.Mutex
}
func newCopier(opts ...option) *copier {
@@ -27,6 +28,8 @@ func newCopier(opts ...option) *copier {
}
func (c *copier) Copy(dst, src any) error {
c.mu.Lock()
defer c.mu.Unlock()
defer c.reset()
if dst == nil || src == nil {
@@ -60,13 +63,16 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
return nil
}
if src.IsZero() {
if !src.IsValid() || src.IsZero() {
return nil
}
if c.isCircularRefs(src) {
return nil
}
switch src.Kind() {
case reflect.Slice, reflect.Array:
// slice -> slice
switch dst.Kind() {
case reflect.Slice, reflect.Array:
return c.copySlice(dst, src, depth)
@@ -76,21 +82,11 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
case reflect.Map:
switch dst.Kind() {
case reflect.Map:
// map -> map
return c.copyMap(dst, src, depth)
case reflect.Struct:
// map -> struct
return c.copyMap2Struct(dst, src, depth)
case reflect.Pointer:
elemType := dst.Type().Elem()
newPtr := reflect.New(elemType)
if err := c.deepCopy(newPtr.Elem(), src, depth+1); err != nil {
return err
}
dst.Set(newPtr)
return nil
return c.copyPointer(dst, src, depth)
default:
return ErrNotSupported
}
@@ -103,70 +99,15 @@ func (c *copier) deepCopy(dst, src reflect.Value, depth int) error {
// struct -> struct
return c.copyStruct(dst, src, depth)
case reflect.Array, reflect.Slice:
if dst.IsNil() {
dstType := dst.Type().Elem()
newDst := reflect.MakeSlice(reflect.SliceOf(dstType), 1, 1)
dst.Set(newDst)
if err := c.deepCopy(dst.Index(0), src, depth+1); err != nil {
return err
}
} else {
len := dst.Len()
cap := dst.Cap()
newSlice := reflect.MakeSlice(reflect.SliceOf(dst.Type().Elem()), len+1, cap+10)
reflect.Copy(newSlice, dst)
newDst := newSlice.Index(len)
if err := c.deepCopy(newDst, src, depth+1); err != nil {
return err
}
dst.Set(newSlice)
}
return nil
return c.copySlice(dst, src, depth)
default:
return c.set(dst, src)
}
case reflect.Pointer:
if dst.Kind() == reflect.Pointer && dst.IsNil() {
if !dst.CanSet() {
return nil
}
p := reflect.New(dst.Type().Elem())
dst.Set(p)
}
if src.Kind() == reflect.Pointer {
src = indirect(src)
}
if dst.Kind() == reflect.Pointer {
dst = dst.Elem()
}
return c.deepCopy(dst, src, depth)
return c.copyPointer(dst, src, depth)
case reflect.Interface:
if src.IsNil() {
return nil
}
if src.Kind() != dst.Kind() {
return c.set(dst, src)
}
src = src.Elem()
newDst := reflect.New(src.Type().Elem())
if err := c.deepCopy(newDst, src, depth); err != nil {
return err
}
dst.Set(newDst)
return nil
return c.copyInterface(dst, src, depth)
case reflect.Chan, reflect.Func, reflect.UnsafePointer:
return ErrNotSupported
default:
@@ -310,32 +251,87 @@ func (c *copier) copyMap(dst, src reflect.Value, depth int) error {
}
func (c *copier) copySlice(dst, src reflect.Value, depth int) error {
var len int
if src.Kind() == reflect.Struct {
len = 0
} else {
len = src.Len()
switch src.Kind() {
case reflect.Struct:
return c.copyStruct2Slice(dst, src, depth)
case reflect.Slice, reflect.Array:
return c.copySlice2Slice(dst, src, depth)
default:
return ErrNotSupported
}
}
func (c *copier) copyStruct2Slice(dst, src reflect.Value, depth int) error {
if dst.Kind() != reflect.Slice {
return ErrNotSupported
}
if dst.Len() > 0 && dst.Len() < src.Len() {
len = dst.Len()
}
// 创建新的slice长度+1
dstLen := dst.Len()
newSlice := reflect.MakeSlice(dst.Type(), dstLen+1, dstLen+1)
if dst.Kind() == reflect.Slice && dst.Len() == 0 && len > 0 {
dstType := dst.Type().Elem()
newDst := reflect.MakeSlice(reflect.SliceOf(dstType), len, len)
dst.Set(newDst)
}
if src.Kind() == reflect.Struct {
if err := c.deepCopy(dst.Index(0), src, depth); err != nil {
// 复制原有元素
for i := 0; i < dstLen; i++ {
if err := c.deepCopy(newSlice.Index(i), dst.Index(i), depth+1); err != nil {
return err
}
} else {
for i := 0; i < len; i++ {
if err := c.deepCopy(dst.Index(i), src.Index(i), depth); err != nil {
return err
}
}
// 追加新元素(结构体)
if err := c.deepCopy(newSlice.Index(dstLen), src, depth+1); err != nil {
return err
}
dst.Set(newSlice)
return nil
}
func (c *copier) copySlice2Slice(dst, src reflect.Value, depth int) error {
if dst.Kind() == reflect.Array {
return c.copyToArray(dst, src, depth)
}
return c.copyToSlice(dst, src, depth)
}
func (c *copier) copyToSlice(dst, src reflect.Value, depth int) error {
srcLen := src.Len()
// 如果目标slice为空或长度不够重新创建
if dst.IsNil() || dst.Len() < srcLen {
newSlice := reflect.MakeSlice(dst.Type(), srcLen, srcLen)
dst.Set(newSlice)
}
// 确定实际复制的元素数量
copyLen := srcLen
if dst.Len() < srcLen {
copyLen = dst.Len()
}
// 复制元素
for i := 0; i < copyLen; i++ {
if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil {
return err
}
}
return nil
}
func (c *copier) copyToArray(dst, src reflect.Value, depth int) error {
srcLen := src.Len()
dstLen := dst.Len()
// 取较小长度
copyLen := srcLen
if dstLen < srcLen {
copyLen = dstLen
}
for i := 0; i < copyLen; i++ {
if err := c.deepCopy(dst.Index(i), src.Index(i), depth+1); err != nil {
return err
}
}
@@ -344,10 +340,17 @@ func (c *copier) copySlice(dst, src reflect.Value, depth int) error {
func (c *copier) copyStruct(dst, src reflect.Value, depth int) error {
if dst.CanSet() {
if _, ok := src.Interface().(time.Time); ok {
if src.Type().AssignableTo(dst.Type()) {
dst.Set(src)
return nil
}
if _, ok := src.Interface().(time.Time); ok {
if dst.Type().AssignableTo(src.Type()) {
dst.Set(src)
return nil
}
}
}
typ := src.Type()
@@ -364,7 +367,7 @@ func (c *copier) copyStruct(dst, src reflect.Value, depth int) error {
continue
}
if c.isCricularRefs(dstValue) {
if c.isCircularRefs(dstValue) {
continue
}
@@ -386,6 +389,50 @@ func (c *copier) copyStruct(dst, src reflect.Value, depth int) error {
return nil
}
func (c *copier) copyInterface(dst, src reflect.Value, depth int) error {
if src.IsNil() {
return nil
}
srcElem := src.Elem()
if !srcElem.IsValid() {
return nil
}
if dst.Kind() == reflect.Interface {
newValue := reflect.New(srcElem.Type()).Elem()
if err := c.deepCopy(newValue, srcElem, depth+1); err != nil {
return err
}
dst.Set(newValue)
return nil
}
return c.deepCopy(dst, srcElem, depth+1)
}
func (c *copier) copyPointer(dst, src reflect.Value, depth int) error {
if src.Kind() == reflect.Pointer && src.IsNil() {
return nil
}
srcElem := indirect(src)
if dst.Kind() == reflect.Pointer {
if dst.IsNil() {
if !dst.CanSet() {
return nil
}
newPtr := reflect.New(dst.Type().Elem())
dst.Set(newPtr)
}
return c.deepCopy(dst.Elem(), srcElem, depth+1)
}
return c.deepCopy(dst, srcElem, depth+1)
}
func (c *copier) set(dst, src reflect.Value) error {
if !src.IsValid() {
return ErrInvalidCopyFrom
@@ -632,15 +679,17 @@ func (c *copier) ExceedMaxDepth(depth int) bool {
return c.opt.maxDepth > 0 && depth >= c.opt.maxDepth
}
func (c *copier) isCricularRefs(reflectValue reflect.Value) bool {
if c.opt.detectCircularRefs {
_, ok := c.visited[reflectValue.Addr()]
if !ok {
c.visited[reflectValue.Addr()] = struct{}{}
}
return ok
func (c *copier) isCircularRefs(v reflect.Value) bool {
if !c.opt.detectCircularRefs || !v.IsValid() || !v.CanAddr() {
return false
}
addr := v.Addr()
if _, exists := c.visited[addr]; exists {
return true
}
c.visited[addr] = struct{}{}
return false
}