1
0
mirror of https://github.com/charlienet/go-mixed.git synced 2025-07-18 00:22:41 +08:00

update redis client

This commit is contained in:
2023-10-12 14:37:10 +08:00
parent 165fc91f9b
commit 95ad0941a8
3 changed files with 254 additions and 129 deletions

View File

@ -1,9 +1,10 @@
package redis package redis
import ( import (
"context" "sync"
"time" "time"
"github.com/charlienet/go-mixed/expr"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@ -15,118 +16,68 @@ const (
defaultSlowThreshold = time.Millisecond * 100 // 慢查询 defaultSlowThreshold = time.Millisecond * 100 // 慢查询
) )
type Option func(r *Redis) var (
once sync.Once
)
type Redis struct { type ReidsOption struct {
addr string // 服务器地址 Addrs []string
prefix string // 键值前缀 Password string // 密码
separator string // 分隔符 Prefix string
Separator string
MaxRetries int
MinRetryBackoff time.Duration
MaxRetryBackoff time.Duration
DialTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
ContextTimeoutEnabled bool
PoolSize int
PoolTimeout time.Duration
MinIdleConns int
MaxIdleConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
} }
type Subscriber struct { type Client struct {
*redis.PubSub redis.UniversalClient
} }
func New(addr string, opts ...Option) *Redis { func New(opt *ReidsOption) Client {
r := &Redis{ var rdb redis.UniversalClient
addr: addr, once.Do(func() {
} rdb = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: opt.Addrs,
Password: opt.Password,
return r MaxRetries: opt.MaxRetries,
} MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff,
func (s *Redis) Set(ctx context.Context, key, value string) error { DialTimeout: opt.DialTimeout,
conn, err := s.getRedis() ReadTimeout: opt.ReadTimeout,
if err != nil { WriteTimeout: opt.WriteTimeout,
return err ContextTimeoutEnabled: opt.ContextTimeoutEnabled,
}
return conn.Set(ctx, s.formatKey(key), value, 0).Err() PoolSize: opt.PoolSize,
} PoolTimeout: opt.PoolTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
})
func (s *Redis) Get(ctx context.Context, key string) (string, error) { if len(opt.Prefix) > 0 {
conn, err := s.getRedis() rdb.AddHook(renameKey{
if err != nil { prefix: opt.Prefix,
return "", err separator: expr.Ternary(len(opt.Separator) == 0, defaultSeparator, opt.Separator),
} })
}
return conn.Get(ctx, s.formatKey(key)).Result()
}
func (s *Redis) GetSet(ctx context.Context, key, value string) (string, error) {
conn, err := s.getRedis()
if err != nil {
return "", err
}
val, err := conn.GetSet(ctx, s.formatKey(key), value).Result()
return val, err
}
func (s *Redis) Del(ctx context.Context, key ...string) (int, error) {
conn, err := s.getRedis()
if err != nil {
return 0, err
}
keys := s.formatKeys(key...)
v, err := conn.Del(ctx, keys...).Result()
if err != nil {
return 0, err
}
return int(v), err
}
func (s *Redis) Subscribe(ctx context.Context, channel string) Subscriber {
conn, err := s.getRedis()
if err != nil {
return Subscriber{}
}
sub := conn.Subscribe(context.Background(), channel)
return Subscriber{sub}
}
func (s *Redis) Publish(ctx context.Context, channel, msg string) *redis.IntCmd {
conn, err := s.getRedis()
if err != nil {
return &redis.IntCmd{}
}
cmd := conn.Publish(ctx, channel, msg)
return cmd
}
func (s *Redis) getRedis() (redis.UniversalClient, error) {
client := redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: []string{s.addr},
}) })
return client, nil return Client{UniversalClient: rdb}
}
func (s *Redis) formatKeys(keys ...string) []string {
// If no prefix is configured, this parameter is returned
if s.prefix == "" {
return keys
}
ret := make([]string, 0, len(keys))
for _, k := range keys {
ret = append(ret, s.formatKey(k))
}
return ret
}
func (s *Redis) formatKey(key string) string {
if s.prefix == "" {
return key
}
return s.prefix + s.separator + key
} }

View File

