195 lines
4.5 KiB
Go
195 lines
4.5 KiB
Go
package copier
|
|
|
|
import "reflect"
|
|
|
|
const (
|
|
noDepthLimited = -1
|
|
)
|
|
|
|
type convertFunc func(any) (any, error)
|
|
|
|
type options struct {
|
|
tagName string // 标签名
|
|
maxDepth int // 最大复制深度
|
|
ignoreEmpty bool // 复制时忽略空字段
|
|
caseSensitive bool // 复制时大小写敏感
|
|
must bool // 只复制具有must标识的字段
|
|
detectCircularRefs bool // 检测循环引用
|
|
fieldNameMapping map[string]string // 字段名转映射
|
|
converters map[converterPair]convertFunc // 根据源和目标类型处理的类型转换器v
|
|
convertByName map[string]convertFunc // 根据名称处理的类型转换器
|
|
nameConverter func(string) string // 字段名转换函数
|
|
timeFormats map[reflect.Type]string // 时间格式
|
|
}
|
|
|
|
type option func(*options)
|
|
|
|
type TypeConverter struct {
|
|
FieldName string
|
|
SrcType any
|
|
DstType any
|
|
Fn func(src any) (dst any, err error)
|
|
}
|
|
|
|
type converterPair struct {
|
|
Name string
|
|
SrcType reflect.Type
|
|
DstType reflect.Type
|
|
}
|
|
|
|
func getOpt(opts ...option) *options {
|
|
opt := &options{
|
|
maxDepth: noDepthLimited,
|
|
tagName: defaultTag,
|
|
}
|
|
|
|
for _, o := range opts {
|
|
o(opt)
|
|
}
|
|
|
|
return opt
|
|
}
|
|
|
|
func (opt *options) NameConvert(name string) string {
|
|
if opt.nameConverter != nil {
|
|
name = opt.nameConverter(name)
|
|
}
|
|
|
|
if toname, ok := opt.fieldNameMapping[name]; ok {
|
|
name = toname
|
|
}
|
|
|
|
return name
|
|
}
|
|
|
|
func (opt *options) ExceedMaxDepth(depth int) bool {
|
|
return opt.maxDepth != noDepthLimited && depth > opt.maxDepth
|
|
}
|
|
|
|
// WithMaxDepth 设置最大复制深度
|
|
func WithMaxDepth(depth int) option {
|
|
return func(o *options) {
|
|
o.maxDepth = depth
|
|
}
|
|
}
|
|
|
|
// WithIgnoreEmpty 复制时忽略空字段标识
|
|
func WithIgnoreEmpty() option {
|
|
return func(o *options) {
|
|
o.ignoreEmpty = true
|
|
}
|
|
}
|
|
|
|
// WithCaseInsensitive 复制时大小写不敏感标识
|
|
func WithCaseSensitive() option {
|
|
return func(o *options) {
|
|
o.caseSensitive = true
|
|
}
|
|
}
|
|
|
|
// WithMust 添加必须复制标识
|
|
func WithMust() option {
|
|
return func(o *options) {
|
|
o.must = true
|
|
}
|
|
}
|
|
|
|
// WithDetectCircularRefs 添加检测循环引用标识
|
|
func WithDetectCircularRefs() option {
|
|
return func(o *options) {
|
|
o.detectCircularRefs = true
|
|
}
|
|
}
|
|
|
|
// WithNameConvertByName 添加根据名称处理的类型转换器
|
|
func WithTypeConvertByName(name string, f func(src any) (dst any, err error)) option {
|
|
return func(o *options) {
|
|
if o.convertByName == nil {
|
|
o.convertByName = make(map[string]convertFunc)
|
|
}
|
|
|
|
o.convertByName[name] = f
|
|
}
|
|
}
|
|
|
|
// WithConvertByNames 添加根据名称处理的类型转换器
|
|
func WithTypeConvertByNames(nameConvert map[string]func(src any) (dst any, err error)) option {
|
|
return func(o *options) {
|
|
if o.convertByName == nil {
|
|
o.convertByName = make(map[string]convertFunc)
|
|
}
|
|
|
|
for name, nameConvert := range nameConvert {
|
|
o.convertByName[name] = nameConvert
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithTypeConvert 添加根据源和目标类型处理的类型转换器
|
|
func WithConverters(converters ...TypeConverter) option {
|
|
return func(o *options) {
|
|
if len(converters) > 0 && o.converters == nil {
|
|
o.converters = make(map[converterPair]convertFunc)
|
|
}
|
|
|
|
for i := range converters {
|
|
pair := converterPair{
|
|
SrcType: reflect.TypeOf(converters[i].SrcType),
|
|
DstType: reflect.TypeOf(converters[i].DstType),
|
|
}
|
|
|
|
o.converters[pair] = converters[i].Fn
|
|
}
|
|
}
|
|
}
|
|
|
|
// WithNameMapping 添加字段名映射
|
|
func WithNameMapping(mappings map[string]string) option {
|
|
return func(o *options) {
|
|
o.fieldNameMapping = mappings
|
|
}
|
|
}
|
|
|
|
// WithNameConverter 添加字段名转换函数
|
|
func WithNameFn(fn func(string) string) option {
|
|
return func(o *options) {
|
|
o.nameConverter = fn
|
|
}
|
|
}
|
|
|
|
func WithTimeFormat(targetType any, format string) option {
|
|
return func(o *options) {
|
|
if o.timeFormats == nil {
|
|
o.timeFormats = make(map[reflect.Type]string)
|
|
}
|
|
|
|
t := reflect.TypeOf(targetType)
|
|
if t.Kind() == reflect.Pointer {
|
|
t = t.Elem()
|
|
}
|
|
|
|
o.timeFormats[t] = format
|
|
}
|
|
}
|
|
|
|
func WithStringTimeFormat(format string) option {
|
|
return WithTimeFormat("", format)
|
|
}
|
|
|
|
func WithDefaultTimeFormat(format string) option {
|
|
return func(o *options) {
|
|
if o.timeFormats == nil {
|
|
o.timeFormats = make(map[reflect.Type]string)
|
|
}
|
|
|
|
o.timeFormats[reflect.TypeOf(nil)] = format
|
|
}
|
|
}
|
|
|
|
// WithTagName 添加标签名
|
|
func WithTagName(tagName string) option {
|
|
return func(o *options) {
|
|
o.tagName = tagName
|
|
}
|
|
}
|