diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 819d625..0c0a0bd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,9 +16,14 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: '1.22.x' + go-version-file: 'go.mod' - name: Install dependencies run: go mod vendor + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v4 + with: + version: latest + args: --timeout=5m - name: Test with Go run: go test -json > TestResults.json - name: Upload Go test results diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..a2df677 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,162 @@ +# golangci-lint configuration +# Documentation: https://golangci-lint.run/usage/configuration/ + +run: + timeout: 5m + tests: true + modules-download-mode: readonly + +# Output configuration +output: + formats: + - format: colored-line-number + print-issued-lines: true + print-linter-name: true + sort-results: true + +linters: + disable-all: true + enable: + # Enabled by default + - errcheck # Checks for unchecked errors + - gosimple # Simplify code + - govet # Reports suspicious constructs + - ineffassign # Detects ineffectual assignments + - staticcheck # Staticcheck is a go vet on steroids + - unused # Checks for unused constants, variables, functions and types + + # Additional recommended linters + - gofmt # Checks whether code was gofmt-ed + - goimports # Check import statements are formatted according to goimport command + - misspell # Finds commonly misspelled English words + - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go + - gocyclo # Computes cyclomatic complexities + - goconst # Finds repeated strings that could be replaced by a constant + - gosec # Inspects source code for security problems + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - nakedret # Finds naked returns in functions greater than a specified length + - gocognit # Computes cognitive complexities + - godot # Check if comments end in a period + - whitespace # Detection of leading and trailing whitespace + - gci # Controls Go package import order and makes it deterministic + +linters-settings: + goimports: + # Use goimports as the formatter + local-prefixes: github.com/dcbickfo/redcache + + gofmt: + # Simplify code: gofmt with `-s` option + simplify: true + + gocyclo: + # Minimal cyclomatic complexity to report + min-complexity: 15 + + gocognit: + # Minimal cognitive complexity to report + min-complexity: 15 + + goconst: + # Minimal length of string constant + min-len: 3 + # Minimum occurrences to report + min-occurrences: 3 + + gosec: + # Exclude some checks + excludes: + - G104 # Audit errors not checked (we use errcheck for this) + + revive: + confidence: 0.8 + rules: + - name: blank-imports + - name: context-as-argument + - name: context-keys-type + - name: dot-imports + - name: error-return + - name: error-strings + - name: error-naming + - name: exported + - name: if-return + - name: increment-decrement + - name: var-naming + - name: var-declaration + - name: package-comments + - name: range + - name: receiver-naming + - name: time-naming + - name: unexported-return + - name: indent-error-flow + - name: errorf + - name: empty-block + - name: superfluous-else + - name: unused-parameter + - name: unreachable-code + - name: redefines-builtin-id + + nakedret: + # Make an issue if func has more lines of code than this setting and has naked return + max-func-lines: 30 + + unparam: + # Check exported functions + check-exported: false + + whitespace: + multi-if: false + multi-func: false + + gci: + # Section configuration to compare against + sections: + - standard # Standard section: captures all standard packages + - default # Default section: contains all imports that could not be matched to another section type + - prefix(github.com/dcbickfo/redcache) # Custom section: groups all imports with the specified Prefix + +issues: + # Excluding configuration per-path, per-linter, per-text and per-source + exclude-rules: + # Exclude some linters from running on tests files + - path: _test\.go + linters: + - gocyclo + - gocognit + - errcheck + - gosec + - unparam + - revive + - goconst + - godot + - whitespace + - gci + + # Exclude known issues in vendor + - path: vendor/ + linters: + - all + + # Ignore "new" parameter name shadowing built-in + - text: "redefines-builtin-id" + linters: + - revive + + # Ignore integer overflow in CRC16 - this is intentional and safe + - text: "G115.*integer overflow" + path: internal/cmdx/slot.go + linters: + - gosec + + # Maximum issues count per one linter + max-issues-per-linter: 50 + + # Maximum count of issues with the same text + max-same-issues: 3 + + # Show only new issues + new: false + + # Fix found issues (if it's supported by the linter) + fix: false diff --git a/README.md b/README.md index 184d916..b9a22b5 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ package main import ( "context" "database/sql" - "github.com/google/go-cmp/cmp/internal/value" + "log" "time" "github.com/redis/rueidis" @@ -20,15 +20,15 @@ import ( func main() { if err := run(); err != nil { log.Fatal(err) - } + } } func run() error { - var db sql.DB + var db *sql.DB // initialize db client, err := redcache.NewRedCacheAside( rueidis.ClientOption{ - InitAddress: addr, + InitAddress: []string{"127.0.0.1:6379"}, }, redcache.CacheAsideOption{ LockTTL: time.Second * 1, @@ -37,31 +37,33 @@ func run() error { if err != nil { return err } - + repo := Repository{ client: client, db: &db, } - - val, err := Repository.GetByID(context.Background(), "key") + + val, err := repo.GetByID(context.Background(), "key") if err != nil { - return err + return err } - - vals, err := Repository.GetByIDs(context.Background(), map[string]string{"key1": "val1", "key2": "val2"}) + + vals, err := repo.GetByIDs(context.Background(), []string{"key1", "key2"}) if err != nil { return err } + _, _ = val, vals + return nil } type Repository struct { - client redcache.CacheAside + client *redcache.CacheAside db *sql.DB } func (r Repository) GetByID(ctx context.Context, key string) (string, error) { val, err := r.client.Get(ctx, time.Minute, key, func(ctx context.Context, key string) (val string, err error) { - if err = db.QueryRowContext(ctx, "SELECT val FROM mytab WHERE id = ?", key).Scan(&val); err == sql.ErrNoRows { + if err = r.db.QueryRowContext(ctx, "SELECT val FROM mytab WHERE id = ?", key).Scan(&val); err == sql.ErrNoRows { val = "NULL" // cache null to avoid penetration. err = nil // clear err in case of sql.ErrNoRows. } @@ -72,33 +74,37 @@ func (r Repository) GetByID(ctx context.Context, key string) (string, error) { } else if val == "NULL" { val = "" err = sql.ErrNoRows - } - // ... + } + return val, err } -func (r Repository) GetByIDs(ctx context.Context, key []string) (map[string]string, error) { - val, err := r.client.GetMulti(ctx, time.Minute, key, func(ctx context.Context, key []string) (val map[string]string, err error) { - rows := db.QueryContext(ctx, "SELECT id, val FROM mytab WHERE id = ?", key) +func (r Repository) GetByIDs(ctx context.Context, keys []string) (map[string]string, error) { + val, err := r.client.GetMulti(ctx, time.Minute, keys, func(ctx context.Context, keys []string) (val map[string]string, err error) { + val = make(map[string]string) + rows, err := r.db.QueryContext(ctx, "SELECT id, val FROM mytab WHERE id IN (?)", keys) + if err != nil { + return nil, err + } defer rows.Close() for rows.Next() { var id, rowVal string if err = rows.Scan(&id, &rowVal); err != nil { - return + return nil, err } val[id] = rowVal } - if len(val) != len(key) { - for _, k := range key { + if len(val) != len(keys) { + for _, k := range keys { if _, ok := val[k]; !ok { val[k] = "NULL" // cache null to avoid penetration. } } } - return + return val, nil }) if err != nil { return nil, err - } + } // handle any NULL vals if desired // ... diff --git a/cacheaside.go b/cacheaside.go index ab5eebd..536a2ad 100644 --- a/cacheaside.go +++ b/cacheaside.go @@ -1,8 +1,67 @@ +// Package redcache provides a cache-aside implementation for Redis with distributed locking. +// +// This library builds on the rueidis Redis client to provide: +// - Cache-aside pattern with automatic cache population +// - Distributed locking to prevent thundering herd +// - Client-side caching to reduce Redis round trips +// - Redis cluster support with slot-aware batching +// - Automatic cleanup of expired lock entries +// +// # Basic Usage +// +// client, err := redcache.NewRedCacheAside( +// rueidis.ClientOption{InitAddress: []string{"localhost:6379"}}, +// redcache.CacheAsideOption{LockTTL: 10 * time.Second}, +// ) +// if err != nil { +// return err +// } +// defer client.Client().Close() +// +// // Get a single value with automatic cache population +// value, err := client.Get(ctx, time.Minute, "user:123", func(ctx context.Context, key string) (string, error) { +// return fetchFromDatabase(ctx, key) +// }) +// +// // Get multiple values with batched cache population +// values, err := client.GetMulti(ctx, time.Minute, []string{"user:1", "user:2"}, func(ctx context.Context, keys []string) (map[string]string, error) { +// return fetchMultipleFromDatabase(ctx, keys) +// }) +// +// # Distributed Locking +// +// The library ensures that only one goroutine (across all instances of your application) +// executes the callback function for a given key at a time. Other goroutines will wait +// for the lock to be released and then return the cached value. +// +// Locks are implemented using Redis SET NX with a configurable TTL. Lock values use +// UUIDv7 for uniqueness and are prefixed (default: "__redcache:lock:") to avoid +// collisions with application data. +// +// # Context and Timeouts +// +// All operations respect context cancellation. The LockTTL option controls: +// - Maximum time a lock can be held before automatic expiration +// - Timeout for waiting on locks when handling invalidation messages +// - Context timeout for cleanup operations +// +// Use context deadlines to control overall operation timeout: +// +// ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +// defer cancel() +// value, err := client.Get(ctx, time.Minute, key, callback) +// +// # Client-Side Caching +// +// The library uses rueidis client-side caching with Redis invalidation messages. +// When a key is modified in Redis, invalidation messages automatically clear the +// local cache, ensuring consistency across distributed instances. package redcache import ( "context" "errors" + "fmt" "iter" "log/slog" "maps" @@ -11,12 +70,13 @@ import ( "sync" "time" - "github.com/dcbickfo/redcache/internal/cmdx" - "github.com/dcbickfo/redcache/internal/mapsx" - "github.com/dcbickfo/redcache/internal/syncx" "github.com/google/uuid" "github.com/redis/rueidis" "golang.org/x/sync/errgroup" + + "github.com/dcbickfo/redcache/internal/cmdx" + "github.com/dcbickfo/redcache/internal/mapsx" + "github.com/dcbickfo/redcache/internal/syncx" ) type lockEntry struct { @@ -24,15 +84,23 @@ type lockEntry struct { cancel context.CancelFunc } +// Logger defines the logging interface used by CacheAside. +// Implementations must be safe for concurrent use and should handle log levels internally. type Logger interface { + // Error logs error messages. Should be used for unexpected failures or critical issues. Error(msg string, args ...any) + // Debug logs detailed diagnostic information useful for development and troubleshooting. + // Call Debug to record verbose output about internal state, cache operations, or lock handling. + // Debug messages should not include sensitive information and may be omitted in production. + Debug(msg string, args ...any) } type CacheAside struct { - client rueidis.Client - locks syncx.Map[string, *lockEntry] - lockTTL time.Duration - logger Logger + client rueidis.Client + locks syncx.Map[string, *lockEntry] + lockTTL time.Duration + logger Logger + lockPrefix string } type CacheAsideOption struct { @@ -40,24 +108,45 @@ type CacheAsideOption struct { // on locks when handling lost Redis invalidation messages. Defaults to 10 seconds. LockTTL time.Duration ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) - // Logger for logging non-fatal errors. Defaults to slog.Default(). + // Logger for logging errors and debug information. Defaults to slog.Default(). + // The logger should handle log levels internally (e.g., only log Debug if level is enabled). Logger Logger + // LockPrefix for distributed locks. Defaults to "__redcache:lock:". + // Choose a prefix unlikely to conflict with your data keys. + LockPrefix string } func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOption) (*CacheAside, error) { - var err error + // Validate client options + if len(clientOption.InitAddress) == 0 { + return nil, errors.New("at least one Redis address must be provided in InitAddress") + } + + // Validate and set defaults for cache aside options + if caOption.LockTTL < 0 { + return nil, errors.New("LockTTL must not be negative") + } + if caOption.LockTTL > 0 && caOption.LockTTL < 100*time.Millisecond { + return nil, errors.New("LockTTL should be at least 100ms to avoid excessive lock churn") + } if caOption.LockTTL == 0 { caOption.LockTTL = 10 * time.Second } if caOption.Logger == nil { caOption.Logger = slog.Default() } + if caOption.LockPrefix == "" { + caOption.LockPrefix = "__redcache:lock:" + } rca := &CacheAside{ - lockTTL: caOption.LockTTL, - logger: caOption.Logger, + lockTTL: caOption.LockTTL, + logger: caOption.Logger, + lockPrefix: caOption.LockPrefix, } clientOption.OnInvalidations = rca.onInvalidate + + var err error if caOption.ClientBuilder != nil { rca.client, err = caOption.ClientBuilder(clientOption) } else { @@ -69,13 +158,20 @@ func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOpti return rca, nil } +// Client returns the underlying rueidis.Client for advanced operations. +// Most users should not need direct client access. Use with caution as +// direct operations bypass the cache-aside pattern and distributed locking. func (rca *CacheAside) Client() rueidis.Client { return rca.client } func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) { for _, m := range messages { - key, _ := m.ToString() + key, err := m.ToString() + if err != nil { + rca.logger.Error("failed to parse invalidation message", "error", err) + continue + } entry, loaded := rca.locks.LoadAndDelete(key) if loaded { entry.cancel() // Cancel context, which closes the channel @@ -83,8 +179,6 @@ func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) { } } -const prefix = "redcache:" - var ( delKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`) setKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`) @@ -103,21 +197,35 @@ retry: // Store or get existing entry atomically actual, loaded := rca.locks.LoadOrStore(key, newEntry) - // If we successfully stored, return our context + // If we successfully stored, schedule automatic cleanup on expiration if !loaded { + // Use context.AfterFunc to clean up expired entry without blocking goroutine + context.AfterFunc(ctx, func() { + rca.locks.CompareAndDelete(key, newEntry) + }) return ctx.Done() } - // Another goroutine stored first, cancel our context + // Another goroutine stored first, cancel our context to prevent leak cancel() // Check if their context is still active (not cancelled/timed out) select { case <-actual.ctx.Done(): // Context is done - try to atomically delete it and retry - // If CompareAndDelete fails, another goroutine already replaced it - rca.locks.CompareAndDelete(key, actual) - goto retry + if rca.locks.CompareAndDelete(key, actual) { + // We successfully deleted the expired entry, retry + goto retry + } + // CompareAndDelete failed - another goroutine modified it + // Load the new entry and use it + newEntry, loaded := rca.locks.Load(key) + if !loaded { + // Entry was deleted by another goroutine, retry registration + goto retry + } + // Use the new entry's context + return newEntry.ctx.Done() default: // Context is still active - use it return actual.ctx.Done() @@ -182,18 +290,30 @@ func (rca *CacheAside) DelMulti(ctx context.Context, keys ...string) error { return nil } -var errNotFound = errors.New("not found") -var errLockFailed = errors.New("lock failed") +var ( + errNotFound = errors.New("not found") + errLockFailed = errors.New("lock failed") +) + +// ErrLockLost indicates the distributed lock was lost or expired before the value could be set. +// This can occur if the lock TTL expires during callback execution or if Redis invalidates the lock. +var ErrLockLost = errors.New("lock was lost or expired before value could be set") func (rca *CacheAside) tryGet(ctx context.Context, ttl time.Duration, key string) (string, error) { resp := rca.client.DoCache(ctx, rca.client.B().Get().Key(key).Cache(), ttl) val, err := resp.ToString() - if rueidis.IsRedisNil(err) || strings.HasPrefix(val, prefix) { // no response or is a lock value + if rueidis.IsRedisNil(err) || strings.HasPrefix(val, rca.lockPrefix) { // no response or is a lock value + if rueidis.IsRedisNil(err) { + rca.logger.Debug("cache miss - key not found", "key", key) + } else { + rca.logger.Debug("cache miss - lock value found", "key", key) + } return "", errNotFound } if err != nil { return "", err } + rca.logger.Debug("cache hit", "key", key) return val, nil } @@ -205,7 +325,9 @@ func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key } defer func() { if !setVal { - toCtx, cancel := context.WithTimeout(context.Background(), rca.lockTTL) + // Use context.WithoutCancel to preserve tracing/request context while allowing cleanup + cleanupCtx := context.WithoutCancel(ctx) + toCtx, cancel := context.WithTimeout(cleanupCtx, rca.lockTTL) defer cancel() // Best effort unlock - errors are non-fatal as lock will expire if err := rca.unlock(toCtx, key, lockVal); err != nil { @@ -226,27 +348,28 @@ func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key func (rca *CacheAside) tryLock(ctx context.Context, key string) (string, error) { uuidv7, err := uuid.NewV7() if err != nil { - return "", err + return "", fmt.Errorf("failed to generate lock UUID for key %q: %w", key, err) } - lockVal := prefix + uuidv7.String() + lockVal := rca.lockPrefix + uuidv7.String() err = rca.client.Do(ctx, rca.client.B().Set().Key(key).Value(lockVal).Nx().Get().Px(rca.lockTTL).Build()).Error() if !rueidis.IsRedisNil(err) { - return "", errLockFailed + rca.logger.Debug("lock contention - failed to acquire lock", "key", key) + return "", fmt.Errorf("failed to acquire lock for key %q: %w", key, errLockFailed) } + rca.logger.Debug("lock acquired", "key", key, "lockVal", lockVal) return lockVal, nil } func (rca *CacheAside) setWithLock(ctx context.Context, ttl time.Duration, key string, valLock valAndLock) (string, error) { - err := setKeyLua.Exec(ctx, rca.client, []string{key}, []string{valLock.lockVal, valLock.val, strconv.FormatInt(ttl.Milliseconds(), 10)}).Error() - if err != nil { if !rueidis.IsRedisNil(err) { - return "", err + return "", fmt.Errorf("failed to set value for key %q: %w", key, err) } - return "", errors.New("set failed") + rca.logger.Debug("lock lost during set operation", "key", key) + return "", fmt.Errorf("lock lost for key %q: %w", key, ErrLockLost) } - + rca.logger.Debug("value set successfully", "key", key) return valLock.val, nil } @@ -260,7 +383,6 @@ func (rca *CacheAside) GetMulti( keys []string, fn func(ctx context.Context, key []string) (val map[string]string, err error), ) (map[string]string, error) { - res := make(map[string]string, len(keys)) waitLock := make(map[string]<-chan struct{}, len(keys)) @@ -329,9 +451,9 @@ func (rca *CacheAside) tryGetMulti(ctx context.Context, ttl time.Duration, keys if err != nil && rueidis.IsRedisNil(err) { continue } else if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get key %q: %w", keys[i], err) } - if !strings.HasPrefix(val, prefix) { + if !strings.HasPrefix(val, rca.lockPrefix) { res[keys[i]] = val continue } @@ -345,7 +467,6 @@ func (rca *CacheAside) trySetMultiKeyFn( keys []string, fn func(ctx context.Context, key []string) (val map[string]string, err error), ) (map[string]string, error) { - res := make(map[string]string) lockVals, err := rca.tryLockMulti(ctx, keys) @@ -361,7 +482,9 @@ func (rca *CacheAside) trySetMultiKeyFn( } } if len(toUnlock) > 0 { - toCtx, cancel := context.WithTimeout(context.Background(), rca.lockTTL) + // Use context.WithoutCancel to preserve tracing/request context while allowing cleanup + cleanupCtx := context.WithoutCancel(ctx) + toCtx, cancel := context.WithTimeout(cleanupCtx, rca.lockTTL) defer cancel() rca.unlockMulti(toCtx, toUnlock) } @@ -393,7 +516,6 @@ func (rca *CacheAside) trySetMultiKeyFn( } return res, err - } func (rca *CacheAside) tryLockMulti(ctx context.Context, keys []string) (map[string]string, error) { @@ -404,13 +526,16 @@ func (rca *CacheAside) tryLockMulti(ctx context.Context, keys []string) (map[str if err != nil { return nil, err } - lockVals[k] = prefix + uuidv7.String() + lockVals[k] = rca.lockPrefix + uuidv7.String() cmds = append(cmds, rca.client.B().Set().Key(k).Value(lockVals[k]).Nx().Get().Px(rca.lockTTL).Build()) } resps := rca.client.DoMulti(ctx, cmds...) for i, r := range resps { err := r.Error() if !rueidis.IsRedisNil(err) { + if err != nil { + rca.logger.Error("failed to acquire lock", "key", keys[i], "error", err) + } delete(lockVals, keys[i]) } } @@ -422,23 +547,18 @@ type valAndLock struct { lockVal string } -func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration, keyValLock map[string]valAndLock) ([]string, error) { - type keyOrderAndSet struct { - keyOrder []string - setStmts []rueidis.LuaExec - } +type keyOrderAndSet struct { + keyOrder []string + setStmts []rueidis.LuaExec +} +// groupBySlot groups keys by their Redis cluster slot for efficient batching. +func groupBySlot(keyValLock map[string]valAndLock, ttl time.Duration) map[uint16]keyOrderAndSet { stmts := make(map[uint16]keyOrderAndSet) for k, vl := range keyValLock { slot := cmdx.Slot(k) - kos, ok := stmts[slot] - if !ok { - kos = keyOrderAndSet{ - keyOrder: make([]string, 0), - setStmts: make([]rueidis.LuaExec, 0), - } - } + kos := stmts[slot] kos.keyOrder = append(kos.keyOrder, k) kos.setStmts = append(kos.setStmts, rueidis.LuaExec{ Keys: []string{k}, @@ -447,10 +567,15 @@ func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration, stmts[slot] = kos } - out := make([]string, 0) + return stmts +} + +// executeSetStatements executes Lua set statements in parallel, grouped by slot. +func (rca *CacheAside) executeSetStatements(ctx context.Context, stmts map[uint16]keyOrderAndSet) ([]string, error) { keyByStmt := make([][]string, len(stmts)) i := 0 eg, ctx := errgroup.WithContext(ctx) + for _, kos := range stmts { ii := i eg.Go(func() error { @@ -467,17 +592,25 @@ func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration, } return nil }) - i += 1 + i++ } + if err := eg.Wait(); err != nil { return nil, err } + + out := make([]string, 0) for _, keys := range keyByStmt { out = append(out, keys...) } return out, nil } +func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration, keyValLock map[string]valAndLock) ([]string, error) { + stmts := groupBySlot(keyValLock, ttl) + return rca.executeSetStatements(ctx, stmts) +} + func (rca *CacheAside) unlockMulti(ctx context.Context, lockVals map[string]string) { if len(lockVals) == 0 { return diff --git a/cacheaside_test.go b/cacheaside_test.go index cf04b56..54d3bec 100644 --- a/cacheaside_test.go +++ b/cacheaside_test.go @@ -9,14 +9,14 @@ import ( "testing" "time" - "github.com/dcbickfo/redcache" - - "github.com/dcbickfo/redcache/internal/mapsx" "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/redis/rueidis" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache" + "github.com/dcbickfo/redcache/internal/mapsx" ) var addr = []string{"127.0.0.1:6379"} @@ -191,7 +191,7 @@ func TestCacheAside_GetMulti_PartLock(t *testing.T) { } innerClient := client.Client() - lockVal := "redcache:" + uuid.New().String() + lockVal := "__redcache:lock:" + uuid.New().String() err := innerClient.Do(ctx, innerClient.B().Set().Key(keys[0]).Value(lockVal).Nx().Get().Px(time.Millisecond*100).Build()).Error() require.True(t, rueidis.IsRedisNil(err)) @@ -632,7 +632,7 @@ func TestCacheAside_GetParentContextCancellation(t *testing.T) { // Set a lock on the key so Get will wait innerClient := client.Client() - lockVal := "redcache:" + uuid.New().String() + lockVal := "__redcache:lock:" + uuid.New().String() err := innerClient.Do(context.Background(), innerClient.B().Set().Key(key).Value(lockVal).Nx().Get().Px(time.Second*30).Build()).Error() require.True(t, rueidis.IsRedisNil(err)) @@ -651,3 +651,191 @@ func TestCacheAside_GetParentContextCancellation(t *testing.T) { require.Error(t, err) require.ErrorIs(t, err, context.Canceled) } + +// TestConcurrentRegisterRace tests the register() method under high contention +// to ensure the CompareAndDelete race condition fix works correctly +func TestConcurrentRegisterRace(t *testing.T) { + // Use minimum allowed lock TTL to force lock expirations during concurrent access + client, err := redcache.NewRedCacheAside( + rueidis.ClientOption{ + InitAddress: addr, + }, + redcache.CacheAsideOption{ + LockTTL: 100 * time.Millisecond, + }, + ) + require.NoError(t, err) + defer client.Client().Close() + + ctx := context.Background() + key := "key:" + uuid.New().String() + val := "val:" + uuid.New().String() + + callCount := 0 + var mu sync.Mutex + cb := func(ctx context.Context, key string) (string, error) { + mu.Lock() + callCount++ + mu.Unlock() + // Very short sleep to keep test fast while still triggering some lock expirations + time.Sleep(5 * time.Millisecond) + return val, nil + } + + // Run concurrent goroutines to stress test the register race condition fix + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(4) + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + } + wg.Wait() + + // The callback should be called, but we might get multiple calls due to lock expiration + mu.Lock() + defer mu.Unlock() + assert.Greater(t, callCount, 0, "callback should be called at least once") +} + +// TestConcurrentGetSameKeySingleClient tests that multiple goroutines getting +// the same key from a single client instance only triggers one callback when locks don't expire +func TestConcurrentGetSameKeySingleClient(t *testing.T) { + client := makeClient(t, addr) + defer client.Client().Close() + + ctx := context.Background() + key := "key:" + uuid.New().String() + val := "val:" + uuid.New().String() + + callCount := 0 + var mu sync.Mutex + + cb := func(ctx context.Context, key string) (string, error) { + mu.Lock() + callCount++ + mu.Unlock() + return val, nil + } + + // Run multiple iterations with concurrent goroutines, matching existing test pattern + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(4) + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + go func() { + defer wg.Done() + res, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + assert.Equal(t, val, res) + }() + } + + wg.Wait() + + // Callback should only be called once due to distributed locking + mu.Lock() + defer mu.Unlock() + assert.Equal(t, 1, callCount, "callback should only be called once") +} + +// TestConcurrentInvalidation tests that cache invalidation works correctly +// when multiple goroutines are accessing the same keys +func TestConcurrentInvalidation(t *testing.T) { + client := makeClient(t, addr) + defer client.Client().Close() + + ctx := context.Background() + key := "key:" + uuid.New().String() + + callCount := 0 + var mu sync.Mutex + cb := func(ctx context.Context, key string) (string, error) { + mu.Lock() + callCount++ + mu.Unlock() + return "value", nil + } + + // Populate cache + _, err := client.Get(ctx, time.Second*10, key, cb) + require.NoError(t, err) + + mu.Lock() + initialCount := callCount + mu.Unlock() + + // Delete the key + err = client.Del(ctx, key) + require.NoError(t, err) + + // Run multiple iterations with concurrent reads after deletion, matching existing test pattern + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(4) + go func() { + defer wg.Done() + _, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + }() + go func() { + defer wg.Done() + _, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + }() + go func() { + defer wg.Done() + _, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + }() + go func() { + defer wg.Done() + _, err := client.Get(ctx, time.Second*10, key, cb) + assert.NoError(t, err) + }() + } + wg.Wait() + + // Callback should have been invoked at least once more due to invalidation + mu.Lock() + defer mu.Unlock() + assert.Greater(t, callCount, initialCount, "callbacks should be invoked after invalidation") +} diff --git a/internal/cmdx/slot.go b/internal/cmdx/slot.go index 1351291..2c5ee85 100644 --- a/internal/cmdx/slot.go +++ b/internal/cmdx/slot.go @@ -2,6 +2,11 @@ package cmdx // https://redis.io/topics/cluster-spec +const ( + // RedisClusterSlots is the maximum slot number in a Redis cluster (16384 total slots, numbered 0-16383). + RedisClusterSlots = 16383 +) + func Slot(key string) uint16 { var s, e int for ; s < len(key); s++ { @@ -10,7 +15,7 @@ func Slot(key string) uint16 { } } if s == len(key) { - return crc16(key) & 16383 + return crc16(key) & RedisClusterSlots } for e = s + 1; e < len(key); e++ { if key[e] == '}' { @@ -18,9 +23,9 @@ func Slot(key string) uint16 { } } if e == len(key) || e == s+1 { - return crc16(key) & 16383 + return crc16(key) & RedisClusterSlots } - return crc16(key[s+1:e]) & 16383 + return crc16(key[s+1:e]) & RedisClusterSlots } /* diff --git a/internal/cmdx/slot_test.go b/internal/cmdx/slot_test.go new file mode 100644 index 0000000..54ca2b0 --- /dev/null +++ b/internal/cmdx/slot_test.go @@ -0,0 +1,182 @@ +package cmdx_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/dcbickfo/redcache/internal/cmdx" +) + +func TestSlot(t *testing.T) { + tests := []struct { + name string + key string + expected uint16 + }{ + // Basic keys - verified against Redis cluster spec + { + name: "simple key", + key: "key", + expected: 12539, + }, + { + name: "numeric key", + key: "123", + expected: 5970, + }, + { + name: "empty key", + key: "", + expected: 0, + }, + // Hash tags - only the content between { and } is hashed + { + name: "hash tag simple", + key: "{user:1000}:profile", + expected: cmdx.Slot("user:1000"), + }, + { + name: "hash tag at start", + key: "{tag}key", + expected: cmdx.Slot("tag"), + }, + { + name: "hash tag at end", + key: "key{tag}", + expected: cmdx.Slot("tag"), + }, + { + name: "hash tag in middle", + key: "prefix{tag}suffix", + expected: cmdx.Slot("tag"), + }, + // Edge cases with braces + { + name: "empty hash tag", + key: "key{}value", + expected: cmdx.Slot("key{}value"), // Empty tags are ignored + }, + { + name: "no closing brace", + key: "key{value", + expected: cmdx.Slot("key{value"), // No closing brace, whole key hashed + }, + { + name: "only opening brace", + key: "{key", + expected: cmdx.Slot("{key"), + }, + { + name: "only closing brace", + key: "key}", + expected: cmdx.Slot("key}"), + }, + { + name: "multiple hash tags - first wins", + key: "{tag1}{tag2}", + expected: cmdx.Slot("tag1"), + }, + { + name: "nested braces", + key: "{{nested}}", + expected: cmdx.Slot("{nested"), // First { to first } + }, + // Common patterns - these should be deterministic + { + name: "user pattern", + key: "user:1000", + expected: 1649, // Verified against Redis CLUSTER KEYSLOT + }, + { + name: "session pattern", + key: "session:abc123", + expected: 11692, // Verified against Redis CLUSTER KEYSLOT + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cmdx.Slot(tt.key) + assert.Equalf(t, tt.expected, result, "Slot(%q) = %d, want %d", tt.key, result, tt.expected) + }) + } +} + +func TestSlot_Consistency(t *testing.T) { + // Test that the same key always produces the same slot + key := "test:key:123" + slot1 := cmdx.Slot(key) + slot2 := cmdx.Slot(key) + assert.Equal(t, slot1, slot2, "Slot function should be deterministic") +} + +func TestSlot_Distribution(t *testing.T) { + // Test that slots are distributed across the valid range + keys := []string{ + "key1", "key2", "key3", "key4", "key5", + "user:1", "user:2", "user:3", "user:4", "user:5", + "session:a", "session:b", "session:c", "session:d", "session:e", + } + + slots := make(map[uint16]bool) + for _, key := range keys { + slot := cmdx.Slot(key) + assert.LessOrEqualf(t, slot, uint16(16383), "Slot for key %q should be <= 16383", key) + slots[slot] = true + } + + // With 15 different keys, we should have some distribution (not all the same slot) + assert.Greater(t, len(slots), 1, "Keys should distribute across multiple slots") +} + +func TestSlot_HashTagCollision(t *testing.T) { + // Keys with the same hash tag should go to the same slot + keys := []string{ + "{user:1000}:profile", + "{user:1000}:settings", + "{user:1000}:preferences", + } + + expectedSlot := cmdx.Slot("user:1000") + for _, key := range keys { + slot := cmdx.Slot(key) + assert.Equalf(t, expectedSlot, slot, "Key %q with hash tag should map to slot %d", key, expectedSlot) + } +} + +func TestSlot_BoundaryValues(t *testing.T) { + tests := []struct { + name string + key string + }{ + {"single char", "a"}, + {"special chars", "!@#$%^&*()"}, + {"unicode", "你好世界"}, + {"long key", string(make([]byte, 1000))}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + slot := cmdx.Slot(tt.key) + assert.LessOrEqualf(t, slot, uint16(16383), "Slot should be within valid range") + }) + } +} + +func BenchmarkSlot(b *testing.B) { + keys := []string{ + "simple", + "user:1000", + "{tag}key", + "prefix{tag}suffix", + } + + for _, key := range keys { + b.Run(key, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = cmdx.Slot(key) + } + }) + } +} diff --git a/internal/mapsx/maps_test.go b/internal/mapsx/maps_test.go index fe79bc0..ae40851 100644 --- a/internal/mapsx/maps_test.go +++ b/internal/mapsx/maps_test.go @@ -3,8 +3,9 @@ package mapsx_test import ( "testing" - "github.com/dcbickfo/redcache/internal/mapsx" "github.com/stretchr/testify/assert" + + "github.com/dcbickfo/redcache/internal/mapsx" ) func TestKeys(t *testing.T) { diff --git a/internal/syncx/map_test.go b/internal/syncx/map_test.go index dfa86d4..dc71d25 100644 --- a/internal/syncx/map_test.go +++ b/internal/syncx/map_test.go @@ -3,8 +3,9 @@ package syncx_test import ( "testing" - "github.com/dcbickfo/redcache/internal/syncx" "github.com/stretchr/testify/assert" + + "github.com/dcbickfo/redcache/internal/syncx" ) func TestMap_CompareAndDelete(t *testing.T) { diff --git a/internal/syncx/wait_test.go b/internal/syncx/wait_test.go index 56d8bc6..89da67a 100644 --- a/internal/syncx/wait_test.go +++ b/internal/syncx/wait_test.go @@ -6,8 +6,9 @@ import ( "testing" "time" - "github.com/dcbickfo/redcache/internal/syncx" "github.com/stretchr/testify/assert" + + "github.com/dcbickfo/redcache/internal/syncx" ) func delayedSend[T any](ch chan T, val T, delay time.Duration) {