@ -2,7 +2,9 @@ package redis
import ( import (
"context" "context"
"fmt"
"log" "log"
"sync"
"testing" "testing"
"time" "time"
@ -11,48 +13,127 @@ import (
) )
func TestGetSet(t *testing.T) { func TestGetSet(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client Client) {
ctx := context.Background() ctx := context.Background()
val, err := client.GetSet(ctx, "hello", "world") val, err := client.GetSet(ctx, "hello", "world").Result()
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "", val) assert.Equal(t, "", val)
val, err = client.Get(ctx, "hello") val, err = client.Get(ctx, "hello").Result()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "world", val) assert.Equal(t, "world", val)
val, err = client.GetSet(ctx, "hello", "newworld") val, err = client.GetSet(ctx, "hello", "newworld").Result()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "world", val) assert.Equal(t, "world", val)
val, err = client.Get(ctx, "hello") val, err = client.Get(ctx, "hello").Result()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "newworld", val) assert.Equal(t, "newworld", val)
ret, err := client.Del(ctx, "hello") ret, err := client.Del(ctx, "hello").Result()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, ret) assert.Equal(t, 1, ret)
}) })
} }
func TestRedis_SetGetDel(t *testing.T) { func TestRedis_SetGetDel(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client Client) {
ctx := context.Background() ctx := context.Background()
err := client.Set(ctx, "hello", "world") _, err := client.Set(ctx, "hello", "world", 0).Result()
assert.Nil(t, err) assert.Nil(t, err)
val, err := client.Get(ctx, "hello") val, err := client.Get(ctx, "hello").Result()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "world", val) assert.Equal(t, "world", val)
ret, err := client.Del(ctx, "hello") ret, err := client.Del(ctx, "hello").Result()
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, 1, ret) assert.Equal(t, int64(1), ret)
}) })
} }
func runOnRedis(t *testing.T, fn func(client *Redis)) { func TestPubSub(t *testing.T) {
runOnRedis(t, func(client Client) {
ctx := context.Background()
c := "chat"
quit := false
total := 0
mu := &sync.Mutex{}
f := func(wg *sync.WaitGroup) {
wg.Add(1)
var receivedCount int = 0
sub := client.Subscribe(ctx, c)
defer sub.Close()
for {
select {
case <-sub.Channel():
receivedCount++
// case <-quit:
default:
if quit {
mu.Lock()
total += receivedCount
mu.Unlock()
t.Logf("Subscriber received %d message %d", receivedCount, total)
wg.Done()
return
}
}
}
// for msg := range sub.Channel() {
// if strings.EqualFold(msg.Payload, "quit") {
// break
// }
// receivedCount++
// }
}
var wg = &sync.WaitGroup{}
go f(wg)
go f(wg)
go f(wg)
for i := 0; i < 20000; i++ {
n, err := client.Publish(ctx, c, fmt.Sprintf("hello %d", i)).Result()
if err != nil {
t.Log(err)
}
_ = n
// t.Logf("%d clients received the message\n", n)
}
// for i := 0; i < 20; i++ {
// client.Publish(ctx, c, "quit")
// }
t.Log("finished send message")
time.Sleep(time.Second * 5)
quit = true
wg.Wait()
time.Sleep(time.Second * 2)
t.Logf("total received %d message", total)
})
}
func runOnRedis(t *testing.T, fn func(client Client)) {
redis, clean, err := CreateMiniRedis() redis, clean, err := CreateMiniRedis()
assert.Nil(t, err) assert.Nil(t, err)
@ -61,26 +142,28 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
fn(redis) fn(redis)
} }
func CreateMiniRedis() (r *Redis, clean func(), err error) { func CreateMiniRedis() (r Client, clean func(), err error) {
mr, err := miniredis.Run() mr, err := miniredis.Run()
if err != nil { if err != nil {
return nil, nil, err return Client{}, nil, err
} }
addr := mr.Addr() addr := mr.Addr()
log.Println("mini redis run at:", addr) log.Println("mini redis run at:", addr)
return New(addr), func() { return New(&ReidsOption{
ch := make(chan struct{}) Addrs: []string{addr},
}), func() {
ch := make(chan struct{})
go func() { go func() {
mr.Close() mr.Close()
close(ch) close(ch)
}() }()
select { select {
case <-ch: case <-ch:
case <-time.After(time.Second): case <-time.After(time.Second):
} }
}, nil }, nil
} }

91
redis/rename_hook.go Normal file
View File

@ -0,0 +1,91 @@
package redis
import (
"context"
"net"
"strings"
"github.com/redis/go-redis/v9"
)
var (
// sequentials = sets.NewHashSet("RENAME", "RENAMENX", "MGET", "BLPOP", "BRPOP", "RPOPLPUSH", "SDIFFSTORE", "SINTER")
)
type renameKey struct {
prefix string
separator string
}
func (r renameKey) DialHook(next redis.DialHook) redis.DialHook {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
return next(ctx, network, addr)
}
}
func (r renameKey) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
return func(ctx context.Context, cmds []redis.Cmder) error {
// 对多个KEY进行更名操作
for i := 0; i < len(cmds); i++ {
r.renameKey(cmds[i])
}
return next(ctx, cmds)
}
}
func (r renameKey) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
return func(ctx context.Context, cmd redis.Cmder) error {
r.renameKey(cmd)
next(ctx, cmd)
return nil
}
}
func (r renameKey) renameKey(cmd redis.Cmder) {
if len(r.prefix) == 0 {
return
}
args := cmd.Args()
if len(args) == 1 {
return
}
switch strings.ToUpper(cmd.Name()) {
case "SELECT":
// 无KEY指令
case "RENAME", "RENAMENX", "MGET", "BLPOP", "BRPOP", "RPOPLPUSH", "SDIFFSTORE", "SINTER", "SINTERSTORE", "SUNIONSTORE":
// 连续KEY
r.rename(args, createSepuence(1, len(args), 1)...)
case "MSET", "MSETNX":
// 间隔KEYKEY位置规则1,3,5,7
r.rename(args, createSepuence(1, len(args), 2)...)
default:
// 默认第一个参数为键值
r.rename(args, 1)
}
}
func (r renameKey) rename(args []any, indexes ...int) {
for _, i := range indexes {
if key, ok := args[i].(string); ok {
var builder strings.Builder
builder.WriteString(r.prefix)
builder.WriteString(r.separator)
builder.WriteString(key)
args[i] = builder.String()
}
}
}
func createSepuence(start, end, step int) []int {
ret := make([]int, 0, (end-start)/step+1)
for i := start; i <= end; i += step {
ret = append(ret, i)
}
return ret
}