1
0
mirror of https://github.com/charlienet/go-mixed.git synced 2025-07-17 16:12:42 +08:00

添加分布式锁实现-使用Redis

This commit is contained in:
2023-10-12 15:25:13 +08:00
parent 69690da6b4
commit e83db7daee
5 changed files with 307 additions and 81 deletions

View File

@ -1,17 +1,14 @@
package bloom
import (
"log"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/charlienet/go-mixed/redis"
"github.com/stretchr/testify/assert"
"github.com/charlienet/go-mixed/tests"
)
func TestRedisStore(t *testing.T) {
runOnRedis(t, func(client redis.Client) {
tests.RunOnRedis(t, func(client redis.Client) {
store := newRedisStore(client, "abcdef", 10000)
err := store.Set(1, 2, 3, 9, 1223)
if err != nil {
@ -23,38 +20,3 @@ func TestRedisStore(t *testing.T) {
t.Log(store.Test(4, 5, 8))
})
}
func runOnRedis(t *testing.T, fn func(client redis.Client)) {
redis, clean, err := CreateMiniRedis()
assert.Nil(t, err)
defer clean()
fn(redis)
}
func CreateMiniRedis() (r redis.Client, clean func(), err error) {
mr, err := miniredis.Run()
if err != nil {
return nil, nil, err
}
addr := mr.Addr()
log.Println("mini redis run at:", addr)
return redis.New(&redis.ReidsOption{
Addrs: []string{addr},
}), func() {
ch := make(chan struct{})
go func() {
mr.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

View File

@ -0,0 +1,166 @@
package locker
import (
"context"
"errors"
"strings"
"sync"
"time"
"github.com/charlienet/go-mixed/rand"
"github.com/charlienet/go-mixed/redis"
goredis "github.com/redis/go-redis/v9"
)
const (
// 加锁(可重入)
lockCmd = `if redis.call("GET", KEYS[1]) == ARGV[1] then
redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
return "OK"
else
return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
end`
// 解锁
delCmd = `if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return '0'
end`
// 延期
incrCmd = `
if redis.call('get', KEYS[1]) == ARGV[1] then
return redis.call('expire', KEYS[1], ARGV[2])
else
return '0'
end`
)
const (
defaultExpire = time.Second * 10
retryInterval = time.Millisecond * 10
)
var (
once sync.Once
ErrContextCancel = errors.New("context cancel")
)
type distributedlock struct {
clients []redis.Client // redis 客户端
ctx context.Context //
key string // 资源键
rand string // 随机值
unlocked bool // 是否已解锁
expire time.Duration // 过期时间
}
func NewDistributedLocker(ctx context.Context, key string, clients ...redis.Client) *distributedlock {
expire := defaultExpire
if deadline, ok := ctx.Deadline(); ok {
expire = deadline.Sub(time.Now())
}
locker := &distributedlock{
ctx: ctx,
clients: clients,
key: key,
rand: rand.Hex.Generate(24),
expire: expire,
}
return locker
}
func (locker *distributedlock) Lock() error {
for {
select {
case <-locker.ctx.Done():
return ErrContextCancel
default:
if locker.TryLock() {
return nil
}
}
time.Sleep(retryInterval)
}
}
func (locker *distributedlock) TryLock() bool {
results := locker.Eval(locker.ctx, lockCmd, []string{locker.key}, locker.rand, locker.expire.Milliseconds())
if !isSuccess(results) {
locker.Unlock()
return false
}
locker.expandLockTime()
return true
}
func (locker *distributedlock) Unlock() {
locker.Eval(locker.ctx, delCmd, []string{locker.key}, locker.rand)
locker.unlocked = true
}
func (l *distributedlock) expandLockTime() {
once.Do(func() {
go func() {
for {
time.Sleep(l.expire / 3)
if l.unlocked {
break
}
l.resetExpire()
}
}()
})
}
func (locker *distributedlock) resetExpire() {
locker.Eval(locker.ctx,
incrCmd,
[]string{locker.key},
locker.rand,
locker.expire.Seconds())
}
func (locker *distributedlock) Eval(ctx context.Context, cmd string, keys []string, args ...any) []*goredis.Cmd {
results := make([]*goredis.Cmd, 0, len(locker.clients))
var wg sync.WaitGroup
wg.Add(len(locker.clients))
for _, rdb := range locker.clients {
go func(rdb redis.Client) {
defer wg.Done()
results = append(results, rdb.Eval(ctx, cmd, keys, args...))
}(rdb)
}
wg.Wait()
return results
}
func isSuccess(results []*goredis.Cmd) bool {
successCount := 0
for _, ret := range results {
resp, err := ret.Result()
if err != nil || resp == nil {
return false
}
reply, ok := resp.(string)
if ok && strings.EqualFold(reply, "OK") {
successCount++
}
}
return successCount >= len(results)/2+1
}

View File

@ -0,0 +1,87 @@
package locker
import (
"context"
"log"
"sync"
"testing"
"time"
"github.com/charlienet/go-mixed/redis"
"github.com/charlienet/go-mixed/tests"
)
func TestDistributedLock(t *testing.T) {
tests.RunOnRedis(t, func(rdb redis.Client) {
lock := NewDistributedLocker(context.Background(), "lock_test", rdb)
lock.Lock()
lock.Unlock()
})
}
func TestConcurrence(t *testing.T) {
tests.RunOnRedis(t, func(rdb redis.Client) {
count := 5
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
defer wg.Done()
locker := NewDistributedLocker(context.Background(), "lock_test", rdb)
for n := 0; n < 5; n++ {
locker.Lock()
t.Logf("协程%d获取到锁", i)
time.Sleep(time.Second)
t.Logf("协程%d释放锁", i)
locker.Unlock()
}
}(i)
}
wg.Wait()
log.Println("所有任务完成")
})
}
func TestTwoLocker(t *testing.T) {
tests.RunOnRedis(t, func(rdb redis.Client) {
l1 := NewDistributedLocker(context.Background(), "lock_test", rdb)
l2 := NewDistributedLocker(context.Background(), "lock_test", rdb)
go func() {
l1.Lock()
println("l1 获取锁")
}()
go func() {
l2.Lock()
println("l2 获取锁")
}()
time.Sleep(time.Second * 20)
l1.Unlock()
l2.Unlock()
})
}
func TestDistributediTryLock(t *testing.T) {
tests.RunOnRedis(t, func(client redis.Client) {
lock := NewDistributedLocker(context.Background(), "lock_test", client)
l := lock.TryLock()
t.Log("尝试加锁结果:", l)
time.Sleep(time.Second * 20)
lock.Unlock()
})
}
func TestLocker(t *testing.T) {
}

View File

@ -1,19 +1,19 @@
package redis
package redis_test
import (
"context"
"fmt"
"log"
"sync"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/charlienet/go-mixed/redis"
"github.com/charlienet/go-mixed/tests"
"github.com/stretchr/testify/assert"
)
func TestGetSet(t *testing.T) {
runOnRedis(t, func(client Client) {
tests.RunOnRedis(t, func(client redis.Client) {
ctx := context.Background()
val, err := client.GetSet(ctx, "hello", "world").Result()
@ -39,7 +39,7 @@ func TestGetSet(t *testing.T) {
}
func TestRedis_SetGetDel(t *testing.T) {
runOnRedis(t, func(client Client) {
tests.RunOnRedis(t, func(client redis.Client) {
ctx := context.Background()
_, err := client.Set(ctx, "hello", "world", 0).Result()
@ -55,7 +55,7 @@ func TestRedis_SetGetDel(t *testing.T) {
}
func TestPubSub(t *testing.T) {
runOnRedis(t, func(client Client) {
tests.RunOnRedis(t, func(client redis.Client) {
ctx := context.Background()
c := "chat"
@ -132,38 +132,3 @@ func TestPubSub(t *testing.T) {
t.Logf("total received %d message", total)
})
}
func runOnRedis(t *testing.T, fn func(client Client)) {
redis, clean, err := CreateMiniRedis()
assert.Nil(t, err)
defer clean()
fn(redis)
}
func CreateMiniRedis() (r Client, clean func(), err error) {
mr, err := miniredis.Run()
if err != nil {
return nil, nil, err
}
addr := mr.Addr()
log.Println("mini redis run at:", addr)
return New(&ReidsOption{
Addrs: []string{addr},
}), func() {
ch := make(chan struct{})
go func() {
mr.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}

46
tests/redis.go Normal file
View File

@ -0,0 +1,46 @@
package tests
import (
"log"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/charlienet/go-mixed/redis"
"github.com/stretchr/testify/assert"
)
func RunOnRedis(t *testing.T, fn func(client redis.Client)) {
redis, clean, err := createMiniRedis()
assert.Nil(t, err)
defer clean()
fn(redis)
}
func createMiniRedis() (r redis.Client, clean func(), err error) {
mr, err := miniredis.Run()
if err != nil {
return nil, nil, err
}
addr := mr.Addr()
log.Println("mini redis run at:", addr)
return redis.New(&redis.ReidsOption{
Addrs: []string{addr},
}), func() {
ch := make(chan struct{})
go func() {
mr.Close()
close(ch)
}()
select {
case <-ch:
case <-time.After(time.Second):
}
}, nil
}