diff --git a/redis/redis.go b/redis/redis.go index c0e3fa6..01c31a8 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -1,9 +1,10 @@ package redis import ( - "context" + "sync" "time" + "github.com/charlienet/go-mixed/expr" "github.com/redis/go-redis/v9" ) @@ -15,118 +16,68 @@ const ( defaultSlowThreshold = time.Millisecond * 100 // 慢查询 ) -type Option func(r *Redis) +var ( + once sync.Once +) -type Redis struct { - addr string // 服务器地址 - prefix string // 键值前缀 - separator string // 分隔符 +type ReidsOption struct { + Addrs []string + Password string // 密码 + Prefix string + Separator string + + MaxRetries int + MinRetryBackoff time.Duration + MaxRetryBackoff time.Duration + + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + ContextTimeoutEnabled bool + + PoolSize int + PoolTimeout time.Duration + MinIdleConns int + MaxIdleConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration } -type Subscriber struct { - *redis.PubSub +type Client struct { + redis.UniversalClient } -func New(addr string, opts ...Option) *Redis { - r := &Redis{ - addr: addr, - } +func New(opt *ReidsOption) Client { + var rdb redis.UniversalClient + once.Do(func() { + rdb = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: opt.Addrs, + Password: opt.Password, - return r -} + MaxRetries: opt.MaxRetries, + MinRetryBackoff: opt.MinRetryBackoff, + MaxRetryBackoff: opt.MaxRetryBackoff, -func (s *Redis) Set(ctx context.Context, key, value string) error { - conn, err := s.getRedis() - if err != nil { - return err - } + DialTimeout: opt.DialTimeout, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + ContextTimeoutEnabled: opt.ContextTimeoutEnabled, - return conn.Set(ctx, s.formatKey(key), value, 0).Err() -} + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + }) -func (s *Redis) Get(ctx context.Context, key string) (string, error) { - conn, err := s.getRedis() - if err != nil { - return "", err - } - - return conn.Get(ctx, s.formatKey(key)).Result() -} - -func (s *Redis) GetSet(ctx context.Context, key, value string) (string, error) { - conn, err := s.getRedis() - if err != nil { - return "", err - } - - val, err := conn.GetSet(ctx, s.formatKey(key), value).Result() - return val, err -} - -func (s *Redis) Del(ctx context.Context, key ...string) (int, error) { - conn, err := s.getRedis() - if err != nil { - return 0, err - } - - keys := s.formatKeys(key...) - v, err := conn.Del(ctx, keys...).Result() - if err != nil { - return 0, err - } - - return int(v), err -} - -func (s *Redis) Subscribe(ctx context.Context, channel string) Subscriber { - conn, err := s.getRedis() - if err != nil { - return Subscriber{} - } - - sub := conn.Subscribe(context.Background(), channel) - - return Subscriber{sub} -} - -func (s *Redis) Publish(ctx context.Context, channel, msg string) *redis.IntCmd { - - conn, err := s.getRedis() - if err != nil { - return &redis.IntCmd{} - } - - cmd := conn.Publish(ctx, channel, msg) - - return cmd -} - -func (s *Redis) getRedis() (redis.UniversalClient, error) { - client := redis.NewUniversalClient(&redis.UniversalOptions{ - Addrs: []string{s.addr}, + if len(opt.Prefix) > 0 { + rdb.AddHook(renameKey{ + prefix: opt.Prefix, + separator: expr.Ternary(len(opt.Separator) == 0, defaultSeparator, opt.Separator), + }) + } }) - return client, nil -} - -func (s *Redis) formatKeys(keys ...string) []string { - // If no prefix is configured, this parameter is returned - if s.prefix == "" { - return keys - } - - ret := make([]string, 0, len(keys)) - for _, k := range keys { - ret = append(ret, s.formatKey(k)) - } - - return ret -} - -func (s *Redis) formatKey(key string) string { - if s.prefix == "" { - return key - } - - return s.prefix + s.separator + key + return Client{UniversalClient: rdb} } diff --git a/redis/redis_test.go b/redis/redis_test.go index bc5421d..ff73259 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -2,7 +2,9 @@ package redis import ( "context" + "fmt" "log" + "sync" "testing" "time" @@ -11,48 +13,127 @@ import ( ) func TestGetSet(t *testing.T) { - runOnRedis(t, func(client *Redis) { + runOnRedis(t, func(client Client) { ctx := context.Background() - val, err := client.GetSet(ctx, "hello", "world") + val, err := client.GetSet(ctx, "hello", "world").Result() assert.NotNil(t, err) assert.Equal(t, "", val) - val, err = client.Get(ctx, "hello") + val, err = client.Get(ctx, "hello").Result() assert.Nil(t, err) assert.Equal(t, "world", val) - val, err = client.GetSet(ctx, "hello", "newworld") + val, err = client.GetSet(ctx, "hello", "newworld").Result() assert.Nil(t, err) assert.Equal(t, "world", val) - val, err = client.Get(ctx, "hello") + val, err = client.Get(ctx, "hello").Result() assert.Nil(t, err) assert.Equal(t, "newworld", val) - ret, err := client.Del(ctx, "hello") + ret, err := client.Del(ctx, "hello").Result() assert.Nil(t, err) assert.Equal(t, 1, ret) }) } func TestRedis_SetGetDel(t *testing.T) { - runOnRedis(t, func(client *Redis) { + runOnRedis(t, func(client Client) { ctx := context.Background() - err := client.Set(ctx, "hello", "world") + _, err := client.Set(ctx, "hello", "world", 0).Result() assert.Nil(t, err) - val, err := client.Get(ctx, "hello") + val, err := client.Get(ctx, "hello").Result() assert.Nil(t, err) assert.Equal(t, "world", val) - ret, err := client.Del(ctx, "hello") + ret, err := client.Del(ctx, "hello").Result() assert.Nil(t, err) - assert.Equal(t, 1, ret) + assert.Equal(t, int64(1), ret) }) } -func runOnRedis(t *testing.T, fn func(client *Redis)) { +func TestPubSub(t *testing.T) { + runOnRedis(t, func(client Client) { + ctx := context.Background() + + c := "chat" + quit := false + + total := 0 + mu := &sync.Mutex{} + f := func(wg *sync.WaitGroup) { + wg.Add(1) + var receivedCount int = 0 + + sub := client.Subscribe(ctx, c) + defer sub.Close() + + for { + select { + case <-sub.Channel(): + receivedCount++ + // case <-quit: + + default: + if quit { + mu.Lock() + total += receivedCount + mu.Unlock() + + t.Logf("Subscriber received %d message %d", receivedCount, total) + wg.Done() + + return + } + } + } + + // for msg := range sub.Channel() { + // if strings.EqualFold(msg.Payload, "quit") { + // break + // } + + // receivedCount++ + // } + + } + + var wg = &sync.WaitGroup{} + go f(wg) + go f(wg) + go f(wg) + + for i := 0; i < 20000; i++ { + + n, err := client.Publish(ctx, c, fmt.Sprintf("hello %d", i)).Result() + if err != nil { + t.Log(err) + } + + _ = n + + // t.Logf("%d clients received the message\n", n) + } + + // for i := 0; i < 20; i++ { + // client.Publish(ctx, c, "quit") + // } + + t.Log("finished send message") + + time.Sleep(time.Second * 5) + quit = true + + wg.Wait() + + time.Sleep(time.Second * 2) + t.Logf("total received %d message", total) + }) +} + +func runOnRedis(t *testing.T, fn func(client Client)) { redis, clean, err := CreateMiniRedis() assert.Nil(t, err) @@ -61,26 +142,28 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) { fn(redis) } -func CreateMiniRedis() (r *Redis, clean func(), err error) { +func CreateMiniRedis() (r Client, clean func(), err error) { mr, err := miniredis.Run() if err != nil { - return nil, nil, err + return Client{}, nil, err } addr := mr.Addr() log.Println("mini redis run at:", addr) - return New(addr), func() { - ch := make(chan struct{}) + return New(&ReidsOption{ + Addrs: []string{addr}, + }), func() { + ch := make(chan struct{}) - go func() { - mr.Close() - close(ch) - }() + go func() { + mr.Close() + close(ch) + }() - select { - case <-ch: - case <-time.After(time.Second): - } - }, nil + select { + case <-ch: + case <-time.After(time.Second): + } + }, nil } diff --git a/redis/rename_hook.go b/redis/rename_hook.go new file mode 100644 index 0000000..383add1 --- /dev/null +++ b/redis/rename_hook.go @@ -0,0 +1,91 @@ +package redis + +import ( + "context" + "net" + "strings" + + "github.com/redis/go-redis/v9" +) + +var ( +// sequentials = sets.NewHashSet("RENAME", "RENAMENX", "MGET", "BLPOP", "BRPOP", "RPOPLPUSH", "SDIFFSTORE", "SINTER") +) + +type renameKey struct { + prefix string + separator string +} + +func (r renameKey) DialHook(next redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + return next(ctx, network, addr) + } +} + +func (r renameKey) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + + // 对多个KEY进行更名操作 + for i := 0; i < len(cmds); i++ { + r.renameKey(cmds[i]) + } + + return next(ctx, cmds) + } +} + +func (r renameKey) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + r.renameKey(cmd) + next(ctx, cmd) + + return nil + } +} + +func (r renameKey) renameKey(cmd redis.Cmder) { + if len(r.prefix) == 0 { + return + } + + args := cmd.Args() + if len(args) == 1 { + return + } + + switch strings.ToUpper(cmd.Name()) { + case "SELECT": + // 无KEY指令 + case "RENAME", "RENAMENX", "MGET", "BLPOP", "BRPOP", "RPOPLPUSH", "SDIFFSTORE", "SINTER", "SINTERSTORE", "SUNIONSTORE": + // 连续KEY + r.rename(args, createSepuence(1, len(args), 1)...) + case "MSET", "MSETNX": + // 间隔KEY,KEY位置规则1,3,5,7 + r.rename(args, createSepuence(1, len(args), 2)...) + default: + // 默认第一个参数为键值 + r.rename(args, 1) + } +} + +func (r renameKey) rename(args []any, indexes ...int) { + for _, i := range indexes { + if key, ok := args[i].(string); ok { + var builder strings.Builder + builder.WriteString(r.prefix) + builder.WriteString(r.separator) + builder.WriteString(key) + + args[i] = builder.String() + } + } +} + +func createSepuence(start, end, step int) []int { + ret := make([]int, 0, (end-start)/step+1) + for i := start; i <= end; i += step { + ret = append(ret, i) + } + return ret +}