Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions cacheaside.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"iter"
"log/slog"
"maps"
"strconv"
"strings"
Expand All @@ -23,27 +24,38 @@ type lockEntry struct {
cancel context.CancelFunc
}

type Logger interface {
Error(msg string, args ...any)
}

type CacheAside struct {
client rueidis.Client
locks syncx.Map[string, *lockEntry]
lockTTL time.Duration
logger Logger
}

type CacheAsideOption struct {
// LockTTL is the maximum time a lock can be held, and also the timeout for waiting
// on locks when handling lost Redis invalidation messages. Defaults to 10 seconds.
LockTTL time.Duration
ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error)
// Logger for logging non-fatal errors. Defaults to slog.Default().
Logger Logger
}

func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOption) (*CacheAside, error) {
var err error
if caOption.LockTTL == 0 {
caOption.LockTTL = 10 * time.Second
}
if caOption.Logger == nil {
caOption.Logger = slog.Default()
}

rca := &CacheAside{
lockTTL: caOption.LockTTL,
logger: caOption.Logger,
}
clientOption.OnInvalidations = rca.onInvalidate
if caOption.ClientBuilder != nil {
Expand Down Expand Up @@ -79,37 +91,37 @@ var (
)

func (rca *CacheAside) register(key string) <-chan struct{} {
// Try to load existing entry first
if entry, loaded := rca.locks.Load(key); loaded {
// Check if the context is still active (not cancelled/timed out)
select {
case <-entry.ctx.Done():
// Context is done - clean it up and create a new one
rca.locks.Delete(key)
default:
// Context is still active - use it
return entry.ctx.Done()
}
}

retry:
// Create new entry with context that auto-cancels after lockTTL
ctx, cancel := context.WithTimeout(context.Background(), rca.lockTTL)

entry := &lockEntry{
newEntry := &lockEntry{
ctx: ctx,
cancel: cancel,
}

// Store or get existing entry atomically
actual, _ := rca.locks.LoadOrStore(key, entry)
actual, loaded := rca.locks.LoadOrStore(key, newEntry)

// If another goroutine stored first, cancel our context and use theirs
if actual != entry {
cancel()
return actual.ctx.Done()
// If we successfully stored, return our context
if !loaded {
return ctx.Done()
}

return ctx.Done()
// Another goroutine stored first, cancel our context
cancel()

// Check if their context is still active (not cancelled/timed out)
select {
case <-actual.ctx.Done():
// Context is done - try to atomically delete it and retry
// If CompareAndDelete fails, another goroutine already replaced it
rca.locks.CompareAndDelete(key, actual)
goto retry
default:
// Context is still active - use it
return actual.ctx.Done()
}
}

func (rca *CacheAside) Get(
Expand Down Expand Up @@ -196,7 +208,9 @@ func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key
toCtx, cancel := context.WithTimeout(context.Background(), rca.lockTTL)
defer cancel()
// Best effort unlock - errors are non-fatal as lock will expire
_ = rca.unlock(toCtx, key, lockVal)
if err := rca.unlock(toCtx, key, lockVal); err != nil {
rca.logger.Error("failed to unlock key", "key", key, "error", err)
}
}
}()
if val, err = fn(ctx, key); err == nil {
Expand Down Expand Up @@ -482,7 +496,12 @@ func (rca *CacheAside) unlockMulti(ctx context.Context, lockVals map[string]stri
go func() {
defer wg.Done()
// Best effort unlock - errors are non-fatal as locks will expire
_ = delKeyLua.ExecMulti(ctx, rca.client, stmts...)
resps := delKeyLua.ExecMulti(ctx, rca.client, stmts...)
for _, resp := range resps {
if err := resp.Error(); err != nil {
rca.logger.Error("failed to unlock key in batch", "error", err)
}
}
}()
}
wg.Wait()
Expand Down
Loading