diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f00c367..e939deb 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -2,15 +2,8 @@ name: Go on: [push] jobs: - build: - services: - redis: - # Docker Hub image - image: redis - ports: - - 6379:6379 + unit-tests: runs-on: ubuntu-latest - steps: - uses: actions/checkout@v5 - name: Setup Go @@ -19,21 +12,86 @@ jobs: go-version-file: 'go.mod' - name: Install dependencies run: go mod vendor + - name: Run unit tests (no Redis required) + run: go test -short -v ./... + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version-file: 'go.mod' - name: Run golangci-lint uses: golangci/golangci-lint-action@v8 with: version: v2.6.1 args: --timeout=5m only-new-issues: false - - name: Test with Go + + integration-tests: + services: + redis: + image: redis:7 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version-file: 'go.mod' + - name: Install dependencies + run: go mod vendor + - name: Run integration tests + run: | + set +e + go test -tags=integration -json ./... > TestResults-integration.json + TEST_EXIT_CODE=$? + exit $TEST_EXIT_CODE + - name: Upload integration test results + uses: actions/upload-artifact@v4 + if: always() + with: + name: Integration-results + path: TestResults-integration.json + + distributed-tests: + services: + redis: + image: redis:7 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Setup Go + uses: actions/setup-go@v6 + with: + go-version-file: 'go.mod' + - name: Install dependencies + run: go mod vendor + - name: Run distributed tests run: | set +e - go test -tags=examples -json ./... > TestResults.json + go test -tags=distributed -json ./... > TestResults-distributed.json TEST_EXIT_CODE=$? exit $TEST_EXIT_CODE - - name: Upload Go test results + - name: Upload distributed test results uses: actions/upload-artifact@v4 if: always() with: - name: Go-results - path: TestResults.json \ No newline at end of file + name: Distributed-results + path: TestResults-distributed.json \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index 1a7184f..9002042 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -6,6 +6,9 @@ run: modules-download-mode: readonly build-tags: - examples + - integration + - distributed + - cluster formatters: enable: diff --git a/.mockery.yaml b/.mockery.yaml new file mode 100644 index 0000000..386993b --- /dev/null +++ b/.mockery.yaml @@ -0,0 +1,17 @@ +# Mockery v3 Configuration +# Generate moq-style mocks for all internal interfaces + +# Use matryer/moq style templates +template: matryer + +# Global output configuration +dir: "mocks/{{.SrcPackageName}}" +pkgname: "mock{{.SrcPackageName}}" +filename: "mock_{{.InterfaceName}}.go" + +# Packages to scan +packages: + github.com/dcbickfo/redcache/internal: + config: + all: true + recursive: true diff --git a/CLAUDE.md b/CLAUDE.md index 29daf60..6ee29a1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -18,27 +18,37 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ### Testing -```bash -# Run all tests (requires Redis on localhost:6379) -make test +redcache uses Go build tags to organize tests by category: -# Run tests with single test -go test . -run TestName -v +```bash +# Unit tests only (no Redis required, <0.3s) +make test-unit +go test -short ./... -# Run tests with race detector -go test . -race -count=1 - -# Run specific package tests -go test ./internal/writelock -v +# Integration tests (requires Redis on localhost:6379) +make test-integration +go test -tags=integration ./... # Distributed tests (multi-client coordination) make test-distributed +go test -tags=distributed ./... -# Redis Cluster tests (requires cluster on ports 7000-7005) +# Redis Cluster tests (requires cluster on ports 17000-17005) make test-cluster +go test -tags=cluster ./... + +# Run multiple test suites +go test -tags="integration,distributed" ./... -# Complete test suite (unit + distributed + cluster + examples) -make test-complete +# All tests (comprehensive: unit + integration + distributed + cluster + examples) +make test +go test -tags="integration,distributed,cluster,examples" ./... + +# Run specific test with race detector +go test -tags=integration -run TestName -race -count=1 -v + +# Run specific package tests +go test -tags=integration ./internal/writelock -v ``` ### Docker/Redis @@ -210,21 +220,50 @@ The linter enforces cognitive complexity < 15 per function. When implementing co ### Testing Patterns -**Test Categories:** -1. **Unit tests**: Basic functionality, single client -2. **Distributed tests** (`*_distributed_test.go`): Multi-client coordination -3. **Cluster tests** (`*_cluster_test.go`): Redis Cluster specific (gracefully skip if cluster unavailable) -4. **Edge tests** (`*_edge_test.go`): Race conditions, context cancellation, error handling +**Test Categories with Build Tags:** +1. **Unit tests** (`*_unit_test.go`): No build tag, use `-short` flag. Fast isolated tests with mocks, no Redis required. +2. **Integration tests** (`*_test.go`): `//go:build integration` tag. Test with real Redis for end-to-end functionality. +3. **Distributed tests** (`*_distributed_test.go`): `//go:build distributed` tag. Multi-client coordination tests. +4. **Cluster tests** (`*_cluster_test.go`): `//go:build cluster` tag. Redis Cluster specific tests (ports 17000-17005). +5. **Examples** (`examples/*_test.go`): `//go:build examples` tag. Runnable documentation examples. + +**Test Organization - IMPORTANT:** +- **Always use subtests** with `t.Run()` for test organization +- **Always use `t.Context()`** instead of `context.Background()` in tests +- **Unit tests must check `testing.Short()`**: Skip when NOT in short mode (`if !testing.Short() { t.Skip(...) }`) **Test Naming Convention:** -- `Test__` -- Example: `TestPrimeableCacheAside_SetMulti_ConcurrentWrites` +- Parent test: `Test_` +- Subtests: Descriptive names like `"CacheHit"`, `"CacheMiss_LockAcquired"`, `"CallbackError"` +- Example: `TestCacheAside_Get` with subtests `"CacheHit"` and `"CacheMiss_LockAcquired"` + +**Test Structure Pattern:** +```go +func TestCacheAside_Get(t *testing.T) { + if !testing.Short() { + t.Skip("Skipping unit test in non-short mode") + } + + // Shared setup here if needed + + t.Run("CacheHit", func(t *testing.T) { + ctx := t.Context() // Use test context + // ... test logic + }) + + t.Run("CacheMiss_LockAcquired", func(t *testing.T) { + ctx := t.Context() + // ... test logic + }) +} +``` **Redis Setup:** - Single Redis: Tests assume `localhost:6379` -- Cluster: Tests check ports 7000-7005, skip if unavailable +- Cluster: Tests use ports 17000-17005 - Always use `makeClient(t)` helper for setup - Always defer `client.Close()` +- Always use `t.Context()` for automatic cleanup ### Benchmarking Best Practices @@ -263,18 +302,21 @@ go tool pprof -inuse_space mem.prof 3. **Test sequence**: ```bash - # Unit tests - go test . -run TestNewFeature -v + # Unit tests (with mocks, no Redis) + go test -short -run TestNewFeature -v + + # Integration tests (with real Redis) + go test -tags=integration -run TestNewFeature -v # Race detector - go test . -run TestNewFeature -race -count=1 + go test -tags=integration -run TestNewFeature -race -count=1 # Distributed coordination - go test . -run TestNewFeature_Distributed -v + go test -tags=distributed -run TestNewFeature -v # Cluster support make docker-cluster-up - go test . -run TestNewFeature_Cluster -v + go test -tags=cluster -run TestNewFeature -v ``` 4. **Linter verification**: @@ -353,9 +395,13 @@ redcache/ 5. **Don't create high cognitive complexity** - Extract helper functions to stay under 15 6. **Don't forget cleanup in tests** - Always `defer client.Close()` and `defer cleanup()` 7. **Don't assume keys are in same slot** - Even similar-looking keys may hash differently +8. **Don't use `context.Background()` in tests** - Always use `t.Context()` for automatic cleanup +9. **Don't skip organizing tests with subtests** - Always use `t.Run()` for better structure and isolation +10. **Don't forget build tags** - Unit tests need `testing.Short()` check, integration tests need `//go:build integration` tag ## Additional Resources +- See `TESTING.md` for comprehensive testing guide (build tags, mocks, best practices) - See `README.md` for user-facing documentation - See `DISTRIBUTED_LOCK_SAFETY.md` for lock safety analysis - See `REDIS_CLUSTER.md` for cluster deployment guide diff --git a/Makefile b/Makefile index b42b320..91bf4dd 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help test test-fast test-unit test-distributed test-cluster test-examples test-coverage lint lint-fix build clean vendor install-tools docker-up docker-down docker-cluster-up docker-cluster-down +.PHONY: help test test-fast test-unit test-unit-mocked test-integration test-distributed test-cluster test-examples test-coverage lint lint-fix build clean vendor install-tools mocks docker-up docker-down docker-cluster-up docker-cluster-down # Colors for output CYAN := \033[36m @@ -18,12 +18,16 @@ help: @echo "$(BOLD)Testing:$(RESET)" @echo " $(CYAN)make test$(RESET) - Run all tests including examples (default)" @echo " $(CYAN)make test-fast$(RESET) - Run all tests quickly (no examples)" - @echo " $(CYAN)make test-unit$(RESET) - Run unit tests only (no distributed/cluster)" + @echo " $(CYAN)make test-unit$(RESET) - Run unit tests with mocks (no Redis required)" + @echo " $(CYAN)make test-integration$(RESET) - Run integration tests (requires Redis)" @echo " $(CYAN)make test-distributed$(RESET) - Run distributed tests (multi-client coordination)" @echo " $(CYAN)make test-cluster$(RESET) - Run Redis cluster tests (requires cluster)" @echo " $(CYAN)make test-examples$(RESET) - Run example tests only" @echo " $(CYAN)make test-coverage$(RESET) - Run tests with coverage report" @echo "" + @echo "$(BOLD)Mocking:$(RESET)" + @echo " $(CYAN)make mocks$(RESET) - Generate all mocks using mockery v3" + @echo "" @echo "$(BOLD)Docker/Redis:$(RESET)" @echo " $(CYAN)make docker-up$(RESET) - Start single Redis instance" @echo " $(CYAN)make docker-down$(RESET) - Stop single Redis instance" @@ -53,37 +57,49 @@ install-tools: @echo "$(BOLD)Installed versions:$(RESET)" @asdf current -# Run all tests including examples (default, most comprehensive) +# Run all tests: unit + integration + distributed + cluster + examples (default, most comprehensive) test: - @echo "$(YELLOW)Running all tests including examples...$(RESET)" + @echo "$(YELLOW)Running all tests (unit + integration + distributed + cluster + examples)...$(RESET)" @echo "$(YELLOW)Note: Requires Redis on localhost:6379 AND Redis Cluster on localhost:17000-17005$(RESET)" @echo "$(YELLOW) Start with: make docker-up && make docker-cluster-up$(RESET)" - @go test -tags=examples -v ./... && echo "$(GREEN)✓ All tests passed!$(RESET)" || (echo "$(RED)✗ Tests failed!$(RESET)" && exit 1) + @echo "" + @echo "$(CYAN)Step 1/2: Running unit tests (no Redis)...$(RESET)" + @go test -v ./... && echo "$(GREEN)✓ Unit tests passed!$(RESET)" || (echo "$(RED)✗ Unit tests failed!$(RESET)" && exit 1) + @echo "" + @echo "$(CYAN)Step 2/2: Running integration/distributed/cluster/examples tests...$(RESET)" + @go test -tags="integration,distributed,cluster,examples" -v ./... && echo "$(GREEN)✓ All integration tests passed!$(RESET)" || (echo "$(RED)✗ Integration tests failed!$(RESET)" && exit 1) + @echo "" + @echo "$(GREEN)$(BOLD)✓ All tests passed!$(RESET)" # Run tests quickly without examples test-fast: @echo "$(YELLOW)Running tests (no examples)...$(RESET)" @echo "$(YELLOW)Note: Requires Redis on localhost:6379 AND Redis Cluster on localhost:17000-17005$(RESET)" @echo "$(YELLOW) Start with: make docker-up && make docker-cluster-up$(RESET)" - @go test -v ./... && echo "$(GREEN)✓ Tests passed!$(RESET)" || (echo "$(RED)✗ Tests failed!$(RESET)" && exit 1) + @go test -tags="integration,distributed,cluster" -v ./... && echo "$(GREEN)✓ Tests passed!$(RESET)" || (echo "$(RED)✗ Tests failed!$(RESET)" && exit 1) -# Run only unit tests (no distributed or cluster tests) +# Run only unit tests (no Redis required, no build tags) test-unit: - @echo "$(YELLOW)Running unit tests only (excluding distributed and cluster tests)...$(RESET)" + @echo "$(YELLOW)Running unit tests (no Redis required)...$(RESET)" + @go test -v ./... && echo "$(GREEN)✓ Unit tests passed!$(RESET)" || (echo "$(RED)✗ Unit tests failed!$(RESET)" && exit 1) + +# Run integration tests (requires Redis) +test-integration: + @echo "$(YELLOW)Running integration tests (requires Redis)...$(RESET)" @echo "$(YELLOW)Note: Requires Redis on localhost:6379$(RESET)" - @go test -v -run '^Test[^_]*$$|TestCacheAside_Get$$|TestCacheAside_GetMulti$$|TestPrimeableCacheAside_Set$$' ./... && echo "$(GREEN)✓ Unit tests passed!$(RESET)" || (echo "$(RED)✗ Unit tests failed!$(RESET)" && exit 1) + @go test -tags=integration -v ./... && echo "$(GREEN)✓ Integration tests passed!$(RESET)" || (echo "$(RED)✗ Integration tests failed!$(RESET)" && exit 1) # Run distributed tests (multi-client, single Redis instance) test-distributed: @echo "$(YELLOW)Running distributed tests (multi-client coordination)...$(RESET)" @echo "$(YELLOW)Note: Requires Redis on localhost:6379 (start with: make docker-up)$(RESET)" - @go test -v -run 'Distributed' ./... && echo "$(GREEN)✓ Distributed tests passed!$(RESET)" || (echo "$(RED)✗ Distributed tests failed!$(RESET)" && exit 1) + @go test -tags=distributed -v ./... && echo "$(GREEN)✓ Distributed tests passed!$(RESET)" || (echo "$(RED)✗ Distributed tests failed!$(RESET)" && exit 1) # Run Redis cluster tests (requires cluster setup) test-cluster: @echo "$(YELLOW)Running Redis cluster tests...$(RESET)" @echo "$(YELLOW)Note: Requires Redis Cluster on localhost:17000-17005 (start with: make docker-cluster-up)$(RESET)" - @go test -v -run 'Cluster' ./... && echo "$(GREEN)✓ Cluster tests passed!$(RESET)" || (echo "$(RED)✗ Cluster tests failed!$(RESET)" && exit 1) + @go test -tags=cluster -v ./... && echo "$(GREEN)✓ Cluster tests passed!$(RESET)" || (echo "$(RED)✗ Cluster tests failed!$(RESET)" && exit 1) # Run example tests with build tag test-examples: @@ -95,7 +111,7 @@ test-coverage: @echo "$(YELLOW)Running tests with coverage...$(RESET)" @echo "$(YELLOW)Note: Requires Redis on localhost:6379 AND Redis Cluster on localhost:17000-17005$(RESET)" @echo "$(YELLOW) Start with: make docker-up && make docker-cluster-up$(RESET)" - @go test -v -coverprofile=coverage.out -covermode=atomic ./... && echo "$(GREEN)✓ Tests passed!$(RESET)" || (echo "$(RED)✗ Tests failed!$(RESET)" && exit 1) + @go test -tags="integration,distributed,cluster" -v -coverprofile=coverage.out -covermode=atomic ./... && echo "$(GREEN)✓ Tests passed!$(RESET)" || (echo "$(RED)✗ Tests failed!$(RESET)" && exit 1) @echo "" @echo "$(CYAN)Coverage report:$(RESET)" @go tool cover -func=coverage.out | tail -1 @@ -122,6 +138,12 @@ build-examples: @echo "$(YELLOW)Building all packages (including examples)...$(RESET)" @go build -tags=examples ./... && echo "$(GREEN)✓ Build successful!$(RESET)" || (echo "$(RED)✗ Build failed!$(RESET)" && exit 1) +# Generate mocks using mockery v3 +mocks: + @echo "$(YELLOW)Generating mocks with mockery v3...$(RESET)" + @command -v mockery >/dev/null 2>&1 || { echo "$(RED)Error: mockery is not installed. Install with: go install github.com/vektra/mockery/v3@latest$(RESET)"; exit 1; } + @mockery --config .mockery.yaml && echo "$(GREEN)✓ Mocks generated successfully!$(RESET)" || (echo "$(RED)✗ Mock generation failed!$(RESET)" && exit 1) + # Clean build artifacts clean: @echo "$(YELLOW)Cleaning build artifacts...$(RESET)" diff --git a/README.md b/README.md index 040eaca..c1a3ea7 100644 --- a/README.md +++ b/README.md @@ -212,28 +212,75 @@ tx.Exec("UPDATE accounts SET balance = balance - $1 WHERE id = $2 AND balance >= ## Testing -Run tests with different configurations: +redcache uses a layered testing approach with build tags for fast feedback and comprehensive coverage: + +### Quick Start ```bash -# View all available test targets -make help +# Unit tests with mocks (no Redis required, <0.3s) +make test-unit +# or +go test -short ./... -# Run unit tests (requires single Redis instance) +# Integration tests (requires Redis on localhost:6379) make docker-up -make test +make test-integration +# or +go test -tags=integration ./... -# Run distributed tests (multi-client coordination) +# Distributed coordination tests (requires Redis) make test-distributed +# or +go test -tags=distributed ./... -# Run Redis cluster tests +# Redis cluster tests (requires cluster on ports 17000-17005) make docker-cluster-up make test-cluster +# or +go test -tags=cluster ./... + +# Run multiple test suites together +go test -tags="integration,distributed" ./... -# Run everything (unit + distributed + cluster + examples) +# All tests (comprehensive, includes all categories) make docker-up && make docker-cluster-up -make test-complete +make test +# or +go test -tags="integration,distributed,cluster,examples" ./... +``` + +### Test Categories + +redcache uses Go build tags to organize tests: + +- **Unit Tests**: No build tag, use `-short` flag. Fast isolated tests using mocks, no external dependencies. +- **Integration Tests**: `//go:build integration` tag. Test with real Redis to verify end-to-end functionality. +- **Distributed Tests**: `//go:build distributed` tag. Multi-client coordination tests to verify locking behavior. +- **Cluster Tests**: `//go:build cluster` tag. Redis Cluster specific tests for slot routing and cluster operations. + +### Mock Generation -# Cleanup +redcache uses mockery v3 to generate moq-style mocks for all internal interfaces: + +```bash +# Generate mocks (auto-scans internal/ packages) +make mocks + +# Mocks are generated in mocks/{package}/ directories +# Example: mocks/lockmanager/mock_LockManager.go +``` + +### Writing Tests + +See [TESTING.md](./TESTING.md) for comprehensive testing guide including: +- How to write unit tests with mocks +- Integration test patterns +- Best practices and troubleshooting +- CI/CD integration examples + +### Cleanup + +```bash make docker-down && make docker-cluster-down ``` diff --git a/bench_test.go b/bench_test.go index 9987dcf..b0c191a 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,3 +1,5 @@ +//go:build integration + package redcache_test import ( diff --git a/cacheaside.go b/cacheaside.go index ba5ecd4..a84a2d6 100644 --- a/cacheaside.go +++ b/cacheaside.go @@ -62,50 +62,21 @@ import ( "context" "errors" "fmt" - "iter" "log/slog" - "maps" - "strconv" - "strings" - "sync" "time" - "github.com/google/uuid" "github.com/redis/rueidis" - "golang.org/x/sync/errgroup" - "github.com/dcbickfo/redcache/internal/cmdx" - "github.com/dcbickfo/redcache/internal/lockpool" + "github.com/dcbickfo/redcache/internal/contextx" + "github.com/dcbickfo/redcache/internal/invalidation" + "github.com/dcbickfo/redcache/internal/lockmanager" + "github.com/dcbickfo/redcache/internal/logger" "github.com/dcbickfo/redcache/internal/mapsx" - "github.com/dcbickfo/redcache/internal/syncx" ) -// Pools for map reuse in hot paths to reduce allocations. -var ( - // stringStringMapPool is used for temporary maps in cleanup operations. - stringStringMapPool = sync.Pool{ - New: func() any { - m := make(map[string]string, 100) - return &m - }, - } -) - -type lockEntry struct { - ctx context.Context - cancel context.CancelFunc -} - -// Logger defines the logging interface used by CacheAside. -// Implementations must be safe for concurrent use and should handle log levels internally. -type Logger interface { - // Error logs error messages. Should be used for unexpected failures or critical issues. - Error(msg string, args ...any) - // Debug logs detailed diagnostic information useful for development and troubleshooting. - // Call Debug to record verbose output about internal state, cache operations, or lock handling. - // Debug messages should not include sensitive information and may be omitted in production. - Debug(msg string, args ...any) -} +// Logger is the logging interface used by CacheAside. +// This is a type alias for the shared logger interface. +type Logger = logger.Logger // CacheAside implements the cache-aside pattern with distributed locking for Redis. // It coordinates concurrent access to prevent cache stampedes and ensures only one @@ -119,12 +90,10 @@ type Logger interface { // - Context-aware cleanup ensures locks are released even on errors type CacheAside struct { client rueidis.Client - locks syncx.Map[string, *lockEntry] + lockManager lockmanager.LockManager lockTTL time.Duration logger Logger - lockPrefix string maxRetries int - lockValPool *lockpool.Pool } // CacheAsideOption configures the behavior of the CacheAside instance. @@ -136,11 +105,25 @@ type CacheAsideOption struct { // enough to recover quickly from process failures. LockTTL time.Duration - // ClientBuilder allows customizing the Redis client creation. - // If nil, rueidis.NewClient is used with the provided options. - // This is useful for injecting mock clients in tests or adding middleware. + // Client allows injecting a pre-configured rueidis.Client for testing or custom implementations. + // If provided, this takes precedence and ClientBuilder is ignored. + // This is the primary mechanism for dependency injection in tests. + // Since rueidis.Client is an interface, you can easily mock it for testing. + // Example: + // caOption := CacheAsideOption{Client: mockRedisClient} + Client rueidis.Client + + // ClientBuilder allows customizing the Redis client creation for advanced use cases. + // If nil and Client is nil, rueidis.NewClient is used with the provided options. ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) + // LockManager allows injecting a custom lock manager for testing or alternative implementations. + // If provided, this is used instead of creating a DistributedLockManager. + // This enables mocking lock behavior in unit tests without Redis. + // Example: + // caOption := CacheAsideOption{LockManager: mockLockManager} + LockManager lockmanager.LockManager + // Logger for logging errors and debug information. Defaults to slog.Default(). // The logger should handle log levels internally (e.g., only log Debug if level is enabled). Logger Logger @@ -207,144 +190,79 @@ func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOpti caOption.MaxRetries = 100 } + // Create temporary struct for initialization (needed for callbacks) rca := &CacheAside{ - lockTTL: caOption.LockTTL, - logger: caOption.Logger, - lockPrefix: caOption.LockPrefix, - maxRetries: caOption.MaxRetries, - lockValPool: lockpool.New(caOption.LockPrefix, 10000), + lockTTL: caOption.LockTTL, + logger: caOption.Logger, + maxRetries: caOption.MaxRetries, } - clientOption.OnInvalidations = rca.onInvalidate - var err error - if caOption.ClientBuilder != nil { - rca.client, err = caOption.ClientBuilder(clientOption) - } else { - rca.client, err = rueidis.NewClient(clientOption) - } + // Create invalidation handler (will be owned by lock manager) + invalidHandler := invalidation.NewRedisInvalidationHandler(invalidation.Config{ + LockTTL: caOption.LockTTL, + Logger: caOption.Logger, + }) + + // Initialize Redis client with invalidation support + // Pass invalidHandler directly to avoid race condition during initialization + client, err := initializeClient(clientOption, caOption, rca, invalidHandler) if err != nil { return nil, err } - return rca, nil -} - -// Client returns the underlying rueidis.Client for advanced operations. -// Most users should not need direct client access. Use with caution as -// direct operations bypass the cache-aside pattern and distributed locking. -func (rca *CacheAside) Client() rueidis.Client { - return rca.client -} + rca.client = client -// Close cleans up resources used by the CacheAside instance. -// It cancels all pending lock wait operations and cleans up internal state. -// Note: This does NOT close the underlying Redis client, as that's owned by the caller. -func (rca *CacheAside) Close() { - // Cancel all pending lock wait operations - rca.locks.Range(func(_ string, entry *lockEntry) bool { - if entry != nil { - entry.cancel() // Cancel context, which closes the channel - } - return true - }) -} + // Initialize lock manager (injected or default) - pass invalidation handler + rca.lockManager = initializeLockManager(caOption, rca, invalidHandler) -func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) { - for _, m := range messages { - key, err := m.ToString() - if err != nil { - rca.logger.Error("failed to parse invalidation message", "error", err) - continue - } - entry, loaded := rca.locks.LoadAndDelete(key) - if loaded { - entry.cancel() // Cancel context, which closes the channel - } - } + return rca, nil } -var ( - delKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`) - setKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`) -) - -//nolint:gocognit // Complex due to atomic operations and retry logic -func (rca *CacheAside) register(key string) <-chan struct{} { -retry: - // First check if an entry already exists (common case for concurrent requests) - // This avoids creating a context unnecessarily - if existing, ok := rca.locks.Load(key); ok { - // Check if the existing context is still active - select { - case <-existing.ctx.Done(): - // Context is done - try to atomically delete it and retry - if rca.locks.CompareAndDelete(key, existing) { - goto retry - } - // Another goroutine modified it, try loading again - if newEntry, found := rca.locks.Load(key); found { - return newEntry.ctx.Done() - } - // Entry was deleted, retry - goto retry - default: - // Context is still active, use it - return existing.ctx.Done() - } +// initializeClient sets up the Redis client with invalidation callbacks. +func initializeClient(clientOption rueidis.ClientOption, caOption CacheAsideOption, _ *CacheAside, invalidHandler invalidation.Handler) (rueidis.Client, error) { + // Priority: 1. Injected Client, 2. ClientBuilder, 3. Default rueidis.NewClient + if caOption.Client != nil { + return caOption.Client, nil } - // No existing entry or it was expired, create new one - // The extra time allows the invalidation message to arrive (primary flow) - // while still providing a fallback timeout for missed messages. - // We use a proportional buffer (20% of lockTTL) with a minimum of 200ms - // to account for network delays, ensuring the timeout scales appropriately - // with different lock durations. - buffer := rca.lockTTL / 5 // 20% - if buffer < 200*time.Millisecond { - buffer = 200 * time.Millisecond + // Setup invalidation callback - delegates directly to handler to avoid race condition + // (lockManager hasn't been assigned yet during initialization) + clientOption.OnInvalidations = func(messages []rueidis.RedisMessage) { + invalidHandler.OnInvalidate(messages) } - ctx, cancel := context.WithTimeout(context.Background(), rca.lockTTL+buffer) - newEntry := &lockEntry{ - ctx: ctx, - cancel: cancel, + if caOption.ClientBuilder != nil { + return caOption.ClientBuilder(clientOption) } + return rueidis.NewClient(clientOption) +} - // Store or get existing entry atomically - actual, loaded := rca.locks.LoadOrStore(key, newEntry) - - // If we successfully stored, schedule automatic cleanup on expiration - if !loaded { - // Use context.AfterFunc to clean up expired entry without blocking goroutine - context.AfterFunc(ctx, func() { - rca.locks.CompareAndDelete(key, newEntry) - }) - return ctx.Done() +// initializeLockManager creates or uses injected lock manager. +func initializeLockManager(caOption CacheAsideOption, rca *CacheAside, invalidHandler invalidation.Handler) lockmanager.LockManager { + if caOption.LockManager != nil { + return caOption.LockManager } + return lockmanager.NewDistributedLockManager(lockmanager.Config{ + Client: rca.client, + LockTTL: caOption.LockTTL, + LockPrefix: caOption.LockPrefix, + Logger: caOption.Logger, + InvalidationHandler: invalidHandler, + }) +} - // Another goroutine stored first, cancel our context to prevent leak - cancel() +// Client returns the underlying rueidis.Client for advanced operations. +// Most users should not need direct client access. Use with caution as +// direct operations bypass the cache-aside pattern and distributed locking. +func (rca *CacheAside) Client() rueidis.Client { + return rca.client +} - // Check if their context is still active (not cancelled/timed out) - select { - case <-actual.ctx.Done(): - // Context is done - try to atomically delete it and retry - if rca.locks.CompareAndDelete(key, actual) { - // We successfully deleted the expired entry, retry - goto retry - } - // CompareAndDelete failed - another goroutine modified it - // Load the new entry and use it - waitEntry, ok := rca.locks.Load(key) - if !ok { - // Entry was deleted by another goroutine, retry registration - goto retry - } - // Use the new entry's context - return waitEntry.ctx.Done() - default: - // Context is still active - use it - return actual.ctx.Done() - } +// Close cleans up resources used by the CacheAside instance. +// Note: Invalidation handler cleanup is automatic via context.AfterFunc. +// Note: This does NOT close the underlying Redis client, as that's owned by the caller. +func (rca *CacheAside) Close() { + // Invalidation handler uses context.AfterFunc for automatic cleanup + // No explicit cleanup needed } // Get retrieves a value from cache or computes it using the provided callback function. @@ -402,8 +320,9 @@ func (rca *CacheAside) Get( return rca.getWithRegistration(ctx, ttl, key, fn) } -// getWithRegistration handles the full cache-aside flow with registration +// getWithRegistration handles the full cache-aside flow with lock acquisition. // This is used when we have a cache miss or need to wait for locks. +// Uses the new TryAcquire API which handles pre-registration internally. func (rca *CacheAside) getWithRegistration( ctx context.Context, ttl time.Duration, @@ -411,36 +330,78 @@ func (rca *CacheAside) getWithRegistration( fn func(ctx context.Context, key string) (val string, err error), ) (string, error) { retry: - wait := rca.register(key) + // Check if value was populated while we were waiting val, err := rca.tryGet(ctx, ttl, key) - + if err == nil && val != "" { + return val, nil + } if err != nil && !errors.Is(err, errNotFound) { return "", err } - if err == nil && val != "" { - return val, nil + // Try to acquire lock (pre-registers for invalidations internally) + lockVal, waitHandle, err := rca.lockManager.TryAcquire(ctx, key) + if err != nil { + return "", err } - if val == "" { - val, err = rca.trySetKeyFunc(ctx, ttl, key, fn) + // If we got a wait handle, lock contention occurred - wait and retry + if waitHandle != nil { + if waitErr := waitHandle.Wait(ctx); waitErr != nil { + return "", waitErr + } + goto retry } - if err != nil && !errors.Is(err, errLockFailed) && !errors.Is(err, ErrLockLost) { + // We acquired the lock - populate the cache + val, err = rca.populateAndCache(ctx, ttl, key, lockVal, fn) + if err != nil { + // Check if lock was lost (overridden by ForceSet) + if errors.Is(err, ErrLockLost) { + // Lock was lost, value may have been updated by another process + // Retry immediately - the retry will check cache first + goto retry + } return "", err } + return val, nil +} - if val == "" || errors.Is(err, ErrLockLost) { - // Wait for lock release or invalidation - select { - case <-wait: - goto retry - case <-ctx.Done(): - return "", ctx.Err() +// populateAndCache executes the callback and caches the result, handling lock cleanup. +// Returns ErrLockLost if the lock was overridden during the operation. +func (rca *CacheAside) populateAndCache( + ctx context.Context, + ttl time.Duration, + key string, + lockVal string, + fn func(ctx context.Context, key string) (string, error), +) (string, error) { + setVal := false + defer func() { + if !setVal { + cleanupCtx, cancel := contextx.WithCleanupTimeout(ctx, rca.lockTTL) + defer cancel() + if unlockErr := rca.lockManager.ReleaseLock(cleanupCtx, key, lockVal); unlockErr != nil { + rca.logger.Error("failed to unlock key", "key", key, "error", unlockErr) + } } + }() + + // Execute callback to get the value + val, err := fn(ctx, key) + if err != nil { + return "", err + } + + // Store value in cache + val, err = rca.setWithLock(ctx, ttl, key, lockVal, val) + if err == nil { + setVal = true + return val, nil } - return val, err + // Return error (caller will check for ErrLockLost and retry if needed) + return "", err } // Del removes a key from Redis cache. @@ -483,8 +444,7 @@ func (rca *CacheAside) DelMulti(ctx context.Context, keys ...string) error { } var ( - errNotFound = errors.New("not found") - errLockFailed = errors.New("lock failed") + errNotFound = errors.New("not found") ) // ErrLockLost is now defined in errors.go for consistency across the package. @@ -492,7 +452,7 @@ var ( func (rca *CacheAside) tryGet(ctx context.Context, ttl time.Duration, key string) (string, error) { resp := rca.client.DoCache(ctx, rca.client.B().Get().Key(key).Cache(), ttl) val, err := resp.ToString() - if rueidis.IsRedisNil(err) || strings.HasPrefix(val, rca.lockPrefix) { // no response or is a lock value + if rueidis.IsRedisNil(err) || rca.lockManager.IsLockValue(val) { // no response or is a lock value if rueidis.IsRedisNil(err) { rca.logger.Debug("cache miss - key not found", "key", key) } else { @@ -507,86 +467,29 @@ func (rca *CacheAside) tryGet(ctx context.Context, ttl time.Duration, key string return val, nil } -// trySetKeyFunc is used internally by Get operations to populate cache on miss. -// This is cache-aside behavior - it only acquires locks when the key doesn't exist. -func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key string, fn func(ctx context.Context, key string) (string, error)) (val string, err error) { - setVal := false - lockVal, err := rca.tryLock(ctx, key) - if err != nil { - return "", err - } - defer func() { - if !setVal { - rca.unlockWithCleanup(ctx, key, lockVal) - } - }() - if val, err = fn(ctx, key); err == nil { - val, err = rca.setWithLock(ctx, ttl, key, valAndLock{val, lockVal}) - if err == nil { - setVal = true - } - return val, err - } - return "", err -} +func (rca *CacheAside) setWithLock(ctx context.Context, ttl time.Duration, key string, lockVal string, val string) (string, error) { + // Commit read lock: atomically replace lock value with actual value + succeeded, needsRetry, err := rca.lockManager.CommitReadLocks(ctx, ttl, + map[string]string{key: lockVal}, + map[string]string{key: val}) -// tryLock attempts to acquire a distributed lock for cache-aside operations. -// For Get operations, if a real value already exists, we fail to acquire the lock -// (another process has already populated the cache). -func (rca *CacheAside) tryLock(ctx context.Context, key string) (string, error) { - uuidv7, err := uuid.NewV7() if err != nil { - return "", fmt.Errorf("failed to generate lock UUID for key %q: %w", key, err) - } - lockVal := rca.lockPrefix + uuidv7.String() - err = rca.client.Do(ctx, rca.client.B().Set().Key(key).Value(lockVal).Nx().Get().Px(rca.lockTTL).Build()).Error() - if !rueidis.IsRedisNil(err) { - rca.logger.Debug("lock contention - failed to acquire lock", "key", key) - return "", fmt.Errorf("failed to acquire lock for key %q: %w", key, errLockFailed) + return "", fmt.Errorf("failed to commit lock for key %q: %w", key, err) } - rca.logger.Debug("lock acquired", "key", key, "lockVal", lockVal) - - // Note: CSC subscription already established by tryGet's DoCache call (line 493). - // No need for additional DoCache here - invalidation notifications are already active. - - return lockVal, nil -} - -// generateLockValue creates a unique lock identifier using UUID v7. -// UUID v7 provides time-ordered uniqueness which helps with debugging and monitoring. -// Values are pooled to reduce allocation overhead during high-throughput operations. -func (rca *CacheAside) generateLockValue() string { - // Use pool for better performance (~15% improvement in lock acquisition) - return rca.lockValPool.Get() -} - -func (rca *CacheAside) setWithLock(ctx context.Context, ttl time.Duration, key string, valLock valAndLock) (string, error) { - result := setKeyLua.Exec(ctx, rca.client, []string{key}, []string{valLock.lockVal, valLock.val, strconv.FormatInt(ttl.Milliseconds(), 10)}) - // Check for Redis errors first - if err := result.Error(); err != nil { - if !rueidis.IsRedisNil(err) { - return "", fmt.Errorf("failed to set value for key %q: %w", key, err) - } + // Check if lock was lost + if len(needsRetry) > 0 { rca.logger.Debug("lock lost during set operation", "key", key) return "", fmt.Errorf("lock lost for key %q: %w", key, ErrLockLost) } - // Check the Lua script return value - // The script returns 0 if the lock doesn't match, or the SET result if successful - returnValue, err := result.AsInt64() - if err == nil && returnValue == 0 { - // Lock was lost - the current value doesn't match our lock - rca.logger.Debug("lock lost during set operation - lock value mismatch", "key", key) - return "", fmt.Errorf("lock lost for key %q: %w", key, ErrLockLost) + // Verify success + if len(succeeded) == 0 { + return "", fmt.Errorf("failed to set value for key %q", key) } rca.logger.Debug("value set successfully", "key", key) - return valLock.val, nil -} - -func (rca *CacheAside) unlock(ctx context.Context, key string, lock string) error { - return delKeyLua.Exec(ctx, rca.client, []string{key}, []string{lock}).Error() + return val, nil } // GetMulti retrieves multiple values from cache or computes them using the provided callback. @@ -638,27 +541,27 @@ func (rca *CacheAside) GetMulti( ) (map[string]string, error) { res := make(map[string]string, len(keys)) - waitLock := make(map[string]<-chan struct{}, len(keys)) + // Track which keys still need processing (use empty struct as marker) + waitKeys := make(map[string]struct{}, len(keys)) for _, key := range keys { - waitLock[key] = nil + waitKeys[key] = struct{}{} } retry: - waitLock = rca.registerAll(maps.Keys(waitLock), len(waitLock)) - - vals, err := rca.tryGetMulti(ctx, ttl, mapsx.Keys(waitLock)) + // No upfront registration needed - TryAcquireMulti handles it internally + vals, err := rca.tryGetMulti(ctx, ttl, mapsx.Keys(waitKeys)) if err != nil && !rueidis.IsRedisNil(err) { return nil, err } for k, v := range vals { res[k] = v - delete(waitLock, k) + delete(waitKeys, k) } - if len(waitLock) > 0 { + if len(waitKeys) > 0 { var shouldRetry bool - shouldRetry, err = rca.processRemainingKeys(ctx, ttl, waitLock, res, fn) + shouldRetry, err = rca.processRemainingKeys(ctx, ttl, waitKeys, res, fn) if err != nil { return nil, err } @@ -675,11 +578,11 @@ retry: func (rca *CacheAside) processRemainingKeys( ctx context.Context, ttl time.Duration, - waitLock map[string]<-chan struct{}, + waitKeys map[string]struct{}, res map[string]string, fn func(ctx context.Context, key []string) (val map[string]string, err error), ) (bool, error) { - shouldWait, handleErr := rca.handleMissingKeys(ctx, ttl, waitLock, res, fn) + waitHandles, handleErr := rca.handleMissingKeys(ctx, ttl, waitKeys, res, fn) if handleErr != nil { // Check if locks expired (don't retry in this case) if errors.Is(handleErr, ErrLockLost) { @@ -688,10 +591,10 @@ func (rca *CacheAside) processRemainingKeys( return false, handleErr } - if shouldWait { - // Convert map values to slice for WaitForAll - channels := mapsx.Values(waitLock) - err := syncx.WaitForAll(ctx, channels) + if len(waitHandles) > 0 { + // Wait for all handles to complete + handles := mapsx.Values(waitHandles) + err := lockmanager.WaitForAll(ctx, handles) if err != nil { return false, err } @@ -702,63 +605,117 @@ func (rca *CacheAside) processRemainingKeys( } // handleMissingKeys attempts to acquire locks and populate missing keys. -// Returns true if we should wait for other processes to populate remaining keys. +// Returns WaitHandles for keys that need retry. func (rca *CacheAside) handleMissingKeys( ctx context.Context, ttl time.Duration, - waitLock map[string]<-chan struct{}, + waitKeys map[string]struct{}, res map[string]string, fn func(ctx context.Context, key []string) (val map[string]string, err error), -) (bool, error) { - acquiredVals, err := rca.tryAcquireAndExecute(ctx, ttl, waitLock, fn) +) (map[string]lockmanager.WaitHandle, error) { + acquiredVals, waitHandles, err := rca.tryAcquireAndExecute(ctx, ttl, mapsx.Keys(waitKeys), fn) if err != nil { - return false, err + return nil, err } - // Merge acquired values into result and remove from waitLock + // Merge acquired values into result and remove from waitKeys for k, v := range acquiredVals { res[k] = v - delete(waitLock, k) + delete(waitKeys, k) } - // Return whether there are still keys we need to wait for - // This handles the case where we acquired all locks but some SET operations failed - return len(waitLock) > 0, nil + // Remove successfully acquired keys from wait handles as well + for k := range acquiredVals { + delete(waitHandles, k) + } + + return waitHandles, nil } // tryAcquireAndExecute attempts to acquire locks and execute the callback for missing keys. -// Returns the values retrieved and any error. +// Returns the values retrieved and WaitHandles for keys that need retry. // It uses an optimistic approach: if not all locks can be acquired, it releases them -// and assumes other processes will populate the keys. +// and returns wait handles for the caller to wait on. func (rca *CacheAside) tryAcquireAndExecute( ctx context.Context, ttl time.Duration, - waitLock map[string]<-chan struct{}, + keysNeeded []string, fn func(ctx context.Context, key []string) (val map[string]string, err error), -) (map[string]string, error) { - keysNeeded := mapsx.Keys(waitLock) - lockVals := rca.tryLockMulti(ctx, keysNeeded) +) (map[string]string, map[string]lockmanager.WaitHandle, error) { + // Try to acquire locks (pre-registers internally) + acquired, retryHandles, err := rca.lockManager.TryAcquireMulti(ctx, keysNeeded) + if err != nil { + return nil, nil, err + } // If we got all locks, execute callback - if len(lockVals) == len(keysNeeded) { - vals, err := rca.executeAndCacheMulti(ctx, ttl, keysNeeded, lockVals, fn) - if err != nil { - return nil, err + if len(acquired) == len(keysNeeded) && len(retryHandles) == 0 { + vals, execErr := rca.executeAndCacheMulti(ctx, ttl, keysNeeded, acquired, fn) + if execErr != nil { + return nil, nil, execErr } - return vals, nil + + // Check if some keys lost their locks during execution + if waitHandles := rca.buildWaitHandlesForLostLocks(keysNeeded, vals); waitHandles != nil { + return vals, waitHandles, nil + } + + return vals, nil, nil + } + + // Didn't get all locks - release what we got and return retry handles + if len(acquired) > 0 { + cleanupCtx, cancel := contextx.WithCleanupTimeout(ctx, rca.lockTTL) + defer cancel() + rca.lockManager.ReleaseMultiLocks(cleanupCtx, acquired) } - // Didn't get all locks - release what we got and wait optimistically - rca.unlockMultiWithCleanup(ctx, lockVals) - return nil, nil + // For keys that had errors, add immediate-return handles + rca.addImmediateHandlesForErrors(keysNeeded, retryHandles) + + return nil, retryHandles, nil } -func (rca *CacheAside) registerAll(keys iter.Seq[string], length int) map[string]<-chan struct{} { - res := make(map[string]<-chan struct{}, length) - for key := range keys { - res[key] = rca.register(key) +// buildWaitHandlesForLostLocks creates wait handles for keys that lost their locks during execution. +// Returns nil if all keys were successful. +func (rca *CacheAside) buildWaitHandlesForLostLocks( + keysNeeded []string, + vals map[string]string, +) map[string]lockmanager.WaitHandle { + // All keys succeeded - no wait handles needed + if len(vals) >= len(keysNeeded) { + return nil + } + + // Some keys lost their locks (e.g., due to invalidation/deletion) + // These keys need immediate retry since the invalidation already happened + waitHandles := make(map[string]lockmanager.WaitHandle, len(keysNeeded)) + immediateHandle := rca.lockManager.CreateImmediateWaitHandle() + + for _, key := range keysNeeded { + if _, ok := vals[key]; !ok { + // This key lost its lock - retry immediately (invalidation already happened) + waitHandles[key] = immediateHandle + } + } + + return waitHandles +} + +// addImmediateHandlesForErrors adds immediate-return wait handles for keys that had errors. +// Modifies retryHandles in-place. +func (rca *CacheAside) addImmediateHandlesForErrors( + keysNeeded []string, + retryHandles map[string]lockmanager.WaitHandle, +) { + immediateHandle := rca.lockManager.CreateImmediateWaitHandle() + + for _, key := range keysNeeded { + if _, hasHandle := retryHandles[key]; !hasHandle { + // Key had error - add immediate handle so we retry right away + retryHandles[key] = immediateHandle + } } - return res } func (rca *CacheAside) tryGetMulti(ctx context.Context, ttl time.Duration, keys []string) (map[string]string, error) { @@ -780,7 +737,7 @@ func (rca *CacheAside) tryGetMulti(ctx context.Context, ttl time.Duration, keys } else if err != nil { return nil, fmt.Errorf("failed to get key %q: %w", keys[i], err) } - if !strings.HasPrefix(val, rca.lockPrefix) { + if !rca.lockManager.IsLockValue(val) { res[keys[i]] = val continue } @@ -800,7 +757,7 @@ func (rca *CacheAside) executeAndCacheMulti( // Defer cleanup of locks that weren't successfully set defer func() { - rca.cleanupUnusedLocks(ctx, lockVals, res) + rca.lockManager.CleanupUnusedLocks(ctx, lockVals, res) }() // Execute callback @@ -809,262 +766,16 @@ func (rca *CacheAside) executeAndCacheMulti( return nil, err } - // Build value-lock pairs - vL := make(map[string]valAndLock, len(vals)) - for k, v := range vals { - vL[k] = valAndLock{v, lockVals[k]} - } - - // Cache values with locks - keysSet, err := rca.setMultiWithLock(ctx, ttl, vL) + // Commit read locks: atomically replace lock values with actual values + succeeded, _, err := rca.lockManager.CommitReadLocks(ctx, ttl, lockVals, vals) if err != nil { return nil, err } - // Build result map - for _, keySet := range keysSet { - res[keySet] = vals[keySet] + // Build result map from succeeded keys + for _, key := range succeeded { + res[key] = vals[key] } return res, nil } - -// cleanupUnusedLocks releases locks that were acquired but not successfully cached. -func (rca *CacheAside) cleanupUnusedLocks(ctx context.Context, lockVals map[string]string, successfulKeys map[string]string) { - // Use pooled map for temp toUnlock map - toUnlockPtr := stringStringMapPool.Get().(*map[string]string) - toUnlock := *toUnlockPtr - defer func() { - clear(toUnlock) - stringStringMapPool.Put(toUnlockPtr) - }() - - for key, lockVal := range lockVals { - if _, ok := successfulKeys[key]; !ok { - toUnlock[key] = lockVal - } - } - - if len(toUnlock) == 0 { - return - } - - rca.unlockMultiWithCleanup(ctx, toUnlock) -} - -// buildLockCommands generates lock values and builds SET NX GET commands for the given keys. -// Returns a map of key->lockValue and the commands to execute. -func (rca *CacheAside) buildLockCommands(keys []string) (map[string]string, rueidis.Commands) { - lockVals := make(map[string]string, len(keys)) - cmds := make(rueidis.Commands, 0, len(keys)) - - for _, k := range keys { - lockVal := rca.generateLockValue() - lockVals[k] = lockVal - // SET NX GET returns the old value if key exists, or nil if SET succeeded - cmds = append(cmds, rca.client.B().Set().Key(k).Value(lockVal).Nx().Get().Px(rca.lockTTL).Build()) - } - - return lockVals, cmds -} - -// tryLockMulti attempts to acquire distributed locks for cache-aside operations. -// For Get operations, if real values already exist, we fail to acquire those locks -// (another process has already populated the cache). -func (rca *CacheAside) tryLockMulti(ctx context.Context, keys []string) map[string]string { - lockVals, cmds := rca.buildLockCommands(keys) - - resps := rca.client.DoMulti(ctx, cmds...) - - // Process responses - remove keys we couldn't lock - for i, r := range resps { - key := keys[i] - err := r.Error() - if !rueidis.IsRedisNil(err) { - if err != nil { - rca.logger.Error("failed to acquire lock", "key", key, "error", err) - } else { - // Key already exists (either lock from another process or real value) - rca.logger.Debug("key already exists, cannot acquire lock", "key", key) - } - delete(lockVals, key) - } - } - - // Note: CSC subscriptions already established by tryGetMulti's DoMultiCache call (line 752). - // No need for additional DoMultiCache here - invalidation notifications are already active. - - return lockVals -} - -type valAndLock struct { - val string - lockVal string -} - -type keyOrderAndSet struct { - keyOrder []string - setStmts []rueidis.LuaExec -} - -// groupBySlot groups keys by their Redis cluster slot for Lua script execution. -// This is necessary because Lua scripts in Redis Cluster must execute on a single node. -// Unlike regular SET commands which rueidis.DoMulti routes automatically, Lua scripts -// (LuaExec) need manual slot grouping to ensure each script runs on the correct node. -func groupBySlot(keyValLock map[string]valAndLock, ttl time.Duration) map[uint16]keyOrderAndSet { - // Pre-allocate with estimated capacity (avg ~8 slots for 50 keys) - estimatedSlots := len(keyValLock) / 8 - if estimatedSlots < 1 { - estimatedSlots = 1 - } - stmts := make(map[uint16]keyOrderAndSet, estimatedSlots) - - // Pre-calculate TTL string once - ttlStr := strconv.FormatInt(ttl.Milliseconds(), 10) - - for k, vl := range keyValLock { - slot := cmdx.Slot(k) - kos := stmts[slot] - - // Pre-allocate slices on first access to this slot - if kos.keyOrder == nil { - // Estimate ~6-7 keys per slot for typical workloads - estimatedKeysPerSlot := (len(keyValLock) / estimatedSlots) + 1 - kos.keyOrder = make([]string, 0, estimatedKeysPerSlot) - kos.setStmts = make([]rueidis.LuaExec, 0, estimatedKeysPerSlot) - } - - kos.keyOrder = append(kos.keyOrder, k) - kos.setStmts = append(kos.setStmts, rueidis.LuaExec{ - Keys: []string{k}, - Args: []string{vl.lockVal, vl.val, ttlStr}, - }) - stmts[slot] = kos - } - - return stmts -} - -// processSetResponse checks if a set operation succeeded and returns true if it did. -func (rca *CacheAside) processSetResponse(resp rueidis.RedisResult) (bool, error) { - // Check for Redis errors first - if err := resp.Error(); err != nil { - if !rueidis.IsRedisNil(err) { - return false, err - } - return false, nil - } - - // Check the Lua script return value - // The script returns 0 if the lock doesn't match - returnValue, err := resp.AsInt64() - if err == nil && returnValue == 0 { - // Lock was lost for this key - return false, nil - } - - return true, nil -} - -// executeSetStatements executes Lua set statements in parallel, grouped by slot. -func (rca *CacheAside) executeSetStatements(ctx context.Context, stmts map[uint16]keyOrderAndSet) ([]string, error) { - // Calculate total keys for pre-allocation - totalKeys := 0 - for _, kos := range stmts { - totalKeys += len(kos.keyOrder) - } - - keyByStmt := make([][]string, len(stmts)) - i := 0 - eg, ctx := errgroup.WithContext(ctx) - - for _, kos := range stmts { - ii := i - // Pre-allocate slice for this statement's successful keys - keyByStmt[ii] = make([]string, 0, len(kos.keyOrder)) - - eg.Go(func() error { - setResps := setKeyLua.ExecMulti(ctx, rca.client, kos.setStmts...) - for j, resp := range setResps { - success, err := rca.processSetResponse(resp) - if err != nil { - return err - } - if success { - keyByStmt[ii] = append(keyByStmt[ii], kos.keyOrder[j]) - } - } - return nil - }) - i++ - } - - if err := eg.Wait(); err != nil { - return nil, err - } - - // Pre-allocate output slice with exact capacity - out := make([]string, 0, totalKeys) - for _, keys := range keyByStmt { - out = append(out, keys...) - } - return out, nil -} - -func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration, keyValLock map[string]valAndLock) ([]string, error) { - stmts := groupBySlot(keyValLock, ttl) - return rca.executeSetStatements(ctx, stmts) -} - -// unlockWithCleanup releases a single lock using a cleanup context derived from the parent. -// It preserves tracing/request context while allowing cleanup even if parent is cancelled. -// Best effort - errors are logged but non-fatal as locks will expire. -func (rca *CacheAside) unlockWithCleanup(ctx context.Context, key string, lockVal string) { - cleanupCtx := context.WithoutCancel(ctx) - toCtx, cancel := context.WithTimeout(cleanupCtx, rca.lockTTL) - defer cancel() - if unlockErr := rca.unlock(toCtx, key, lockVal); unlockErr != nil { - rca.logger.Error("failed to unlock key", "key", key, "error", unlockErr) - } -} - -// unlockMultiWithCleanup releases multiple locks using a cleanup context derived from the parent. -// It preserves tracing/request context while allowing cleanup even if parent is cancelled. -func (rca *CacheAside) unlockMultiWithCleanup(ctx context.Context, lockVals map[string]string) { - if len(lockVals) == 0 { - return - } - cleanupCtx := context.WithoutCancel(ctx) - toCtx, cancel := context.WithTimeout(cleanupCtx, rca.lockTTL) - defer cancel() - rca.unlockMulti(toCtx, lockVals) -} - -func (rca *CacheAside) unlockMulti(ctx context.Context, lockVals map[string]string) { - if len(lockVals) == 0 { - return - } - delStmts := make(map[uint16][]rueidis.LuaExec) - for key, lockVal := range lockVals { - slot := cmdx.Slot(key) - delStmts[slot] = append(delStmts[slot], rueidis.LuaExec{ - Keys: []string{key}, - Args: []string{lockVal}, - }) - } - wg := sync.WaitGroup{} - for _, stmts := range delStmts { - wg.Add(1) - go func() { - defer wg.Done() - // Best effort unlock - errors are non-fatal as locks will expire - resps := delKeyLua.ExecMulti(ctx, rca.client, stmts...) - for _, resp := range resps { - if err := resp.Error(); err != nil { - rca.logger.Error("failed to unlock key in batch", "error", err) - } - } - }() - } - wg.Wait() -} diff --git a/cacheaside_cluster_test.go b/cacheaside_cluster_test.go index 552f122..85c8890 100644 --- a/cacheaside_cluster_test.go +++ b/cacheaside_cluster_test.go @@ -1,3 +1,5 @@ +//go:build cluster + package redcache_test import ( @@ -50,7 +52,7 @@ func makeClusterCacheAside(t *testing.T) *redcache.CacheAside { } // Test cluster connectivity - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() innerClient := cacheAside.Client() if pingErr := innerClient.Do(ctx, innerClient.B().Ping().Build()).Error(); pingErr != nil { @@ -71,7 +73,7 @@ func TestCacheAside_Cluster_BasicOperations(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "cluster:basic:" + uuid.New().String() expectedValue := "value:" + uuid.New().String() called := false @@ -98,7 +100,7 @@ func TestCacheAside_Cluster_BasicOperations(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() // Use hash tags to ensure keys are in the same slot keys := []string{ @@ -146,7 +148,7 @@ func TestCacheAside_Cluster_BasicOperations(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() // Generate keys guaranteed to be in different hash slots keys := generateKeysInDifferentSlots("cluster:multiSlot", 3) @@ -187,7 +189,7 @@ func TestCacheAside_Cluster_LargeKeySet(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() // Create 100 keys that will span multiple slots numKeys := 100 @@ -249,7 +251,7 @@ func TestCacheAside_Cluster_ConcurrentOperations(t *testing.T) { client2 := makeClusterCacheAside(t) defer client2.Client().Close() - ctx := context.Background() + ctx := t.Context() // Create keys in different slots key1 := "{shard:1}:key" @@ -305,7 +307,7 @@ func TestCacheAside_Cluster_ConcurrentOperations(t *testing.T) { client2 := makeClusterCacheAside(t) defer client2.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "cluster:concurrent:" + uuid.New().String() var callbackCount atomic.Int32 @@ -354,7 +356,7 @@ func TestCacheAside_Cluster_PartialResults(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() // Create keys across different slots keys := []string{ @@ -406,7 +408,7 @@ func TestCacheAside_Cluster_Invalidation(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "cluster:del:" + uuid.New().String() // Set a value @@ -435,7 +437,7 @@ func TestCacheAside_Cluster_Invalidation(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() // Create keys across different slots keys := []string{ @@ -472,7 +474,7 @@ func TestCacheAside_Cluster_ErrorHandling(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "cluster:error:" + uuid.New().String() callCount := 0 @@ -505,7 +507,7 @@ func TestCacheAside_Cluster_ErrorHandling(t *testing.T) { key := "cluster:cancel:" + uuid.New().String() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) // Set a lock manually to force waiting innerClient := client.Client() @@ -539,7 +541,7 @@ func TestCacheAside_Cluster_StressTest(t *testing.T) { } defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() // Create many keys across all slots numKeys := 50 diff --git a/cacheaside_distributed_test.go b/cacheaside_distributed_test.go index d4c255d..885f897 100644 --- a/cacheaside_distributed_test.go +++ b/cacheaside_distributed_test.go @@ -1,3 +1,5 @@ +//go:build distributed + package redcache_test import ( @@ -20,7 +22,7 @@ import ( // coordinate correctly across multiple clients func TestCacheAside_DistributedCoordination(t *testing.T) { t.Run("multiple clients Get same key - only one calls callback", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "dist:get:" + uuid.New().String() // Create multiple clients @@ -78,7 +80,7 @@ func TestCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("multiple clients GetMulti with overlapping keys", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "dist:multi:1:" + uuid.New().String() key2 := "dist:multi:2:" + uuid.New().String() key3 := "dist:multi:3:" + uuid.New().String() @@ -167,7 +169,7 @@ func TestCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("client invalidation propagates across clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "dist:invalidate:" + uuid.New().String() // Create two clients @@ -218,7 +220,7 @@ func TestCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("lock expiration handled correctly across clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "dist:expire:" + uuid.New().String() // Create client with short lock TTL @@ -242,7 +244,7 @@ func TestCacheAside_DistributedCoordination(t *testing.T) { go func() { defer wg.Done() // This will timeout and lock will expire - timeoutCtx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond) + timeoutCtx, cancel := context.WithTimeout(t.Context(), 400*time.Millisecond) defer cancel() _, getErr := client1.Get(timeoutCtx, 10*time.Second, key, func(ctx context.Context, key string) (string, error) { // Hang until context cancels @@ -273,7 +275,7 @@ func TestCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("concurrent Gets from many clients - stress test", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Create many clients numClients := 20 diff --git a/cacheaside_test.go b/cacheaside_test.go index c14ef7f..41f8426 100644 --- a/cacheaside_test.go +++ b/cacheaside_test.go @@ -1,3 +1,5 @@ +//go:build integration + package redcache_test import ( @@ -18,8 +20,6 @@ import ( "github.com/dcbickfo/redcache" ) -var addr = []string{"127.0.0.1:6379"} - func makeClient(t *testing.T, addr []string) *redcache.CacheAside { client, err := redcache.NewRedCacheAside( rueidis.ClientOption{ @@ -38,7 +38,7 @@ func makeClient(t *testing.T, addr []string) *redcache.CacheAside { func TestCacheAside_Get(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "key:" + uuid.New().String() val := "val:" + uuid.New().String() called := false @@ -62,7 +62,7 @@ func TestCacheAside_Get(t *testing.T) { func TestCacheAside_GetMulti(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 3 { keyAndVals[fmt.Sprintf("key:%d:%s", i, uuid.New().String())] = fmt.Sprintf("val:%d:%s", i, uuid.New().String()) @@ -95,7 +95,7 @@ func TestCacheAside_GetMulti(t *testing.T) { func TestCacheAside_GetMulti_Partial(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 3 { keyAndVals[fmt.Sprintf("key:%d:%s", i, uuid.New().String())] = fmt.Sprintf("val:%d:%s", i, uuid.New().String()) @@ -157,7 +157,7 @@ func TestCacheAside_GetMulti_Partial(t *testing.T) { func TestCacheAside_GetMulti_PartLock(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 3 { keyAndVals[fmt.Sprintf("key:%d:%s", i, uuid.New().String())] = fmt.Sprintf("val:%d:%s", i, uuid.New().String()) @@ -201,7 +201,7 @@ func TestCacheAside_GetMulti_PartLock(t *testing.T) { func TestCacheAside_Del(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "key:" + uuid.New().String() val := "val:" + uuid.New().String() @@ -226,7 +226,7 @@ func TestCBWrapper_GetMultiCheckConcurrent(t *testing.T) { client2 := makeClient(t, addr) defer client2.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 6 { keyAndVals[fmt.Sprintf("key:%d:%s", i, uuid.New().String())] = fmt.Sprintf("val:%d:%s", i, uuid.New().String()) @@ -324,7 +324,7 @@ func TestCBWrapper_GetMultiCheckConcurrentOverlapDifferentClients(t *testing.T) client4 := makeClient(t, addr) defer client4.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 6 { keyAndVals[fmt.Sprintf("key:%d:%s", i, uuid.New().String())] = fmt.Sprintf("val:%d:%s", i, uuid.New().String()) @@ -452,7 +452,7 @@ func TestCBWrapper_GetMultiCheckConcurrentOverlap(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 6 { keyAndVals[fmt.Sprintf("key:%d:%s", i, uuid.New().String())] = fmt.Sprintf("val:%d:%s", i, uuid.New().String()) @@ -579,7 +579,7 @@ func TestCBWrapper_GetMultiCheckConcurrentOverlap(t *testing.T) { func TestCacheAside_DelMulti(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 3 { @@ -614,7 +614,7 @@ func TestCacheAside_GetParentContextCancellation(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) key := "key:" + uuid.New().String() val := "val:" + uuid.New().String() @@ -655,7 +655,7 @@ func TestConcurrentRegisterRace(t *testing.T) { require.NoError(t, err) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "key:" + uuid.New().String() val := "val:" + uuid.New().String() @@ -713,7 +713,7 @@ func TestConcurrentGetSameKeySingleClient(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "key:" + uuid.New().String() val := "val:" + uuid.New().String() @@ -771,7 +771,7 @@ func TestConcurrentInvalidation(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "key:" + uuid.New().String() callCount := 0 @@ -834,7 +834,7 @@ func TestDeleteDuringGetWithLock(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() key := "key:" + uuid.New().String() expectedValue := "val:" + uuid.New().String() @@ -901,7 +901,7 @@ func TestDeleteDuringGetMultiWithLocks(t *testing.T) { client := makeClient(t, addr) defer client.Client().Close() - ctx := context.Background() + ctx := t.Context() keyAndVals := make(map[string]string) for i := range 3 { keyAndVals[fmt.Sprintf("key:%d:%s", i, uuid.New().String())] = fmt.Sprintf("val:%d:%s", i, uuid.New().String()) @@ -975,7 +975,7 @@ func TestDeleteDuringGetMultiWithLocks(t *testing.T) { // TestCacheAside_LockExpiration tests Get/GetMulti behavior when locks expire naturally func TestCacheAside_LockExpiration(t *testing.T) { t.Run("Get with callback exceeding lock TTL", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "get-exceed-ttl:" + uuid.New().String() callCount := 0 @@ -1012,7 +1012,7 @@ func TestCacheAside_LockExpiration(t *testing.T) { }) t.Run("GetMulti with callback exceeding lock TTL", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "getmulti-exceed-1:" + uuid.New().String() key2 := "getmulti-exceed-2:" + uuid.New().String() diff --git a/cacheaside_unit_test.go b/cacheaside_unit_test.go new file mode 100644 index 0000000..81d3f9d --- /dev/null +++ b/cacheaside_unit_test.go @@ -0,0 +1,680 @@ +package redcache + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/redis/rueidis" + "github.com/redis/rueidis/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/dcbickfo/redcache/internal/lockmanager" + mocklockmanager "github.com/dcbickfo/redcache/mocks/lockmanager" + mocklogger "github.com/dcbickfo/redcache/mocks/logger" +) + +// TestCacheAside_Get tests all Get() scenarios using subtests. +func TestCacheAside_Get(t *testing.T) { + t.Run("CacheHit", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + cachedValue := "cached-value" + + // Create mocked logger + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + // Create mocked Redis client + mockClient := mock.NewClient(ctrl) + + // Expect DoCache call for cache hit + mockClient.EXPECT(). + DoCache(gomock.Any(), mock.Match("GET", key), gomock.Any()). + Return(mock.Result(mock.RedisString(cachedValue))) + + // Create mocked lock manager + tryAcquireCalled := false + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + IsLockValueFunc: func(value string) bool { + return false + }, + TryAcquireFunc: func(ctx context.Context, key string) (string, lockmanager.WaitHandle, error) { + tryAcquireCalled = true + return "", nil, nil + }, + } + + // Create CacheAside instance + ca := &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + // Callback should not be called + callbackCalled := false + callback := func(ctx context.Context, key string) (string, error) { + callbackCalled = true + return "should-not-be-called", nil + } + + // Execute Get + result, err := ca.Get(ctx, time.Minute, key, callback) + + // Verify results + require.NoError(t, err) + assert.Equal(t, cachedValue, result) + assert.False(t, callbackCalled, "Callback should not be called on cache hit") + assert.False(t, tryAcquireCalled, "Lock should not be acquired on cache hit") + }) + + t.Run("CacheMiss_LockAcquired", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + lockValue := "lock-value-123" + callbackValue := "fetched-from-db" + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + DoCache(gomock.Any(), mock.Match("GET", key), gomock.Any()). + Return(mock.Result(mock.RedisNil())). + Times(2) + + lockAcquired := false + lockCommitted := false + lockReleased := false + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + IsLockValueFunc: func(value string) bool { + return false + }, + TryAcquireFunc: func(ctx context.Context, k string) (string, lockmanager.WaitHandle, error) { + lockAcquired = true + assert.Equal(t, key, k) + return lockValue, nil, nil + }, + CommitReadLocksFunc: func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (succeeded []string, needsRetry []string, err error) { + lockCommitted = true + require.Contains(t, lockValues, key) + require.Contains(t, actualValues, key) + assert.Equal(t, lockValue, lockValues[key]) + assert.Equal(t, callbackValue, actualValues[key]) + return []string{key}, nil, nil + }, + ReleaseLockFunc: func(ctx context.Context, k string, lv string) error { + lockReleased = true + t.Errorf("Lock should be committed, not released") + return nil + }, + } + + ca := &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + callbackCalled := false + callback := func(ctx context.Context, k string) (string, error) { + callbackCalled = true + assert.Equal(t, key, k) + return callbackValue, nil + } + + result, err := ca.Get(ctx, time.Minute, key, callback) + + require.NoError(t, err) + assert.Equal(t, callbackValue, result) + assert.True(t, lockAcquired) + assert.True(t, callbackCalled) + assert.True(t, lockCommitted) + assert.False(t, lockReleased) + }) + + t.Run("LockContention", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + lockValue := "lock-value-456" + callbackValue := "fetched-after-retry" + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + DoCache(gomock.Any(), mock.Match("GET", key), gomock.Any()). + Return(mock.Result(mock.RedisNil())). + MinTimes(2). + MaxTimes(6) + + waitChan := make(chan struct{}) + close(waitChan) + + mockWaitHandle := &mocklockmanager.MockWaitHandle{ + WaitFunc: func(ctx context.Context) error { + <-waitChan + return nil + }, + } + + retryAttempts := 0 + lockCommitted := false + callbackCalled := false + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + IsLockValueFunc: func(value string) bool { + return false + }, + TryAcquireFunc: func(ctx context.Context, k string) (string, lockmanager.WaitHandle, error) { + retryAttempts++ + if retryAttempts == 1 { + return "", mockWaitHandle, nil + } + return lockValue, nil, nil + }, + CommitReadLocksFunc: func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (succeeded []string, needsRetry []string, err error) { + lockCommitted = true + require.Contains(t, lockValues, key) + require.Contains(t, actualValues, key) + assert.Equal(t, lockValue, lockValues[key]) + assert.Equal(t, callbackValue, actualValues[key]) + return []string{key}, nil, nil + }, + } + + ca := &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + callback := func(ctx context.Context, k string) (string, error) { + callbackCalled = true + assert.Equal(t, key, k) + return callbackValue, nil + } + + result, err := ca.Get(ctx, time.Minute, key, callback) + + require.NoError(t, err) + assert.Equal(t, callbackValue, result) + assert.Equal(t, 2, retryAttempts, "Expected 2 lock acquisition attempts (1 retry)") + assert.True(t, callbackCalled) + assert.True(t, lockCommitted) + }) + + t.Run("CallbackError", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + lockValue := "lock-value-789" + expectedErr := errors.New("database connection failed") + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) {}, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + DoCache(gomock.Any(), mock.Match("GET", key), gomock.Any()). + Return(mock.Result(mock.RedisNil())). + Times(2) + + lockAcquired := false + lockReleased := false + lockCommitted := false + callbackCalled := false + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + IsLockValueFunc: func(value string) bool { + return false + }, + TryAcquireFunc: func(ctx context.Context, k string) (string, lockmanager.WaitHandle, error) { + lockAcquired = true + assert.Equal(t, key, k) + return lockValue, nil, nil + }, + ReleaseLockFunc: func(ctx context.Context, k string, lv string) error { + lockReleased = true + assert.Equal(t, key, k) + assert.Equal(t, lockValue, lv) + return nil + }, + CommitReadLocksFunc: func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (succeeded []string, needsRetry []string, err error) { + lockCommitted = true + t.Errorf("Lock should be released, not committed, when callback returns error") + return nil, nil, nil + }, + } + + ca := &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + callback := func(ctx context.Context, k string) (string, error) { + callbackCalled = true + assert.Equal(t, key, k) + return "", expectedErr + } + + result, err := ca.Get(ctx, time.Minute, key, callback) + + require.Error(t, err) + assert.ErrorIs(t, err, expectedErr) + assert.Empty(t, result) + assert.True(t, lockAcquired) + assert.True(t, callbackCalled) + assert.True(t, lockReleased) + assert.False(t, lockCommitted) + }) +} + +// TestCacheAside_GetMulti tests all GetMulti() scenarios using subtests. +func TestCacheAside_GetMulti(t *testing.T) { + t.Run("AllCacheHits", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + keys := []string{"key1", "key2", "key3"} + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + DoMultiCache(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, commands ...rueidis.CacheableTTL) []rueidis.RedisResult { + // Build key-based response map to handle any key order + cacheValues := map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + results := make([]rueidis.RedisResult, len(commands)) + for i, cmd := range commands { + // Extract the key from the command + cmdStrings := cmd.Cmd.Commands() + if len(cmdStrings) >= 2 { + key := cmdStrings[1] // GET + if val, exists := cacheValues[key]; exists { + results[i] = mock.Result(mock.RedisString(val)) + } else { + results[i] = mock.Result(mock.RedisNil()) + } + } else { + results[i] = mock.Result(mock.RedisNil()) + } + } + return results + }) + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + IsLockValueFunc: func(value string) bool { + return false + }, + } + + ca := &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + callbackCalled := false + callback := func(ctx context.Context, keys []string) (map[string]string, error) { + callbackCalled = true + return nil, errors.New("callback should not be called") + } + + result, err := ca.GetMulti(ctx, time.Minute, keys, callback) + + require.NoError(t, err) + assert.Len(t, result, 3) + assert.Equal(t, "value1", result["key1"]) + assert.Equal(t, "value2", result["key2"]) + assert.Equal(t, "value3", result["key3"]) + assert.False(t, callbackCalled) + }) + + t.Run("AllMiss", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + keys := []string{"miss1", "miss2", "miss3"} + callbackValues := map[string]string{ + "miss1": "fetched1", + "miss2": "fetched2", + "miss3": "fetched3", + } + lockValues := map[string]string{ + "miss1": "lock1", + "miss2": "lock2", + "miss3": "lock3", + } + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + DoMultiCache(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, commands ...rueidis.CacheableTTL) []rueidis.RedisResult { + results := make([]rueidis.RedisResult, len(commands)) + for i := range commands { + results[i] = mock.Result(mock.RedisNil()) + } + return results + }) + + tryAcquireCalled := false + callbackCalled := false + commitCalled := false + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + IsLockValueFunc: func(value string) bool { + return false + }, + TryAcquireMultiFunc: func(ctx context.Context, acquireKeys []string) (map[string]string, map[string]lockmanager.WaitHandle, error) { + tryAcquireCalled = true + return lockValues, nil, nil + }, + CommitReadLocksFunc: func(ctx context.Context, ttl time.Duration, lockVals map[string]string, actualValues map[string]string) (succeeded []string, needsRetry []string, err error) { + commitCalled = true + for key, lockVal := range lockValues { + assert.Equal(t, lockVal, lockVals[key]) + } + for key, expectedVal := range callbackValues { + assert.Equal(t, expectedVal, actualValues[key]) + } + return keys, nil, nil + }, + CleanupUnusedLocksFunc: func(ctx context.Context, acquired map[string]string, used map[string]string) {}, + } + + ca := &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + callback := func(ctx context.Context, cbKeys []string) (map[string]string, error) { + callbackCalled = true + assert.ElementsMatch(t, keys, cbKeys) + return callbackValues, nil + } + + result, err := ca.GetMulti(ctx, time.Minute, keys, callback) + + require.NoError(t, err) + assert.Len(t, result, 3) + assert.Equal(t, "fetched1", result["miss1"]) + assert.Equal(t, "fetched2", result["miss2"]) + assert.Equal(t, "fetched3", result["miss3"]) + assert.True(t, tryAcquireCalled) + assert.True(t, callbackCalled) + assert.True(t, commitCalled) + }) + + t.Run("PartialHits", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + keys := []string{"cached1", "miss1", "cached2", "miss2"} + missedKeys := []string{"miss1", "miss2"} + fetchedValues := map[string]string{ + "miss1": "fetched1", + "miss2": "fetched2", + } + lockValues := map[string]string{ + "miss1": "lock1", + "miss2": "lock2", + } + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + DoMultiCache(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, commands ...rueidis.CacheableTTL) []rueidis.RedisResult { + // Build key-based response map to handle any key order + cacheValues := map[string]string{ + "cached1": "value1", + "cached2": "value2", + } + + results := make([]rueidis.RedisResult, len(commands)) + for i, cmd := range commands { + // Extract the key from the command + cmdStrings := cmd.Cmd.Commands() + if len(cmdStrings) >= 2 { + key := cmdStrings[1] // GET + if val, exists := cacheValues[key]; exists { + results[i] = mock.Result(mock.RedisString(val)) + } else { + results[i] = mock.Result(mock.RedisNil()) + } + } else { + results[i] = mock.Result(mock.RedisNil()) + } + } + return results + }) + + tryAcquireCalled := false + callbackCalled := false + commitCalled := false + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + IsLockValueFunc: func(value string) bool { + return false + }, + TryAcquireMultiFunc: func(ctx context.Context, acquireKeys []string) (map[string]string, map[string]lockmanager.WaitHandle, error) { + tryAcquireCalled = true + assert.ElementsMatch(t, missedKeys, acquireKeys) + result := make(map[string]string) + for _, key := range acquireKeys { + result[key] = lockValues[key] + } + return result, nil, nil + }, + CommitReadLocksFunc: func(ctx context.Context, ttl time.Duration, lockVals map[string]string, actualValues map[string]string) (succeeded []string, needsRetry []string, err error) { + commitCalled = true + assert.Len(t, lockVals, 2) + assert.Len(t, actualValues, 2) + for key, expectedVal := range fetchedValues { + assert.Equal(t, expectedVal, actualValues[key]) + } + return missedKeys, nil, nil + }, + CleanupUnusedLocksFunc: func(ctx context.Context, acquired map[string]string, used map[string]string) {}, + } + + ca := &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + callback := func(ctx context.Context, cbKeys []string) (map[string]string, error) { + callbackCalled = true + assert.ElementsMatch(t, missedKeys, cbKeys) + return fetchedValues, nil + } + + result, err := ca.GetMulti(ctx, time.Minute, keys, callback) + + require.NoError(t, err) + assert.Len(t, result, 4) + assert.Equal(t, "value1", result["cached1"]) + assert.Equal(t, "value2", result["cached2"]) + assert.Equal(t, "fetched1", result["miss1"]) + assert.Equal(t, "fetched2", result["miss2"]) + assert.True(t, tryAcquireCalled) + assert.True(t, callbackCalled) + assert.True(t, commitCalled) + }) +} + +// TestCacheAside_Del tests deletion operations using subtests. +func TestCacheAside_Del(t *testing.T) { + t.Run("SingleKey", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + Do(gomock.Any(), mock.Match("DEL", key)). + Return(mock.Result(mock.RedisInt64(1))) + + ca := &CacheAside{ + client: mockClient, + lockManager: &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + }, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + err := ca.Del(ctx, key) + require.NoError(t, err) + }) + + t.Run("MultipleKeys", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + keys := []string{"key1", "key2", "key3"} + + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, args ...any) {}, + ErrorFunc: func(msg string, args ...any) { + t.Errorf("Unexpected error log: %s %v", msg, args) + }, + } + + mockClient := mock.NewClient(ctrl) + mockClient.EXPECT(). + DoMulti(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, commands ...rueidis.Completed) []rueidis.RedisResult { + results := make([]rueidis.RedisResult, len(commands)) + for i := range commands { + results[i] = mock.Result(mock.RedisInt64(1)) + } + return results + }) + + ca := &CacheAside{ + client: mockClient, + lockManager: &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { + return "__redcache:lock:" + }, + }, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + } + + err := ca.DelMulti(ctx, keys...) + require.NoError(t, err) + }) +} diff --git a/errors.go b/errors.go index 79d08c4..daa4f1d 100644 --- a/errors.go +++ b/errors.go @@ -3,6 +3,8 @@ package redcache import ( "errors" "fmt" + + "github.com/dcbickfo/redcache/internal/errs" ) // Common errors returned by redcache operations. @@ -10,12 +12,14 @@ var ( // ErrNoKeys is returned when an operation is called with an empty key list. ErrNoKeys = errors.New("no keys provided") - // ErrLockFailed is returned when a lock cannot be acquired. - ErrLockFailed = errors.New("failed to acquire lock") + // ErrLockFailed is an alias for the shared lock failed error. + // It is returned when a lock cannot be acquired. + ErrLockFailed = errs.ErrLockFailed - // ErrLockLost indicates the distributed lock was lost or expired before the value could be set. + // ErrLockLost is an alias for the shared lock lost error. + // It indicates the distributed lock was lost or expired before the value could be set. // This can occur if the lock TTL expires during callback execution or if Redis invalidates the lock. - ErrLockLost = errors.New("lock was lost or expired before value could be set") + ErrLockLost = errs.ErrLockLost // ErrInvalidTTL is returned when a TTL value is invalid (e.g., negative or zero). ErrInvalidTTL = errors.New("invalid TTL value") diff --git a/go.mod b/go.mod index 3ec5e4f..afe82ef 100644 --- a/go.mod +++ b/go.mod @@ -6,14 +6,14 @@ require ( github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/redis/rueidis v1.0.68 + github.com/redis/rueidis/mock v1.0.68 github.com/stretchr/testify v1.11.1 - golang.org/x/sync v0.18.0 + go.uber.org/mock v0.6.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - go.uber.org/goleak v1.3.0 // indirect golang.org/x/sys v0.38.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 02e927b..48b0209 100644 --- a/go.sum +++ b/go.sum @@ -12,24 +12,16 @@ github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/rueidis v1.0.56 h1:DwPjFIgas1OMU/uCqBELOonu9TKMYt3MFPq6GtwEWNY= -github.com/redis/rueidis v1.0.56/go.mod h1:g660/008FMYmAF46HG4lmcpcgFNj+jCjCAZUUM+wEbs= github.com/redis/rueidis v1.0.68 h1:gept0E45JGxVigWb3zoWHvxEc4IOC7kc4V/4XvN8eG8= github.com/redis/rueidis v1.0.68/go.mod h1:Lkhr2QTgcoYBhxARU7kJRO8SyVlgUuEkcJO1Y8MCluA= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/redis/rueidis/mock v1.0.68 h1:4KH+DOg8uWrccRpLzjQGJD0xk1YQtBXQhbIYoIGuBro= +github.com/redis/rueidis/mock v1.0.68/go.mod h1:a+M+Z+czot8TnSTFwfbd9Ru20B5iE4pjyWV1aBIbSrU= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= -go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= diff --git a/internal/cmdx/slot.go b/internal/cmdx/slot.go index e596a06..6a32d44 100644 --- a/internal/cmdx/slot.go +++ b/internal/cmdx/slot.go @@ -7,6 +7,25 @@ const ( RedisClusterSlots = 16383 ) +// EstimateSlotDistribution estimates the number of Redis cluster slots and keys per slot +// for pre-allocating maps and slices when grouping operations by slot. +// +// The estimation assumes approximately 8 keys will be distributed across different slots +// (a reasonable assumption for hash-distributed keys), and calculates the expected number +// of keys per slot based on this distribution. +// +// Returns: +// - estimatedSlots: estimated number of unique slots that will contain keys +// - estimatedPerSlot: estimated number of keys per slot +func EstimateSlotDistribution(itemCount int) (estimatedSlots, estimatedPerSlot int) { + estimatedSlots = itemCount / 8 + if estimatedSlots < 1 { + estimatedSlots = 1 + } + estimatedPerSlot = (itemCount / estimatedSlots) + 1 + return +} + func Slot(key string) uint16 { var s, e int for ; s < len(key); s++ { diff --git a/internal/contextx/cleanup.go b/internal/contextx/cleanup.go new file mode 100644 index 0000000..d4999b5 --- /dev/null +++ b/internal/contextx/cleanup.go @@ -0,0 +1,30 @@ +// Package contextx provides context utilities for the redcache package. +package contextx + +import ( + "context" + "time" +) + +// WithCleanupTimeout creates a context for cleanup operations that will not be +// cancelled when the parent context is cancelled. This is useful for ensuring +// cleanup operations (like releasing locks) complete even if the parent context +// times out or is cancelled. +// +// The returned context will have a timeout applied to prevent indefinite blocking. +// The caller must call the returned cancel function to release resources. +// +// Example usage: +// +// ctx, cancel := contextx.WithCleanupTimeout(parentCtx, 5*time.Second) +// defer cancel() +// if err := releaseLock(ctx, lockID); err != nil { +// log.Error("failed to release lock", "error", err) +// } +func WithCleanupTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + // Create a new context that won't be cancelled when parent is cancelled + cleanupCtx := context.WithoutCancel(parent) + + // Add a timeout to prevent indefinite blocking + return context.WithTimeout(cleanupCtx, timeout) +} diff --git a/internal/contextx/cleanup_test.go b/internal/contextx/cleanup_test.go new file mode 100644 index 0000000..6355563 --- /dev/null +++ b/internal/contextx/cleanup_test.go @@ -0,0 +1,115 @@ +package contextx_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache/internal/contextx" +) + +func TestWithCleanupTimeout(t *testing.T) { + t.Run("creates context that survives parent cancellation", func(t *testing.T) { + // Create parent context that we'll cancel + parentCtx, parentCancel := context.WithCancel(context.Background()) + defer parentCancel() + + // Create cleanup context + cleanupCtx, cleanupCancel := contextx.WithCleanupTimeout(parentCtx, 5*time.Second) + defer cleanupCancel() + + // Cancel parent immediately + parentCancel() + + // Cleanup context should still be valid + select { + case <-cleanupCtx.Done(): + t.Fatal("Cleanup context should not be cancelled when parent is cancelled") + case <-time.After(10 * time.Millisecond): + // Success - cleanup context is still active + } + }) + + t.Run("applies timeout to cleanup context", func(t *testing.T) { + parentCtx := context.Background() + + // Create cleanup context with very short timeout + cleanupCtx, cleanupCancel := contextx.WithCleanupTimeout(parentCtx, 50*time.Millisecond) + defer cleanupCancel() + + // Wait for timeout + select { + case <-cleanupCtx.Done(): + // Success - context timed out + assert.ErrorIs(t, cleanupCtx.Err(), context.DeadlineExceeded) + case <-time.After(200 * time.Millisecond): + t.Fatal("Cleanup context should have timed out") + } + }) + + t.Run("cancel function releases resources", func(t *testing.T) { + parentCtx := context.Background() + + cleanupCtx, cleanupCancel := contextx.WithCleanupTimeout(parentCtx, 5*time.Second) + + // Cancel explicitly + cleanupCancel() + + // Context should be done + select { + case <-cleanupCtx.Done(): + // Success - context was cancelled + assert.ErrorIs(t, cleanupCtx.Err(), context.Canceled) + case <-time.After(10 * time.Millisecond): + t.Fatal("Cleanup context should be cancelled after calling cancel function") + } + }) + + t.Run("inherits values from parent context", func(t *testing.T) { + type contextKey string + const key contextKey = "test-key" + expectedValue := "test-value" + + // Create parent with value + parentCtx := context.WithValue(context.Background(), key, expectedValue) + + // Create cleanup context + cleanupCtx, cleanupCancel := contextx.WithCleanupTimeout(parentCtx, 5*time.Second) + defer cleanupCancel() + + // Value should be accessible in cleanup context + actualValue := cleanupCtx.Value(key) + require.NotNil(t, actualValue) + assert.Equal(t, expectedValue, actualValue.(string)) + }) + + t.Run("timeout occurs before parent cancellation", func(t *testing.T) { + // Create parent that will be cancelled later + parentCtx, parentCancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer parentCancel() + + // Create cleanup context with shorter timeout + cleanupCtx, cleanupCancel := contextx.WithCleanupTimeout(parentCtx, 50*time.Millisecond) + defer cleanupCancel() + + // Wait for cleanup context to timeout (should happen first) + select { + case <-cleanupCtx.Done(): + // Success - cleanup context timed out before parent + assert.ErrorIs(t, cleanupCtx.Err(), context.DeadlineExceeded) + case <-time.After(100 * time.Millisecond): + t.Fatal("Cleanup context should have timed out") + } + + // Parent should still be active at this point + select { + case <-parentCtx.Done(): + t.Fatal("Parent context should still be active") + default: + // Success - parent is still active + } + }) +} diff --git a/internal/errs/errors.go b/internal/errs/errors.go new file mode 100644 index 0000000..04a0bae --- /dev/null +++ b/internal/errs/errors.go @@ -0,0 +1,16 @@ +// Package errs provides common error definitions used throughout redcache. +package errs + +import ( + "errors" +) + +// Common errors used by lock management operations. +var ( + // ErrLockFailed is returned when a lock cannot be acquired due to contention. + ErrLockFailed = errors.New("lock acquisition failed") + + // ErrLockLost indicates the distributed lock was lost or expired before the value could be set. + // This can occur if the lock TTL expires during callback execution or if Redis invalidates the lock. + ErrLockLost = errors.New("lock was lost or expired before value could be set") +) diff --git a/internal/errs/errors_test.go b/internal/errs/errors_test.go new file mode 100644 index 0000000..e2a62b5 --- /dev/null +++ b/internal/errs/errors_test.go @@ -0,0 +1,36 @@ +package errs_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/dcbickfo/redcache/internal/errs" +) + +func TestErrors(t *testing.T) { + t.Run("ErrLockFailed has correct message", func(t *testing.T) { + assert.Equal(t, "lock acquisition failed", errs.ErrLockFailed.Error()) + }) + + t.Run("ErrLockLost has correct message", func(t *testing.T) { + assert.Equal(t, "lock was lost or expired before value could be set", errs.ErrLockLost.Error()) + }) + + t.Run("ErrLockFailed can be wrapped and unwrapped", func(t *testing.T) { + wrappedErr := errors.Join(errs.ErrLockFailed, errors.New("additional context")) + assert.ErrorIs(t, wrappedErr, errs.ErrLockFailed) + }) + + t.Run("ErrLockLost can be wrapped and unwrapped", func(t *testing.T) { + wrappedErr := errors.Join(errs.ErrLockLost, errors.New("additional context")) + assert.ErrorIs(t, wrappedErr, errs.ErrLockLost) + }) + + t.Run("errors are distinct", func(t *testing.T) { + assert.NotEqual(t, errs.ErrLockFailed, errs.ErrLockLost) + assert.False(t, errors.Is(errs.ErrLockFailed, errs.ErrLockLost)) + assert.False(t, errors.Is(errs.ErrLockLost, errs.ErrLockFailed)) + }) +} diff --git a/internal/invalidation/handler.go b/internal/invalidation/handler.go new file mode 100644 index 0000000..6191665 --- /dev/null +++ b/internal/invalidation/handler.go @@ -0,0 +1,203 @@ +package invalidation + +import ( + "context" + "iter" + "time" + + "github.com/redis/rueidis" + + "github.com/dcbickfo/redcache/internal/logger" + "github.com/dcbickfo/redcache/internal/syncx" +) + +// Handler manages cache invalidation tracking and coordination. +// It tracks pending requests waiting for cache updates and notifies them +// when invalidation messages arrive from Redis. +// +// This is used to coordinate multiple concurrent requests for the same key: +// - When a cache miss occurs, the first request acquires a lock +// - Subsequent requests register themselves as waiters +// - When the value is computed and stored, Redis sends invalidation messages +// - All waiters are notified to retry their cache lookups. +type Handler interface { + // OnInvalidate processes Redis invalidation messages. + // It notifies all registered waiters for the invalidated keys. + OnInvalidate(messages []rueidis.RedisMessage) + + // Register registers interest in a key and returns a channel that will be closed + // when the key is invalidated or times out. + // Multiple goroutines can register for the same key safely. + Register(key string) <-chan struct{} + + // RegisterAll registers interest in multiple keys at once. + // Returns a map of key -> wait channel. + RegisterAll(keys iter.Seq[string], length int) map[string]<-chan struct{} +} + +// Logger is the logging interface used for invalidation operations. +// This is a type alias for the shared logger interface. +type Logger = logger.Logger + +// lockEntry tracks a registered waiter for a key. +type lockEntry struct { + ctx context.Context + cancel context.CancelFunc +} + +// Config holds configuration for the invalidation handler. +type Config struct { + // LockTTL is the timeout for waiting on invalidations. + // This should match the lock TTL used for distributed locking. + LockTTL time.Duration + + // Logger for error reporting. + Logger Logger +} + +// RedisInvalidationHandler implements Handler using a sync.Map to track waiters. +type RedisInvalidationHandler struct { + lockTTL time.Duration + logger Logger + locks *lockEntryMap // map[string]*lockEntry +} + +// lockEntryMap is a type-safe wrapper around syncx.Map for lock entries. +type lockEntryMap struct { + m *syncx.Map[string, *lockEntry] +} + +// NewRedisInvalidationHandler creates a new invalidation handler. +func NewRedisInvalidationHandler(cfg Config) *RedisInvalidationHandler { + return &RedisInvalidationHandler{ + lockTTL: cfg.LockTTL, + logger: cfg.Logger, + locks: &lockEntryMap{ + m: syncx.NewMap[string, *lockEntry](), + }, + } +} + +// OnInvalidate processes Redis invalidation messages. +func (h *RedisInvalidationHandler) OnInvalidate(messages []rueidis.RedisMessage) { + for _, m := range messages { + key, err := m.ToString() + if err != nil { + h.logger.Error("failed to parse invalidation message", "error", err) + continue + } + entry, loaded := h.locks.m.LoadAndDelete(key) + if loaded { + entry.cancel() // Cancel context, which closes the channel + } + } +} + +// Register registers interest in a key and returns a channel that will be closed +// when the key is invalidated or times out. +// +//nolint:gocognit // Complex due to atomic operations and retry logic +func (h *RedisInvalidationHandler) Register(key string) <-chan struct{} { +retry: + // First check if an entry already exists (common case for concurrent requests) + // This avoids creating a context unnecessarily + if existing, ok := h.locks.m.Load(key); ok { + // Check if the existing context is still active + select { + case <-existing.ctx.Done(): + // Context is done - try to atomically delete it and retry + if h.locks.m.CompareAndDelete(key, existing) { + goto retry + } + // Another goroutine modified it, try loading again + if newEntry, found := h.locks.m.Load(key); found { + return newEntry.ctx.Done() + } + // Entry was deleted, retry + goto retry + default: + // Context is still active, use it + return existing.ctx.Done() + } + } + + // No existing entry or it was expired, create new one + // The extra time allows the invalidation message to arrive (primary flow) + // while still providing a fallback timeout for missed messages. + // We use a proportional buffer (20% of lockTTL) with a minimum of 200ms + // to account for network delays, ensuring the timeout scales appropriately + // with different lock durations. + buffer := h.lockTTL / 5 // 20% + if buffer < 200*time.Millisecond { + buffer = 200 * time.Millisecond + } + ctx, cancel := context.WithTimeout(context.Background(), h.lockTTL+buffer) + + newEntry := &lockEntry{ + ctx: ctx, + cancel: cancel, + } + + // Store or get existing entry atomically + actual, loaded := h.locks.m.LoadOrStore(key, newEntry) + + // If we successfully stored, schedule automatic cleanup on expiration + if !loaded { + // Use context.AfterFunc to clean up expired entry without blocking goroutine + context.AfterFunc(ctx, func() { + h.locks.m.CompareAndDelete(key, newEntry) + }) + return ctx.Done() + } + + // Another goroutine stored first, cancel our context to prevent leak + cancel() + + // Check if their context is still active (not cancelled/timed out) + select { + case <-actual.ctx.Done(): + // Context is done - try to atomically delete it and retry + if h.locks.m.CompareAndDelete(key, actual) { + // We successfully deleted the expired entry, retry + goto retry + } + // CompareAndDelete failed - another goroutine modified it + // Load the new entry and use it + waitEntry, ok := h.locks.m.Load(key) + if !ok { + // Entry was deleted by another goroutine, retry registration + goto retry + } + return waitEntry.ctx.Done() + default: + // Context is still active, use it + return actual.ctx.Done() + } +} + +// RegisterAll registers interest in multiple keys at once. +func (h *RedisInvalidationHandler) RegisterAll(keys iter.Seq[string], length int) map[string]<-chan struct{} { + res := make(map[string]<-chan struct{}, length) + for key := range keys { + res[key] = h.Register(key) + } + return res +} + +// WaitForSingleLock waits for a single lock to be released via invalidation or timeout. +// This is a helper function for waiting on invalidation channels with proper timeout handling. +func WaitForSingleLock(ctx context.Context, waitChan <-chan struct{}, lockTTL time.Duration) error { + timer := time.NewTimer(lockTTL) + defer timer.Stop() + + select { + case <-waitChan: + // Lock released via invalidation + return nil + case <-timer.C: + // Lock TTL expired + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/internal/invalidation/handler_test.go b/internal/invalidation/handler_test.go new file mode 100644 index 0000000..9f26437 --- /dev/null +++ b/internal/invalidation/handler_test.go @@ -0,0 +1,315 @@ +//go:build integration + +package invalidation_test + +import ( + "context" + "testing" + "time" + + "github.com/redis/rueidis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache/internal/invalidation" +) + +// noopLogger is a no-op logger for testing +type noopLogger struct{} + +func (noopLogger) Debug(msg string, args ...any) {} +func (noopLogger) Error(msg string, args ...any) {} + +// makeHandler creates an invalidation handler for testing +func makeHandler(t *testing.T) invalidation.Handler { + t.Helper() + + return invalidation.NewRedisInvalidationHandler(invalidation.Config{ + LockTTL: 5 * time.Second, + Logger: noopLogger{}, + }) +} + +// TestHandler_Register tests single key registration +func TestHandler_Register(t *testing.T) { + handler := makeHandler(t) + + t.Run("returns channel for key", func(t *testing.T) { + ch := handler.Register("key1") + require.NotNil(t, ch) + + // Channel should be open initially + select { + case <-ch: + t.Fatal("Channel should not be closed yet") + default: + // Expected - channel is still open + } + }) + + t.Run("multiple registrations for same key share wait channel", func(t *testing.T) { + ch1 := handler.Register("key2") + ch2 := handler.Register("key2") + + assert.NotNil(t, ch1) + assert.NotNil(t, ch2) + // Both channels should be open initially + select { + case <-ch1: + t.Fatal("ch1 should not be closed yet") + default: + } + select { + case <-ch2: + t.Fatal("ch2 should not be closed yet") + default: + } + }) +} + +// TestHandler_RegisterAll tests batch key registration +func TestHandler_RegisterAll(t *testing.T) { + handler := makeHandler(t) + + t.Run("registers multiple keys", func(t *testing.T) { + keys := []string{"m1", "m2", "m3"} + keysSeq := func(yield func(string) bool) { + for _, k := range keys { + if !yield(k) { + return + } + } + } + + channels := handler.RegisterAll(keysSeq, len(keys)) + assert.Len(t, channels, 3) + + for _, key := range keys { + ch, exists := channels[key] + assert.True(t, exists) + assert.NotNil(t, ch) + } + }) + + t.Run("handles empty keys", func(t *testing.T) { + emptySeq := func(yield func(string) bool) { + // Empty iterator + } + + channels := handler.RegisterAll(emptySeq, 0) + assert.Empty(t, channels) + }) +} + +// TestHandler_OnInvalidate tests invalidation message processing +func TestHandler_OnInvalidate(t *testing.T) { + handler := makeHandler(t) + + t.Run("handles empty invalidation messages", func(t *testing.T) { + // Should not panic + handler.OnInvalidate(nil) + handler.OnInvalidate([]rueidis.RedisMessage{}) + }) + + // Note: We cannot easily test invalidation message processing without + // a real Redis connection, as RedisMessage is an opaque type from rueidis. + // The actual invalidation behavior is tested via integration tests in + // the main package (cacheaside_test.go, primeable_cacheaside_test.go). +} + +// TestWaitForSingleLock tests the wait helper function +func TestWaitForSingleLock(t *testing.T) { + t.Run("returns immediately when channel closed", func(t *testing.T) { + ch := make(chan struct{}) + close(ch) + + ctx := t.Context() + err := invalidation.WaitForSingleLock(ctx, ch, 5*time.Second) + assert.NoError(t, err) + }) + + t.Run("waits for channel close", func(t *testing.T) { + ch := make(chan struct{}) + + // Close after delay + go func() { + time.Sleep(50 * time.Millisecond) + close(ch) + }() + + ctx := t.Context() + start := time.Now() + err := invalidation.WaitForSingleLock(ctx, ch, 5*time.Second) + duration := time.Since(start) + + assert.NoError(t, err) + assert.Greater(t, duration, 40*time.Millisecond) + }) + + t.Run("returns on context cancellation", func(t *testing.T) { + ch := make(chan struct{}) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := invalidation.WaitForSingleLock(ctx, ch, 5*time.Second) + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("returns on lock TTL timeout", func(t *testing.T) { + ch := make(chan struct{}) + + ctx := t.Context() + err := invalidation.WaitForSingleLock(ctx, ch, 50*time.Millisecond) + assert.NoError(t, err) // TTL timeout is not an error, just a signal to retry + }) +} + +// TestHandler_TTLExpiration tests channel timeout behavior +func TestHandler_TTLExpiration(t *testing.T) { + // Use longer TTL for testing timeout (buffer is 20% with 200ms minimum) + shortHandler := invalidation.NewRedisInvalidationHandler(invalidation.Config{ + LockTTL: 100 * time.Millisecond, + Logger: noopLogger{}, + }) + + t.Run("channel closes on TTL timeout", func(t *testing.T) { + ch := shortHandler.Register("timeout1") + + // Wait for channel to close due to TTL + // TTL is 100ms, buffer is 200ms (minimum), so total is 300ms + start := time.Now() + select { + case <-ch: + duration := time.Since(start) + // Should close after TTL + buffer (100ms + 200ms = 300ms) + assert.Greater(t, duration, 250*time.Millisecond) + assert.Less(t, duration, 400*time.Millisecond) + case <-time.After(500 * time.Millisecond): + t.Fatal("Channel should have closed due to TTL timeout") + } + }) + + t.Run("multiple registrations share timeout", func(t *testing.T) { + ch1 := shortHandler.Register("timeout2") + // Give it a moment to establish the entry + time.Sleep(10 * time.Millisecond) + ch2 := shortHandler.Register("timeout2") + ch3 := shortHandler.Register("timeout2") + + // All should close at roughly the same time + timeout := time.After(500 * time.Millisecond) + + select { + case <-ch1: + // Expected + case <-timeout: + t.Fatal("ch1 should have closed") + } + + // ch2 and ch3 should close shortly after (or already be closed) + select { + case <-ch2: + // Expected + case <-time.After(50 * time.Millisecond): + t.Fatal("ch2 should have closed") + } + + select { + case <-ch3: + // Expected + case <-time.After(50 * time.Millisecond): + t.Fatal("ch3 should have closed") + } + }) +} + +// TestHandler_ConcurrentAccess tests thread safety +func TestHandler_ConcurrentAccess(t *testing.T) { + handler := makeHandler(t) + + t.Run("concurrent registrations", func(t *testing.T) { + done := make(chan bool) + + // Spawn multiple goroutines registering keys + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + handler.Register("concurrent") + } + done <- true + }(i) + } + + // Wait for all to complete + for i := 0; i < 10; i++ { + <-done + } + + // Should not panic or race + }) + + t.Run("concurrent register and onInvalidate", func(t *testing.T) { + done := make(chan bool) + + // Spawn registerers + go func() { + for i := 0; i < 50; i++ { + handler.Register("race") + time.Sleep(1 * time.Millisecond) + } + done <- true + }() + + // Spawn invalidators calling OnInvalidate with empty messages + go func() { + for i := 0; i < 50; i++ { + handler.OnInvalidate(nil) + time.Sleep(1 * time.Millisecond) + } + done <- true + }() + + // Wait for both + <-done + <-done + + // Should not panic or race + }) +} + +// TestHandler_IteratorConversion tests iter.Seq usage +func TestHandler_IteratorConversion(t *testing.T) { + handler := makeHandler(t) + + t.Run("slice to iterator conversion", func(t *testing.T) { + keys := []string{"s1", "s2", "s3"} + + // Convert slice to iter.Seq + keysSeq := func(yield func(string) bool) { + for _, k := range keys { + if !yield(k) { + return + } + } + } + + channels := handler.RegisterAll(keysSeq, len(keys)) + assert.Len(t, channels, 3) + }) + + t.Run("generator iterator", func(t *testing.T) { + // Create iterator that generates keys + genSeq := func(yield func(string) bool) { + for i := 0; i < 5; i++ { + if !yield("gen" + string(rune('0'+i))) { + return + } + } + } + + channels := handler.RegisterAll(genSeq, 5) + assert.Len(t, channels, 5) + }) +} diff --git a/internal/lockmanager/lockmanager.go b/internal/lockmanager/lockmanager.go new file mode 100644 index 0000000..5880880 --- /dev/null +++ b/internal/lockmanager/lockmanager.go @@ -0,0 +1,1014 @@ +// Package lockmanager provides distributed lock management for cache-aside operations. +// This package extracts lock-related responsibilities from CacheAside to follow the +// Single Responsibility Principle (SOLID). +package lockmanager + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/redis/rueidis" + + "github.com/dcbickfo/redcache/internal/cmdx" + "github.com/dcbickfo/redcache/internal/invalidation" + "github.com/dcbickfo/redcache/internal/lockpool" + "github.com/dcbickfo/redcache/internal/logger" + "github.com/dcbickfo/redcache/internal/luascript" +) + +// WaitHandle represents a handle for waiting on lock release. +// It encapsulates the invalidation mechanism, hiding implementation details. +type WaitHandle interface { + // Wait blocks until the lock is released or timeout/context cancellation occurs. + Wait(ctx context.Context) error +} + +// invalidationWaitHandle implements WaitHandle using Redis invalidation notifications. +type invalidationWaitHandle struct { + waitChan <-chan struct{} + lockTTL time.Duration +} + +// Wait blocks until the lock is released via invalidation or timeout. +func (h *invalidationWaitHandle) Wait(ctx context.Context) error { + timer := time.NewTimer(h.lockTTL) + defer timer.Stop() + + select { + case <-h.waitChan: + // Lock released via invalidation + return nil + case <-timer.C: + // Lock TTL expired, safe to retry + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// LockManager defines the interface for distributed lock management. +// Implementations handle lock acquisition, release, and cleanup for cache-aside operations. +type LockManager interface { + // TryAcquire attempts to acquire a lock optimistically (pre-registration pattern). + // Pre-registers for invalidations BEFORE trying to acquire to avoid race conditions. + // + // Returns: + // - lockValue: non-empty if lock was acquired + // - retry: non-nil WaitHandle if lock contention occurred (caller should wait and retry) + // - err: real error (not lock contention) + // + // Usage pattern (CacheAside optimistic retry): + // retry: + // lockVal, waitHandle, err := lockMgr.TryAcquire(ctx, key) + // if lockVal != "" { /* got lock */ } + // if waitHandle != nil { + // waitHandle.Wait(ctx) + // goto retry + // } + TryAcquire(ctx context.Context, key string) (lockValue string, retry WaitHandle, err error) + + // TryAcquireMulti attempts to acquire locks for multiple keys optimistically. + // Pre-registers for invalidations before attempting acquisition. + // + // Returns: + // - acquired: map of successfully acquired locks + // - retry: map of keys that failed with wait handles for retry + // - err: critical error during acquisition + TryAcquireMulti(ctx context.Context, keys []string) (acquired map[string]string, retry map[string]WaitHandle, err error) + + // AcquireLockBlocking acquires a lock, blocking until successful. + // Handles all retry logic internally using invalidation notifications. + // + // This is the blocking pattern used by write operations that must eventually + // acquire the lock before proceeding. + // + // Returns: + // - lockValue: the acquired lock value + // - err: context cancellation or critical error + AcquireLockBlocking(ctx context.Context, key string) (lockValue string, err error) + + // AcquireMultiLocksBlocking acquires locks for multiple keys, blocking until successful. + // Returns a map of all successfully acquired locks. + // If context is cancelled, releases any acquired locks before returning. + AcquireMultiLocksBlocking(ctx context.Context, keys []string) (map[string]string, error) + + // ReleaseLock releases a previously acquired lock. + // The lock value must match to prevent releasing locks held by others. + ReleaseLock(ctx context.Context, key, lockValue string) error + + // ReleaseMultiLocks releases multiple locks in a batch operation. + ReleaseMultiLocks(ctx context.Context, lockValues map[string]string) + + // CleanupUnusedLocks releases locks that weren't used (e.g., partial failures). + // Takes the full set of acquired locks and the subset that were successfully used. + CleanupUnusedLocks(ctx context.Context, acquiredLocks, usedKeys map[string]string) + + // GenerateLockValue creates a unique lock identifier. + // Uses pooling for performance in high-throughput scenarios. + GenerateLockValue() string + + // CheckKeyLocked checks if a key currently has an active lock. + // Returns true if the key holds a lock value (not a real cached value). + CheckKeyLocked(ctx context.Context, key string) bool + + // CheckMultiKeysLocked checks multiple keys for active locks. + // Returns a slice of keys that currently have locks. + CheckMultiKeysLocked(ctx context.Context, keys []string) []string + + // IsLockValue checks if the given value is a lock (not a real cached value). + // This is useful when you already have the value and want to check if it's a lock. + IsLockValue(val string) bool + + // LockPrefix returns the lock prefix used by this manager. + // This is used by WriteLockManager to ensure cohesive lock checking. + LockPrefix() string + + // OnInvalidate processes Redis invalidation messages. + // This is called by the Redis client when invalidation messages arrive. + // Delegates to the internal invalidation handler. + OnInvalidate(messages []rueidis.RedisMessage) + + // WaitForKey returns a WaitHandle for waiting on a key's invalidation. + // This is used when you need to wait for a lock to be released without acquiring it yourself. + // The caller should ensure they're subscribed to the key (via DoCache) before waiting. + WaitForKey(key string) WaitHandle + + // WaitForKeyWithSubscription registers for lock invalidations and subscribes to + // Redis client-side cache for the key in the correct order to avoid race conditions. + // + // This method combines two operations: + // 1. Register for invalidations (WaitForKey) - MUST happen first + // 2. Subscribe via DoCache and fetch current value + // + // This ordering ensures no invalidation messages are missed between registration + // and subscription. Returns the wait handle and the current value. + // + // Returns: + // - waitHandle: handle for waiting on key invalidation + // - currentValue: the current value from Redis (may be empty if key doesn't exist) + // - err: error from DoCache operation + WaitForKeyWithSubscription(ctx context.Context, key string, cacheTTL time.Duration) (waitHandle WaitHandle, currentValue string, err error) + + // WaitForKeyWithRetry waits for a key's lock to be released with periodic retry support. + // This subscribes to the key, waits for invalidation, but also returns on ticker timeout + // to allow the caller to retry lock acquisition. + // + // Returns nil if: + // - Lock was released (invalidation received) + // - Ticker fired (time to retry) + // - Context cancelled + WaitForKeyWithRetry(ctx context.Context, key string, ticker *time.Ticker) error + + // CommitReadLocks atomically replaces read lock values with real values using CAS. + // Only succeeds for keys where we still hold the exact lock value. + // This is for READ locks only - no backup/restore capability needed. + // + // Returns: + // - succeeded: slice of keys that were successfully committed + // - needsRetry: slice of keys that lost their locks (CAS failed) + // - err: critical error during execution + // + // Usage: After acquiring read locks and executing callback, use this to atomically + // replace the lock placeholders with actual values. If some keys lost their locks + // (e.g., due to invalidation or ForceSet), they will be returned in needsRetry. + CommitReadLocks(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (succeeded []string, needsRetry []string, err error) + + // CreateImmediateWaitHandle creates a WaitHandle that returns immediately without waiting. + // Used for non-contention errors or when invalidation already occurred. + // This allows callers to retry immediately rather than waiting for lock release. + CreateImmediateWaitHandle() WaitHandle +} + +// DistributedLockManager implements LockManager using Redis SET NX commands. +// This is a single-instance lock pattern suitable for cache coordination. +// +// Lock Safety Notes: +// - Suitable for: Cache coordination, preventing thundering herd, efficiency optimizations +// - NOT suitable for: Financial transactions, inventory management, correctness-critical operations +// - See DISTRIBUTED_LOCK_SAFETY.md for detailed safety analysis +type DistributedLockManager struct { + client rueidis.Client + lockTTL time.Duration + lockPrefix string + lockValPool *lockpool.Pool + logger Logger + invalidationHandler invalidation.Handler +} + +// Logger is the logging interface used by LockManager. +// This is a type alias for the shared logger interface. +type Logger = logger.Logger + +// Config holds configuration for creating a DistributedLockManager. +type Config struct { + Client rueidis.Client + LockTTL time.Duration + LockPrefix string + Logger Logger + InvalidationHandler invalidation.Handler +} + +// NewDistributedLockManager creates a new lock manager with the given configuration. +func NewDistributedLockManager(cfg Config) *DistributedLockManager { + return &DistributedLockManager{ + client: cfg.Client, + lockTTL: cfg.LockTTL, + lockPrefix: cfg.LockPrefix, + lockValPool: lockpool.New(cfg.LockPrefix, 10000), + logger: cfg.Logger, + invalidationHandler: cfg.InvalidationHandler, + } +} + +// TryAcquire attempts to acquire a lock optimistically (pre-registration pattern). +// Pre-registers for invalidations BEFORE trying to acquire to avoid race conditions. +// +// Returns: +// - lockValue: non-empty if lock was acquired +// - retry: non-nil WaitHandle if lock contention occurred (caller should wait and retry) +// - err: real error (not lock contention) +// +// Usage pattern (CacheAside optimistic retry): +// +// retry: +// lockVal, waitHandle, err := lockMgr.TryAcquire(ctx, key) +// if lockVal != "" { /* got lock */ } +// if waitHandle != nil { +// waitHandle.Wait(ctx) +// goto retry +// } +func (dlm *DistributedLockManager) TryAcquire(ctx context.Context, key string) (lockValue string, retry WaitHandle, err error) { + // Pre-register for invalidations BEFORE lock attempt to avoid race condition + // where lock is released between our failed attempt and registration + waitChan := dlm.invalidationHandler.Register(key) + + // Subscribe to key using DoCache to ensure we receive invalidation messages + _ = dlm.client.DoCache(ctx, dlm.client.B().Get().Key(key).Cache(), dlm.lockTTL) + + // Generate lock value + lockVal := dlm.GenerateLockValue() + + // Try to acquire lock: SET NX GET PX + err = dlm.client.Do( + ctx, + dlm.client.B().Set().Key(key).Value(lockVal).Nx().Get().Px(dlm.lockTTL).Build(), + ).Error() + + // Success: err == redis.Nil (key didn't exist, lock acquired) + if rueidis.IsRedisNil(err) { + dlm.logger.Debug("lock acquired", "key", key, "lockVal", lockVal) + return lockVal, nil, nil + } + + // Check if this is a Redis error vs lock contention + if err != nil { + dlm.logger.Debug("lock acquisition error", "key", key, "error", err) + return "", nil, fmt.Errorf("failed to acquire lock for key %q: %w", key, err) + } + + // Lock contention (err == nil means key exists) - return wait handle + dlm.logger.Debug("lock contention - failed to acquire lock", "key", key) + return "", &invalidationWaitHandle{waitChan: waitChan, lockTTL: dlm.lockTTL}, nil +} + +// TryAcquireMulti attempts to acquire locks for multiple keys optimistically. +// Pre-registers for invalidations before attempting acquisition. +// +// Returns: +// - acquired: map of successfully acquired locks +// - retry: map of keys that failed with wait handles for retry +// - err: critical error during acquisition +func (dlm *DistributedLockManager) TryAcquireMulti(ctx context.Context, keys []string) (acquired map[string]string, retry map[string]WaitHandle, err error) { + if len(keys) == 0 { + return nil, nil, nil + } + + // Check for context cancellation upfront + if ctx.Err() != nil { + return nil, nil, ctx.Err() + } + + // Pre-register for all keys + waitChans := dlm.invalidationHandler.RegisterAll(func(yield func(string) bool) { + for _, k := range keys { + if !yield(k) { + return + } + } + }, len(keys)) + + // Subscribe to all keys using DoMultiCache + cmds := make([]rueidis.CacheableTTL, 0, len(keys)) + for _, key := range keys { + cmds = append(cmds, rueidis.CacheableTTL{ + Cmd: dlm.client.B().Get().Key(key).Cache(), + TTL: dlm.lockTTL, + }) + } + _ = dlm.client.DoMultiCache(ctx, cmds...) + + // Build and execute lock commands + lockVals, lockCmds := dlm.buildLockCommands(keys) + resps := dlm.client.DoMulti(ctx, lockCmds...) + + // Process responses + acquired = make(map[string]string) + retry = make(map[string]WaitHandle) + criticalErrors := 0 + + for i, r := range resps { + key := keys[i] + respErr := r.Error() + + // Success: err == redis.Nil (key didn't exist, SET succeeded) + if rueidis.IsRedisNil(respErr) { + acquired[key] = lockVals[key] + continue + } + + // Redis error (not just contention) + if respErr != nil { + dlm.logger.Debug("lock acquisition error in batch", "key", key, "error", respErr) + // Count critical errors (connection failures, context cancellation, etc.) + // but don't fail immediately - try to acquire what we can + criticalErrors++ + continue + } + + // Lock contention (err == nil means key exists) - add to retry map + dlm.logger.Debug("lock contention in batch", "key", key) + retry[key] = &invalidationWaitHandle{ + waitChan: waitChans[key], + lockTTL: dlm.lockTTL, + } + } + + // If ALL operations failed with errors, return an error + if criticalErrors > 0 && len(acquired) == 0 && len(retry) == 0 { + return nil, nil, fmt.Errorf("failed to acquire any locks: %d/%d keys had errors", criticalErrors, len(keys)) + } + + dlm.logger.Debug("acquired locks in batch", "acquired", len(acquired), "retry", len(retry), "requested", len(keys)) + return acquired, retry, nil +} + +// AcquireLockBlocking acquires a lock, blocking until successful. +// Handles all retry logic internally using invalidation notifications. +// +// This is the blocking pattern used by write operations that must eventually +// acquire the lock before proceeding. +// +// Returns: +// - lockValue: the acquired lock value +// - err: context cancellation or critical error +func (dlm *DistributedLockManager) AcquireLockBlocking(ctx context.Context, key string) (lockValue string, err error) { + for { + // Generate lock value + lockVal := dlm.GenerateLockValue() + + // Try to acquire lock + lockErr := dlm.client.Do( + ctx, + dlm.client.B().Set().Key(key).Value(lockVal).Nx().Get().Px(dlm.lockTTL).Build(), + ).Error() + + // Success: err == redis.Nil (key didn't exist, lock acquired) + if rueidis.IsRedisNil(lockErr) { + dlm.logger.Debug("lock acquired (blocking)", "key", key, "lockVal", lockVal) + return lockVal, nil + } + + // Check if this is a Redis error vs lock contention + if lockErr != nil { + dlm.logger.Debug("lock acquisition error (blocking)", "key", key, "error", lockErr) + return "", fmt.Errorf("failed to acquire lock for key %q: %w", key, lockErr) + } + + // Lock contention - register for invalidations (post-register pattern) + dlm.logger.Debug("lock contention (blocking) - waiting for release", "key", key) + waitChan := dlm.invalidationHandler.Register(key) + + // Subscribe to key using DoCache to ensure we receive invalidation messages + _ = dlm.client.DoCache(ctx, dlm.client.B().Get().Key(key).Cache(), dlm.lockTTL) + + // Wait for invalidation or timeout + timer := time.NewTimer(dlm.lockTTL) + select { + case <-ctx.Done(): + timer.Stop() + return "", ctx.Err() + case <-waitChan: + timer.Stop() + // Lock was released (invalidation event), retry + continue + case <-timer.C: + // Lock TTL expired, safe to retry + continue + } + } +} + +// AcquireMultiLocksBlocking acquires locks for multiple keys, blocking until successful. +// Returns a map of all successfully acquired locks. +// If context is cancelled, releases any acquired locks before returning. +func (dlm *DistributedLockManager) AcquireMultiLocksBlocking(ctx context.Context, keys []string) (map[string]string, error) { + if len(keys) == 0 { + return nil, nil + } + + acquired := make(map[string]string, len(keys)) + remaining := keys + + for len(remaining) > 0 { + // Check context before each attempt + if err := dlm.checkContextAndCleanup(ctx, acquired); err != nil { + return nil, err + } + + // Pre-register and subscribe to remaining keys + waitChans := dlm.registerAndSubscribe(ctx, remaining) + + // Try to acquire locks for remaining keys + newRemaining, needsWait := dlm.tryAcquireBatch(ctx, remaining, acquired) + remaining = newRemaining + + // If we acquired all locks, we're done + if len(remaining) == 0 { + dlm.logger.Debug("acquired all locks (blocking)", "count", len(acquired)) + return acquired, nil + } + + // Wait for any lock to be released if we had contention + if needsWait { + if err := dlm.waitForAnyLockRelease(ctx, remaining, waitChans, acquired); err != nil { + return nil, err + } + } + } + + dlm.logger.Debug("acquired all locks (blocking)", "count", len(acquired)) + return acquired, nil +} + +// checkContextAndCleanup checks if context is cancelled and releases acquired locks if so. +func (dlm *DistributedLockManager) checkContextAndCleanup(ctx context.Context, acquired map[string]string) error { + if ctx.Err() != nil { + if len(acquired) > 0 { + dlm.ReleaseMultiLocks(ctx, acquired) + } + return ctx.Err() + } + return nil +} + +// registerAndSubscribe pre-registers for invalidations and subscribes to all keys. +func (dlm *DistributedLockManager) registerAndSubscribe(ctx context.Context, keys []string) map[string]<-chan struct{} { + // Pre-register for all remaining keys + waitChans := dlm.invalidationHandler.RegisterAll(func(yield func(string) bool) { + for _, k := range keys { + if !yield(k) { + return + } + } + }, len(keys)) + + // Subscribe to all remaining keys + cmds := make([]rueidis.CacheableTTL, 0, len(keys)) + for _, key := range keys { + cmds = append(cmds, rueidis.CacheableTTL{ + Cmd: dlm.client.B().Get().Key(key).Cache(), + TTL: dlm.lockTTL, + }) + } + _ = dlm.client.DoMultiCache(ctx, cmds...) + + return waitChans +} + +// tryAcquireBatch attempts to acquire locks for all keys and returns remaining keys and whether to wait. +func (dlm *DistributedLockManager) tryAcquireBatch(ctx context.Context, keys []string, acquired map[string]string) (remaining []string, needsWait bool) { + // Try to acquire locks for remaining keys + lockVals, lockCmds := dlm.buildLockCommands(keys) + resps := dlm.client.DoMulti(ctx, lockCmds...) + + // Process responses + newRemaining := make([]string, 0) + + for i, r := range resps { + key := keys[i] + respErr := r.Error() + + // Success: err == redis.Nil (key didn't exist, SET succeeded) + if rueidis.IsRedisNil(respErr) { + acquired[key] = lockVals[key] + continue + } + + // Redis error (not just contention) + if respErr != nil { + dlm.logger.Debug("lock acquisition error in blocking batch", "key", key, "error", respErr) + newRemaining = append(newRemaining, key) + continue + } + + // Lock contention - need to wait and retry + needsWait = true + newRemaining = append(newRemaining, key) + } + + return newRemaining, needsWait +} + +// waitForAnyLockRelease waits for at least one lock to be released or timeout. +func (dlm *DistributedLockManager) waitForAnyLockRelease( + ctx context.Context, + remaining []string, + waitChans map[string]<-chan struct{}, + acquired map[string]string, +) error { + dlm.logger.Debug("waiting for locks (blocking)", "remaining", len(remaining)) + + // Wait for the first lock to be released + timer := time.NewTimer(dlm.lockTTL) + defer timer.Stop() + + waitCases := make([]<-chan struct{}, 0, len(remaining)+1) + waitCases = append(waitCases, ctx.Done()) + for _, key := range remaining { + if ch, ok := waitChans[key]; ok { + waitCases = append(waitCases, ch) + } + } + + // Wait for first signal + select { + case <-ctx.Done(): + if len(acquired) > 0 { + dlm.ReleaseMultiLocks(ctx, acquired) + } + return ctx.Err() + case <-timer.C: + // Timeout, retry all remaining + return nil + default: + // Check all wait channels + if dlm.checkAnyChannelReady(waitCases[1:]) { + return nil + } + // No releases yet, wait for timer + <-timer.C + return nil + } +} + +// checkAnyChannelReady checks if any of the channels are ready (non-blocking). +func (dlm *DistributedLockManager) checkAnyChannelReady(channels []<-chan struct{}) bool { + for _, ch := range channels { + select { + case <-ch: + return true + default: + } + } + return false +} + +// ReleaseLock releases a single lock using a Lua script to ensure atomicity. +func (dlm *DistributedLockManager) ReleaseLock(ctx context.Context, key, lockValue string) error { + return delKeyLua.Exec(ctx, dlm.client, []string{key}, []string{lockValue}).Error() +} + +// ReleaseMultiLocks releases multiple locks in a batch operation. +func (dlm *DistributedLockManager) ReleaseMultiLocks(ctx context.Context, lockValues map[string]string) { + if len(lockValues) == 0 { + return + } + + for key, lockVal := range lockValues { + // Fire-and-forget: Don't wait for responses + // Locks will expire via TTL if deletion fails + if err := dlm.ReleaseLock(ctx, key, lockVal); err != nil { + dlm.logger.Debug("failed to release lock", "key", key, "error", err) + } + } +} + +// CleanupUnusedLocks releases locks that were acquired but not used. +// This handles partial failure scenarios where some operations succeed and others fail. +func (dlm *DistributedLockManager) CleanupUnusedLocks( + ctx context.Context, + acquiredLocks map[string]string, + usedKeys map[string]string, +) { + toUnlock := make(map[string]string) + + for key, lockVal := range acquiredLocks { + if _, used := usedKeys[key]; !used { + toUnlock[key] = lockVal + } + } + + if len(toUnlock) > 0 { + dlm.ReleaseMultiLocks(ctx, toUnlock) + } +} + +// GenerateLockValue creates a unique lock identifier using a pool for performance. +func (dlm *DistributedLockManager) GenerateLockValue() string { + // Use pool for better performance (~15% improvement in lock acquisition) + return dlm.lockValPool.Get() +} + +// CheckKeyLocked checks if a key currently has an active lock. +// Uses DoCache to subscribe to invalidations (consistent with CheckMultiKeysLocked). +func (dlm *DistributedLockManager) CheckKeyLocked(ctx context.Context, key string) bool { + resp := dlm.client.DoCache(ctx, dlm.client.B().Get().Key(key).Cache(), dlm.lockTTL) + val, err := resp.ToString() + if err != nil { + return false + } + return dlm.IsLockValue(val) +} + +// CheckMultiKeysLocked checks multiple keys for active locks. +// Returns keys that currently have locks. +// Uses DoMultiCache to ensure invalidation subscriptions for locked keys. +func (dlm *DistributedLockManager) CheckMultiKeysLocked(ctx context.Context, keys []string) []string { + if len(keys) == 0 { + return nil + } + + // Build cacheable commands to check for locks + // Use lockTTL to ensure subscription lasts long enough to receive invalidations + cmds := make([]rueidis.CacheableTTL, 0, len(keys)) + for _, key := range keys { + cmds = append(cmds, rueidis.CacheableTTL{ + Cmd: dlm.client.B().Get().Key(key).Cache(), + TTL: dlm.lockTTL, + }) + } + + lockedKeys := make([]string, 0) + resps := dlm.client.DoMultiCache(ctx, cmds...) + for i, resp := range resps { + val, err := resp.ToString() + if err == nil && dlm.IsLockValue(val) { + lockedKeys = append(lockedKeys, keys[i]) + } + } + + return lockedKeys +} + +// IsLockValue checks if the given value is a lock (has the lock prefix). +func (dlm *DistributedLockManager) IsLockValue(val string) bool { + return strings.HasPrefix(val, dlm.lockPrefix) +} + +// LockPrefix returns the lock prefix used by this manager. +// This is used by WriteLockManager to ensure cohesive lock checking. +func (dlm *DistributedLockManager) LockPrefix() string { + return dlm.lockPrefix +} + +// OnInvalidate processes Redis invalidation messages. +// Delegates to the internal invalidation handler. +func (dlm *DistributedLockManager) OnInvalidate(messages []rueidis.RedisMessage) { + dlm.invalidationHandler.OnInvalidate(messages) +} + +// WaitForKey returns a WaitHandle for waiting on a key's invalidation. +// This is used when you need to wait for a lock to be released without acquiring it yourself. +func (dlm *DistributedLockManager) WaitForKey(key string) WaitHandle { + waitChan := dlm.invalidationHandler.Register(key) + return &invalidationWaitHandle{ + waitChan: waitChan, + lockTTL: dlm.lockTTL, + } +} + +// WaitForKeyWithSubscription implements LockManager.WaitForKeyWithSubscription. +// Registers for invalidations first, then subscribes via DoCache and fetches the current value. +func (dlm *DistributedLockManager) WaitForKeyWithSubscription( + ctx context.Context, + key string, + cacheTTL time.Duration, +) (WaitHandle, string, error) { + // STEP 1: Register for invalidations FIRST (before DoCache subscription) + // This ensures we won't miss any invalidation that arrives between DoCache and registration + waitHandle := dlm.WaitForKey(key) + + // STEP 2: Subscribe to Redis client-side cache AND fetch current value + // Since we already registered, we won't miss any invalidation messages + resp := dlm.client.DoCache(ctx, dlm.client.B().Get().Key(key).Cache(), cacheTTL) + val, err := resp.ToString() + + return waitHandle, val, err +} + +// WaitForKeyWithRetry waits for a key's lock to be released with periodic retry support. +// Subscribes to the key via DoCache, then waits for either invalidation or ticker timeout. +func (dlm *DistributedLockManager) WaitForKeyWithRetry(ctx context.Context, key string, ticker *time.Ticker) error { + // Get wait handle and subscribe to invalidations + waitHandle := dlm.WaitForKey(key) + _ = dlm.CheckKeyLocked(ctx, key) // Ensures subscription to invalidations + + // Wait with ticker support for periodic retries + waitChan := make(chan struct{}) + go func() { + _ = waitHandle.Wait(ctx) + close(waitChan) + }() + + select { + case <-waitChan: + // Lock was released (invalidation event) + return nil + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + // Periodic retry timeout + return nil + } +} + +// WaitForAny waits for any of the provided WaitHandles to complete. +// Returns when the first handle completes or context is cancelled. +func WaitForAny(ctx context.Context, handles []WaitHandle) error { + if len(handles) == 0 { + return nil + } + + // Create a channel to signal completion + done := make(chan struct{}) + defer close(done) + + // Start goroutines for each handle + for _, h := range handles { + go func(handle WaitHandle) { + _ = handle.Wait(ctx) + select { + case done <- struct{}{}: + case <-ctx.Done(): + } + }(h) + } + + // Wait for first completion or context cancellation + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// WaitForAll waits for all provided WaitHandles to complete concurrently. +// Returns when all handles complete or context is cancelled. +// Handles are waited on concurrently to support cases where multiple locks are released +// simultaneously via a single invalidation event. +func WaitForAll(ctx context.Context, handles []WaitHandle) error { + if len(handles) == 0 { + return nil + } + + // Channel to collect completion signals + done := make(chan error, len(handles)) + + // Launch goroutine for each handle + for _, h := range handles { + go func(handle WaitHandle) { + done <- handle.Wait(ctx) + }(h) + } + + // Wait for all handles to complete or context cancellation + for i := 0; i < len(handles); i++ { + select { + case err := <-done: + if err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil +} + +// buildLockCommands generates lock values and builds SET NX GET commands for the given keys. +func (dlm *DistributedLockManager) buildLockCommands(keys []string) (map[string]string, rueidis.Commands) { + lockVals := make(map[string]string, len(keys)) + cmds := make(rueidis.Commands, 0, len(keys)) + + for _, k := range keys { + lockVal := dlm.GenerateLockValue() + lockVals[k] = lockVal + // SET NX GET returns the old value if key exists, or nil if SET succeeded + cmds = append(cmds, dlm.client.B().Set().Key(k).Value(lockVal).Nx().Get().Px(dlm.lockTTL).Build()) + } + + return lockVals, cmds +} + +// LockChecker checks if a key has an active lock. +type LockChecker interface { + // HasLock checks if the given value is a lock (not a real cached value). + HasLock(val string) bool + // CheckKeyLocked checks if a key currently has a lock in Redis. + CheckKeyLocked(ctx context.Context, client rueidis.Client, key string) bool +} + +// PrefixLockChecker checks locks based on a prefix. +type PrefixLockChecker struct { + Prefix string +} + +// HasLock checks if the value has the lock prefix. +func (p *PrefixLockChecker) HasLock(val string) bool { + return strings.HasPrefix(val, p.Prefix) +} + +// CheckKeyLocked checks if a key has an active lock. +func (p *PrefixLockChecker) CheckKeyLocked(ctx context.Context, client rueidis.Client, key string) bool { + resp := client.Do(ctx, client.B().Get().Key(key).Build()) + val, err := resp.ToString() + if err != nil { + return false + } + return p.HasLock(val) +} + +// BatchCheckLocks checks multiple keys for locks and returns those with locks. +// Uses DoMultiCache to ensure we're subscribed to invalidations for locked keys. +func BatchCheckLocks(ctx context.Context, client rueidis.Client, keys []string, checker LockChecker) []string { + if len(keys) == 0 { + return nil + } + + // Build cacheable commands to check for locks + cmds := make([]rueidis.CacheableTTL, 0, len(keys)) + for _, key := range keys { + cmds = append(cmds, rueidis.CacheableTTL{ + Cmd: client.B().Get().Key(key).Cache(), + TTL: time.Second, // Short TTL for lock checks + }) + } + + lockedKeys := make([]string, 0) + resps := client.DoMultiCache(ctx, cmds...) + for i, resp := range resps { + val, err := resp.ToString() + if err == nil && checker.HasLock(val) { + lockedKeys = append(lockedKeys, keys[i]) + } + } + + return lockedKeys +} + +// CommitReadLocks atomically replaces read lock values with real values using CAS. +func (dlm *DistributedLockManager) CommitReadLocks( + ctx context.Context, + ttl time.Duration, + lockValues map[string]string, + actualValues map[string]string, +) ([]string, []string, error) { + if len(lockValues) == 0 { + return nil, nil, nil + } + + // Group by slot for Redis Cluster compatibility + stmtsBySlot := dlm.groupCommitsBySlot(lockValues, actualValues, ttl) + + // Execute grouped statements and collect results + succeeded, needsRetry, err := dlm.executeCommitStatements(ctx, stmtsBySlot) + if err != nil { + return nil, nil, err + } + + return succeeded, needsRetry, nil +} + +// CreateImmediateWaitHandle creates a WaitHandle that returns immediately. +func (dlm *DistributedLockManager) CreateImmediateWaitHandle() WaitHandle { + return &immediateWaitHandle{} +} + +// slotCommitStatements holds commit statements grouped by slot. +type slotCommitStatements struct { + keyOrder []string + execStmts []rueidis.LuaExec +} + +// groupCommitsBySlot groups commit operations by Redis cluster slot. +func (dlm *DistributedLockManager) groupCommitsBySlot( + lockValues map[string]string, + actualValues map[string]string, + ttl time.Duration, +) map[uint16]slotCommitStatements { + estimatedSlots := len(lockValues) / 8 + if estimatedSlots < 1 { + estimatedSlots = 1 + } + stmts := make(map[uint16]slotCommitStatements, estimatedSlots) + + // Pre-calculate TTL string once + ttlStr := strconv.FormatInt(ttl.Milliseconds(), 10) + + for key, lockVal := range lockValues { + actualVal, ok := actualValues[key] + if !ok { + // Skip keys without actual values (shouldn't happen, but be defensive) + dlm.logger.Error("no actual value for key in CommitReadLocks", "key", key) + continue + } + + slot := cmdx.Slot(key) + stmt := stmts[slot] + + // Pre-allocate slices on first access to this slot + if stmt.keyOrder == nil { + estimatedKeysPerSlot := (len(lockValues) / estimatedSlots) + 1 + stmt.keyOrder = make([]string, 0, estimatedKeysPerSlot) + stmt.execStmts = make([]rueidis.LuaExec, 0, estimatedKeysPerSlot) + } + + stmt.keyOrder = append(stmt.keyOrder, key) + stmt.execStmts = append(stmt.execStmts, rueidis.LuaExec{ + Keys: []string{key}, + Args: []string{lockVal, actualVal, ttlStr}, + }) + stmts[slot] = stmt + } + + return stmts +} + +// executeCommitStatements executes commit statements and returns succeeded/failed keys. +func (dlm *DistributedLockManager) executeCommitStatements( + ctx context.Context, + stmtsBySlot map[uint16]slotCommitStatements, +) ([]string, []string, error) { + succeeded := make([]string, 0) + needsRetry := make([]string, 0) + + // Execute each slot's statements + for _, stmt := range stmtsBySlot { + setResps := commitReadLockScript.ExecMulti(ctx, dlm.client, stmt.execStmts...) + + // Process responses in order + for i, resp := range setResps { + key := stmt.keyOrder[i] + + // Check for Redis errors + if err := resp.Error(); err != nil && !rueidis.IsRedisNil(err) { + return nil, nil, fmt.Errorf("failed to commit lock for key %q: %w", key, err) + } + + // Check the Lua script return value (0 = lock lost) + returnValue, err := resp.AsInt64() + if err != nil || returnValue == 0 { + // Lock was lost for this key + needsRetry = append(needsRetry, key) + continue + } + + succeeded = append(succeeded, key) + } + } + + return succeeded, needsRetry, nil +} + +// immediateWaitHandle is a WaitHandle that returns immediately without waiting. +// Used for non-contention errors or when invalidation already occurred. +type immediateWaitHandle struct{} + +// Wait returns immediately without blocking. +func (i *immediateWaitHandle) Wait(_ context.Context) error { + return nil +} + +// Lua script for committing read locks (replacing lock with actual value). +var commitReadLockScript = luascript.New(` +if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("SET", KEYS[1], ARGV[2], "PX", ARGV[3]) + return 1 +else + return 0 +end +`) + +// Lua script for atomic lock release (only delete if value matches). +var delKeyLua = luascript.New(` +if redis.call("GET", KEYS[1]) == ARGV[1] then + return redis.call("DEL", KEYS[1]) +else + return 0 +end +`) diff --git a/internal/lockmanager/lockmanager_test.go b/internal/lockmanager/lockmanager_test.go new file mode 100644 index 0000000..d98d156 --- /dev/null +++ b/internal/lockmanager/lockmanager_test.go @@ -0,0 +1,465 @@ +//go:build integration + +package lockmanager_test + +import ( + "context" + "testing" + "time" + + "github.com/redis/rueidis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache/internal/invalidation" + "github.com/dcbickfo/redcache/internal/lockmanager" +) + +// noopLogger is a no-op logger for testing +type noopLogger struct{} + +func (noopLogger) Debug(msg string, args ...any) {} +func (noopLogger) Error(msg string, args ...any) {} + +// makeClient creates a Redis client for testing +func makeClient(t *testing.T) rueidis.Client { + t.Helper() + client, err := rueidis.NewClient(rueidis.ClientOption{ + InitAddress: []string{"127.0.0.1:6379"}, + }) + require.NoError(t, err, "Failed to connect to Redis") + return client +} + +// makeLockManager creates a LockManager for testing +func makeLockManager(t *testing.T, client rueidis.Client) lockmanager.LockManager { + t.Helper() + + invHandler := invalidation.NewRedisInvalidationHandler(invalidation.Config{ + LockTTL: 5 * time.Second, + Logger: noopLogger{}, + }) + + return lockmanager.NewDistributedLockManager(lockmanager.Config{ + Client: client, + LockPrefix: "__test:lock:", + LockTTL: 5 * time.Second, + Logger: noopLogger{}, + InvalidationHandler: invHandler, + }) +} + +// TestLockManager_TryAcquire tests optimistic lock acquisition +func TestLockManager_TryAcquire(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("acquires lock for new key", func(t *testing.T) { + lockVal, retry, err := lm.TryAcquire(ctx, "key1") + require.NoError(t, err) + assert.NotEmpty(t, lockVal) + assert.Nil(t, retry) + + // Clean up + err = lm.ReleaseLock(ctx, "key1", lockVal) + assert.NoError(t, err) + }) + + t.Run("returns retry handle when lock exists", func(t *testing.T) { + // Acquire lock first + lockVal1, _, err := lm.TryAcquire(ctx, "key2") + require.NoError(t, err) + defer lm.ReleaseLock(ctx, "key2", lockVal1) + + // Try to acquire again + lockVal2, retry, err := lm.TryAcquire(ctx, "key2") + require.NoError(t, err) + assert.Empty(t, lockVal2) + assert.NotNil(t, retry) + }) + + t.Run("context cancellation", func(t *testing.T) { + cancelCtx, cancel := context.WithCancel(t.Context()) + cancel() + + _, _, err := lm.TryAcquire(cancelCtx, "key3") + assert.ErrorIs(t, err, context.Canceled) + }) +} + +// TestLockManager_TryAcquireMulti tests batch optimistic lock acquisition +func TestLockManager_TryAcquireMulti(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("acquires locks for multiple keys", func(t *testing.T) { + keys := []string{"m1", "m2", "m3"} + + acquired, retry, err := lm.TryAcquireMulti(ctx, keys) + require.NoError(t, err) + assert.Len(t, acquired, 3) + assert.Empty(t, retry) + + // Clean up + lm.ReleaseMultiLocks(ctx, acquired) + }) + + t.Run("returns retry handles for contested keys", func(t *testing.T) { + keys := []string{"m4", "m5", "m6"} + + // Acquire some locks first + lock1, _, err := lm.TryAcquire(ctx, "m4") + require.NoError(t, err) + defer lm.ReleaseLock(ctx, "m4", lock1) + + // Try to acquire all - should get retry for m4 + acquired, retry, err := lm.TryAcquireMulti(ctx, keys) + require.NoError(t, err) + assert.Len(t, acquired, 2) // m5, m6 + assert.Len(t, retry, 1) // m4 + assert.Contains(t, retry, "m4") + + // Clean up + lm.ReleaseMultiLocks(ctx, acquired) + }) +} + +// TestLockManager_AcquireLockBlocking tests blocking lock acquisition +func TestLockManager_AcquireLockBlocking(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("acquires lock immediately if available", func(t *testing.T) { + lockVal, err := lm.AcquireLockBlocking(ctx, "b1") + require.NoError(t, err) + assert.NotEmpty(t, lockVal) + + err = lm.ReleaseLock(ctx, "b1", lockVal) + assert.NoError(t, err) + }) + + t.Run("waits for lock release", func(t *testing.T) { + // Acquire lock first + lock1, err := lm.AcquireLockBlocking(ctx, "b2") + require.NoError(t, err) + + // Launch goroutine to release after delay + go func() { + time.Sleep(100 * time.Millisecond) + lm.ReleaseLock(context.Background(), "b2", lock1) + }() + + // Try to acquire - should wait and then succeed + start := time.Now() + lock2, err := lm.AcquireLockBlocking(ctx, "b2") + duration := time.Since(start) + + require.NoError(t, err) + assert.NotEmpty(t, lock2) + assert.Greater(t, duration, 50*time.Millisecond) + + err = lm.ReleaseLock(ctx, "b2", lock2) + assert.NoError(t, err) + }) + + t.Run("context cancellation", func(t *testing.T) { + cancelCtx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + + // Hold lock in another goroutine + lock1, err := lm.AcquireLockBlocking(context.Background(), "b3") + require.NoError(t, err) + defer lm.ReleaseLock(context.Background(), "b3", lock1) + + // Try to acquire with short timeout + _, err = lm.AcquireLockBlocking(cancelCtx, "b3") + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) +} + +// TestLockManager_AcquireMultiLocksBlocking tests batch blocking lock acquisition +func TestLockManager_AcquireMultiLocksBlocking(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("acquires all locks", func(t *testing.T) { + keys := []string{"mb1", "mb2", "mb3"} + + acquired, err := lm.AcquireMultiLocksBlocking(ctx, keys) + require.NoError(t, err) + assert.Len(t, acquired, 3) + + lm.ReleaseMultiLocks(ctx, acquired) + }) + + t.Run("waits for all locks to become available", func(t *testing.T) { + keys := []string{"mb4", "mb5"} + + // Hold one lock + lock1, err := lm.AcquireLockBlocking(ctx, "mb4") + require.NoError(t, err) + + // Release after delay + go func() { + time.Sleep(100 * time.Millisecond) + lm.ReleaseLock(context.Background(), "mb4", lock1) + }() + + // Try to acquire both - should wait + start := time.Now() + acquired, err := lm.AcquireMultiLocksBlocking(ctx, keys) + duration := time.Since(start) + + require.NoError(t, err) + assert.Len(t, acquired, 2) + assert.Greater(t, duration, 50*time.Millisecond) + + lm.ReleaseMultiLocks(ctx, acquired) + }) +} + +// TestLockManager_ReleaseLock tests lock release +func TestLockManager_ReleaseLock(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("releases owned lock", func(t *testing.T) { + lockVal, err := lm.AcquireLockBlocking(ctx, "r1") + require.NoError(t, err) + + err = lm.ReleaseLock(ctx, "r1", lockVal) + assert.NoError(t, err) + + // Verify lock is released + locked := lm.CheckKeyLocked(ctx, "r1") + assert.False(t, locked) + }) + + t.Run("fails to release with wrong lock value", func(t *testing.T) { + lockVal, err := lm.AcquireLockBlocking(ctx, "r2") + require.NoError(t, err) + defer lm.ReleaseLock(ctx, "r2", lockVal) + + err = lm.ReleaseLock(ctx, "r2", "wrong-value") + assert.NoError(t, err) // Doesn't error, just doesn't release + + // Verify lock is still held + locked := lm.CheckKeyLocked(ctx, "r2") + assert.True(t, locked) + }) +} + +// TestLockManager_CheckKeyLocked tests lock status checking +func TestLockManager_CheckKeyLocked(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("returns false for unlocked key", func(t *testing.T) { + locked := lm.CheckKeyLocked(ctx, "c1") + assert.False(t, locked) + }) + + t.Run("returns true for locked key", func(t *testing.T) { + lockVal, err := lm.AcquireLockBlocking(ctx, "c2") + require.NoError(t, err) + defer lm.ReleaseLock(ctx, "c2", lockVal) + + locked := lm.CheckKeyLocked(ctx, "c2") + assert.True(t, locked) + }) + + t.Run("returns false for real cached value", func(t *testing.T) { + // Set a real value (not a lock) + err := client.Do(ctx, client.B().Set().Key("c3").Value("real-value").Build()).Error() + require.NoError(t, err) + + locked := lm.CheckKeyLocked(ctx, "c3") + assert.False(t, locked) + }) +} + +// TestLockManager_CheckMultiKeysLocked tests batch lock checking +func TestLockManager_CheckMultiKeysLocked(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("identifies locked keys", func(t *testing.T) { + // Set up: lock some keys, leave others unlocked + lock1, err := lm.AcquireLockBlocking(ctx, "cm1") + require.NoError(t, err) + defer lm.ReleaseLock(ctx, "cm1", lock1) + + lock3, err := lm.AcquireLockBlocking(ctx, "cm3") + require.NoError(t, err) + defer lm.ReleaseLock(ctx, "cm3", lock3) + + keys := []string{"cm1", "cm2", "cm3", "cm4"} + locked := lm.CheckMultiKeysLocked(ctx, keys) + + assert.Len(t, locked, 2) + assert.Contains(t, locked, "cm1") + assert.Contains(t, locked, "cm3") + }) +} + +// TestLockManager_CommitReadLocks tests CAS commit for read locks +func TestLockManager_CommitReadLocks(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("commits locks to real values", func(t *testing.T) { + // Acquire locks + lock1, err := lm.AcquireLockBlocking(ctx, "co1") + require.NoError(t, err) + + lock2, err := lm.AcquireLockBlocking(ctx, "co2") + require.NoError(t, err) + + lockValues := map[string]string{ + "co1": lock1, + "co2": lock2, + } + + actualValues := map[string]string{ + "co1": "value1", + "co2": "value2", + } + + succeeded, needsRetry, err := lm.CommitReadLocks(ctx, 10*time.Second, lockValues, actualValues) + require.NoError(t, err) + assert.Len(t, succeeded, 2) + assert.Empty(t, needsRetry) + + // Verify values are set + resp := client.Do(ctx, client.B().Get().Key("co1").Build()) + val, _ := resp.ToString() + assert.Equal(t, "value1", val) + }) + + t.Run("returns needsRetry if lock lost", func(t *testing.T) { + // Generate a lock value manually + lockVal := lm.GenerateLockValue() + + // Set the lock directly (simpler than AcquireLockBlocking which creates CSC subscriptions) + err := client.Do(ctx, client.B().Set().Key("co3").Value(lockVal).Px(5*time.Second).Build()).Error() + require.NoError(t, err) + + // Overwrite with a different value to simulate lock being stolen + err = client.Do(ctx, client.B().Set().Key("co3").Value("stolen").Build()).Error() + require.NoError(t, err) + + // Try to commit with the original lock value (should fail CAS) + lockValues := map[string]string{"co3": lockVal} + actualValues := map[string]string{"co3": "new-value"} + + succeeded, needsRetry, err := lm.CommitReadLocks(ctx, 10*time.Second, lockValues, actualValues) + require.NoError(t, err) + assert.Empty(t, succeeded) + assert.Len(t, needsRetry, 1) + assert.Contains(t, needsRetry, "co3") + }) +} + +// TestLockManager_CleanupUnusedLocks tests cleanup of unused locks +func TestLockManager_CleanupUnusedLocks(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + ctx := t.Context() + + t.Run("releases unused locks only", func(t *testing.T) { + // Acquire 3 locks + acquired := make(map[string]string) + for _, key := range []string{"cu1", "cu2", "cu3"} { + lock, err := lm.AcquireLockBlocking(ctx, key) + require.NoError(t, err) + acquired[key] = lock + } + + // Mark only 2 as used + used := map[string]string{ + "cu1": acquired["cu1"], + "cu2": acquired["cu2"], + } + + // Cleanup should release cu3 + lm.CleanupUnusedLocks(ctx, acquired, used) + + // Verify cu1 and cu2 still locked + assert.True(t, lm.CheckKeyLocked(ctx, "cu1")) + assert.True(t, lm.CheckKeyLocked(ctx, "cu2")) + + // Verify cu3 released + assert.False(t, lm.CheckKeyLocked(ctx, "cu3")) + + // Clean up remaining + lm.ReleaseLock(ctx, "cu1", acquired["cu1"]) + lm.ReleaseLock(ctx, "cu2", acquired["cu2"]) + }) +} + +// TestLockManager_IsLockValue tests lock value identification +func TestLockManager_IsLockValue(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + + t.Run("identifies lock values", func(t *testing.T) { + lockVal := lm.GenerateLockValue() + assert.True(t, lm.IsLockValue(lockVal)) + }) + + t.Run("rejects non-lock values", func(t *testing.T) { + assert.False(t, lm.IsLockValue("regular-value")) + assert.False(t, lm.IsLockValue("")) + }) +} + +// TestLockManager_GenerateLockValue tests lock value generation +func TestLockManager_GenerateLockValue(t *testing.T) { + client := makeClient(t) + defer client.Close() + + lm := makeLockManager(t, client) + + t.Run("generates unique lock values", func(t *testing.T) { + lock1 := lm.GenerateLockValue() + lock2 := lm.GenerateLockValue() + + assert.NotEmpty(t, lock1) + assert.NotEmpty(t, lock2) + assert.NotEqual(t, lock1, lock2) + }) + + t.Run("generated values are valid locks", func(t *testing.T) { + lockVal := lm.GenerateLockValue() + assert.True(t, lm.IsLockValue(lockVal)) + }) +} diff --git a/internal/lockutil/lockutil.go b/internal/lockutil/lockutil.go deleted file mode 100644 index 5483adb..0000000 --- a/internal/lockutil/lockutil.go +++ /dev/null @@ -1,83 +0,0 @@ -// Package lockutil provides shared utilities for lock management in cache-aside operations. -package lockutil - -import ( - "context" - "strings" - "time" - - "github.com/redis/rueidis" -) - -// LockChecker checks if a key has an active lock. -type LockChecker interface { - // HasLock checks if the given value is a lock (not a real cached value). - HasLock(val string) bool - // CheckKeyLocked checks if a key currently has a lock in Redis. - CheckKeyLocked(ctx context.Context, client rueidis.Client, key string) bool -} - -// PrefixLockChecker checks locks based on a prefix. -type PrefixLockChecker struct { - Prefix string -} - -// HasLock checks if the value has the lock prefix. -func (p *PrefixLockChecker) HasLock(val string) bool { - return strings.HasPrefix(val, p.Prefix) -} - -// CheckKeyLocked checks if a key has an active lock. -func (p *PrefixLockChecker) CheckKeyLocked(ctx context.Context, client rueidis.Client, key string) bool { - resp := client.Do(ctx, client.B().Get().Key(key).Build()) - val, err := resp.ToString() - if err != nil { - return false - } - return p.HasLock(val) -} - -// BatchCheckLocks checks multiple keys for locks and returns those with locks. -// Uses DoMultiCache to ensure we're subscribed to invalidations for locked keys. -func BatchCheckLocks(ctx context.Context, client rueidis.Client, keys []string, checker LockChecker) []string { - if len(keys) == 0 { - return nil - } - - // Build cacheable commands to check for locks - cmds := make([]rueidis.CacheableTTL, 0, len(keys)) - for _, key := range keys { - cmds = append(cmds, rueidis.CacheableTTL{ - Cmd: client.B().Get().Key(key).Cache(), - TTL: time.Second, // Short TTL for lock checks - }) - } - - lockedKeys := make([]string, 0) - resps := client.DoMultiCache(ctx, cmds...) - for i, resp := range resps { - val, err := resp.ToString() - if err == nil && checker.HasLock(val) { - lockedKeys = append(lockedKeys, keys[i]) - } - } - - return lockedKeys -} - -// WaitForSingleLock waits for a single lock to be released via invalidation or timeout. -func WaitForSingleLock(ctx context.Context, waitChan <-chan struct{}, lockTTL time.Duration) error { - timer := time.NewTimer(lockTTL) - defer timer.Stop() - - select { - case <-waitChan: - // Lock released via invalidation - return nil - case <-timer.C: - // Lock TTL expired - return nil - case <-ctx.Done(): - return ctx.Err() - } -} diff --git a/internal/lockutil/lockutil_test.go b/internal/lockutil/lockutil_test.go deleted file mode 100644 index 07096b9..0000000 --- a/internal/lockutil/lockutil_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package lockutil_test - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/dcbickfo/redcache/internal/lockutil" -) - -func TestPrefixLockChecker_HasLock(t *testing.T) { - checker := &lockutil.PrefixLockChecker{Prefix: "__lock:"} - - tests := []struct { - name string - value string - expected bool - }{ - {"lock value", "__lock:abc123", true}, - {"real value", "some-data", false}, - {"empty value", "", false}, - {"partial prefix", "__loc", false}, - {"prefix at end", "data__lock:", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := checker.HasLock(tt.value) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestWaitForSingleLock_ImmediateRelease(t *testing.T) { - ctx := context.Background() - waitChan := make(chan struct{}) - - // Close channel immediately to simulate lock release - close(waitChan) - - err := lockutil.WaitForSingleLock(ctx, waitChan, time.Second) - require.NoError(t, err) -} - -func TestWaitForSingleLock_Timeout(t *testing.T) { - ctx := context.Background() - waitChan := make(chan struct{}) - - // Don't close channel, let it timeout - err := lockutil.WaitForSingleLock(ctx, waitChan, 50*time.Millisecond) - require.NoError(t, err, "timeout should not return error") -} - -func TestWaitForSingleLock_ContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - waitChan := make(chan struct{}) - - // Cancel context immediately - cancel() - - err := lockutil.WaitForSingleLock(ctx, waitChan, time.Second) - require.Error(t, err) - assert.Equal(t, context.Canceled, err) -} - -func TestWaitForSingleLock_ContextTimeout(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - waitChan := make(chan struct{}) - - err := lockutil.WaitForSingleLock(ctx, waitChan, time.Second) - require.Error(t, err) - assert.Equal(t, context.DeadlineExceeded, err) -} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..b2c0761 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,14 @@ +// Package logger provides a common logging interface for all redcache components. +package logger + +// Logger defines the logging interface used throughout redcache. +// Implementations must be safe for concurrent use and should handle log levels internally. +type Logger interface { + // Error logs error messages. Should be used for unexpected failures or critical issues. + Error(msg string, args ...any) + + // Debug logs detailed diagnostic information useful for development and troubleshooting. + // Call Debug to record verbose output about internal state, cache operations, or lock handling. + // Debug messages should not include sensitive information and may be omitted in production. + Debug(msg string, args ...any) +} diff --git a/internal/luascript/script.go b/internal/luascript/script.go new file mode 100644 index 0000000..045dc92 --- /dev/null +++ b/internal/luascript/script.go @@ -0,0 +1,41 @@ +// Package luascript provides a common interface for Lua script execution in Redis. +package luascript + +import ( + "context" + + "github.com/redis/rueidis" +) + +// Executor defines the interface for executing Lua scripts in Redis. +// This abstraction allows for consistent script handling across the codebase. +type Executor interface { + // Exec executes the Lua script with the given keys and arguments. + Exec(ctx context.Context, client rueidis.Client, keys, args []string) rueidis.RedisResult + + // ExecMulti executes multiple instances of the script with different parameters. + ExecMulti(ctx context.Context, client rueidis.Client, statements ...rueidis.LuaExec) []rueidis.RedisResult +} + +// New creates a new Lua script executor that wraps rueidis.Lua. +// This provides a consistent interface for all Lua script operations. +func New(script string) Executor { + return &executor{ + script: rueidis.NewLuaScript(script), + } +} + +// executor wraps rueidis.Lua to implement the Executor interface. +type executor struct { + script *rueidis.Lua +} + +// Exec implements Executor. +func (e *executor) Exec(ctx context.Context, client rueidis.Client, keys, args []string) rueidis.RedisResult { + return e.script.Exec(ctx, client, keys, args) +} + +// ExecMulti implements Executor. +func (e *executor) ExecMulti(ctx context.Context, client rueidis.Client, statements ...rueidis.LuaExec) []rueidis.RedisResult { + return e.script.ExecMulti(ctx, client, statements...) +} diff --git a/internal/mapsx/mapsx.go b/internal/mapsx/mapsx.go index 035fc51..5ac43c2 100644 --- a/internal/mapsx/mapsx.go +++ b/internal/mapsx/mapsx.go @@ -28,13 +28,3 @@ func Values[K comparable, V any](m map[K]V) []V { } return values } - -// ToSet converts map keys to a set (map with bool values). -// This is useful for creating exclusion sets or membership tests. -func ToSet[K comparable, V any](m map[K]V) map[K]bool { - result := make(map[K]bool, len(m)) - for key := range m { - result[key] = true - } - return result -} diff --git a/internal/mapsx/mapsx_test.go b/internal/mapsx/mapsx_test.go index 70f624d..484b504 100644 --- a/internal/mapsx/mapsx_test.go +++ b/internal/mapsx/mapsx_test.go @@ -39,90 +39,44 @@ func TestKeys(t *testing.T) { }) } -func TestToSet(t *testing.T) { - t.Run("converts string map to set", func(t *testing.T) { - input := map[string]string{ - "key1": "value1", - "key2": "value2", - "key3": "value3", - } - - result := mapsx.ToSet(input) +func TestValues(t *testing.T) { + t.Run("empty map returns empty slice", func(t *testing.T) { + m := make(map[string]int) + values := mapsx.Values(m) + assert.Empty(t, values) + assert.NotNil(t, values) // Should return empty slice, not nil + }) - expected := map[string]bool{ - "key1": true, - "key2": true, - "key3": true, + t.Run("extracts all values from map", func(t *testing.T) { + m := map[string]int{ + "a": 1, + "b": 2, + "c": 3, } - assert.Equal(t, expected, result) + values := mapsx.Values(m) + assert.Len(t, values, 3) + assert.ElementsMatch(t, []int{1, 2, 3}, values) }) - t.Run("converts int map to set", func(t *testing.T) { - input := map[int]string{ + t.Run("works with different types", func(t *testing.T) { + m := map[int]string{ 1: "one", 2: "two", 3: "three", } - - result := mapsx.ToSet(input) - - expected := map[int]bool{ - 1: true, - 2: true, - 3: true, - } - assert.Equal(t, expected, result) - }) - - t.Run("returns empty set for empty map", func(t *testing.T) { - input := map[string]int{} - - result := mapsx.ToSet(input) - - assert.Empty(t, result) - }) - - t.Run("returns empty set for nil map", func(t *testing.T) { - var input map[string]string - - result := mapsx.ToSet(input) - - assert.Empty(t, result) - }) - - t.Run("works with struct values", func(t *testing.T) { - type Value struct { - Name string - Age int - } - input := map[string]Value{ - "alice": {Name: "Alice", Age: 30}, - "bob": {Name: "Bob", Age: 25}, - } - - result := mapsx.ToSet(input) - - expected := map[string]bool{ - "alice": true, - "bob": true, - } - assert.Equal(t, expected, result) + values := mapsx.Values(m) + assert.Len(t, values, 3) + assert.ElementsMatch(t, []string{"one", "two", "three"}, values) }) - t.Run("preserves all keys with different value types", func(t *testing.T) { - input := map[string]interface{}{ - "key1": "string", - "key2": 123, - "key3": true, - } - - result := mapsx.ToSet(input) - - expected := map[string]bool{ - "key1": true, - "key2": true, - "key3": true, + t.Run("handles duplicate values", func(t *testing.T) { + m := map[string]int{ + "a": 1, + "b": 1, + "c": 2, } - assert.Equal(t, expected, result) + values := mapsx.Values(m) + assert.Len(t, values, 3) + assert.ElementsMatch(t, []int{1, 1, 2}, values) }) } diff --git a/internal/syncx/map.go b/internal/syncx/map.go index 11872d4..385df91 100644 --- a/internal/syncx/map.go +++ b/internal/syncx/map.go @@ -6,6 +6,11 @@ type Map[K comparable, V any] struct { m sync.Map } +// NewMap creates a new generic sync.Map wrapper. +func NewMap[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{} +} + func (sm *Map[K, V]) CompareAndDelete(key K, old V) bool { return sm.m.CompareAndDelete(key, old) } diff --git a/internal/syncx/wait.go b/internal/syncx/wait.go deleted file mode 100644 index cdfacb6..0000000 --- a/internal/syncx/wait.go +++ /dev/null @@ -1,53 +0,0 @@ -package syncx - -import ( - "context" - "sync" -) - -// WaitForAll waits for all channels to close or for context cancellation. -// Uses goroutine-based channel merging which is ~100x faster than reflect.SelectCase. -// -// Implementation uses one goroutine per channel to forward close signals to a merged -// channel, avoiding reflection overhead entirely. -func WaitForAll[C ~<-chan V, V any](ctx context.Context, channels []C) error { - if len(channels) == 0 { - return nil - } - - // Merged channel that collects close signals from all input channels - done := make(chan struct{}, len(channels)) - var wg sync.WaitGroup - - // Launch one goroutine per channel to wait for close - for _, ch := range channels { - wg.Add(1) - go func(c C) { - defer wg.Done() - // Wait for channel to close - for range c { - // Drain channel until closed - } - // Signal completion - done <- struct{}{} - }(ch) - } - - // Close done channel when all goroutines complete - go func() { - wg.Wait() - close(done) - }() - - // Wait for all channels to close or context cancellation - for i := 0; i < len(channels); i++ { - select { - case <-done: - // One channel closed - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil -} diff --git a/internal/syncx/wait_test.go b/internal/syncx/wait_test.go deleted file mode 100644 index ea1c4a1..0000000 --- a/internal/syncx/wait_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package syncx_test - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/dcbickfo/redcache/internal/syncx" -) - -func delayedSend[T any](ch chan T, val T, delay time.Duration) { - go func() { - time.Sleep(delay) - ch <- val - close(ch) - }() -} - -func delayedClose[T any](ch chan T, delay time.Duration) { - go func() { - time.Sleep(delay) - close(ch) - }() -} - -func TestWaitForAll_Success(t *testing.T) { - ctx := context.Background() - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - - delayedSend(ch1, struct{}{}, 100*time.Millisecond) - delayedSend(ch2, struct{}{}, 200*time.Millisecond) - - waitLock := []<-chan struct{}{ch1, ch2} - - err := syncx.WaitForAll(ctx, waitLock) - assert.NoErrorf(t, err, "expected no error, got %v", err) -} - -func TestWaitForAll_SuccessClosed(t *testing.T) { - ctx := context.Background() - ch1 := make(chan struct{}) - ch2 := make(chan struct{}) - - delayedClose(ch1, 100*time.Millisecond) - delayedClose(ch2, 200*time.Millisecond) - - waitLock := []<-chan struct{}{ch1, ch2} - - err := syncx.WaitForAll(ctx, waitLock) - assert.NoErrorf(t, err, "expected no error, got %v", err) -} - -func TestWaitForAll_ContextCancelled(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) - defer cancel() - - ch1 := make(chan int) - ch2 := make(chan int) - - delayedSend(ch1, 1, 200*time.Millisecond) - delayedSend(ch2, 2, 300*time.Millisecond) - - waitLock := []<-chan int{ch1, ch2} - - err := syncx.WaitForAll(ctx, waitLock) - assert.ErrorIsf(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded, got %v", err) -} - -func TestWaitForAll_PartialCompleteContextCancelled(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - - ch1 := make(chan int) - ch2 := make(chan int) - - delayedSend(ch1, 1, 100*time.Millisecond) - delayedSend(ch2, 2, 300*time.Millisecond) - - waitLock := []<-chan int{ch1, ch2} - - err := syncx.WaitForAll(ctx, waitLock) - assert.ErrorIsf(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded, got %v", err) -} - -func TestWaitForAll_NoChannels(t *testing.T) { - ctx := context.Background() - var waitLock []<-chan int - - err := syncx.WaitForAll(ctx, waitLock) - assert.NoErrorf(t, err, "expected no error, got %v", err) -} - -func TestWaitForAll_ImmediateContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - ch1 := make(chan int) - ch2 := make(chan int) - - waitLock := []<-chan int{ch1, ch2} - - err := syncx.WaitForAll(ctx, waitLock) - assert.ErrorIsf(t, err, context.Canceled, "expected context.Canceled, got %v", err) -} - -func TestWaitForAll_ChannelAlreadyClosed(t *testing.T) { - ctx := context.Background() - ch1 := make(chan int) - ch2 := make(chan int) - - close(ch1) - close(ch2) - - waitLock := []<-chan int{ch1, ch2} - - err := syncx.WaitForAll(ctx, waitLock) - assert.NoErrorf(t, err, "expected no error, got %v", err) -} diff --git a/internal/writelock/writelock.go b/internal/writelock/writelock.go new file mode 100644 index 0000000..627bd2c --- /dev/null +++ b/internal/writelock/writelock.go @@ -0,0 +1,760 @@ +// Package writelock provides write lock management for cache Set operations with CAS semantics. +// This package handles the complex lock acquisition, rollback, and slot grouping logic +// required for PrimeableCacheAside Set/SetMulti operations. +package writelock + +import ( + "context" + "fmt" + "sort" + "strconv" + "time" + + "github.com/redis/rueidis" + + "github.com/dcbickfo/redcache/internal/cmdx" + "github.com/dcbickfo/redcache/internal/errs" + "github.com/dcbickfo/redcache/internal/lockmanager" + "github.com/dcbickfo/redcache/internal/logger" + "github.com/dcbickfo/redcache/internal/luascript" +) + +// WriteLockManager defines the interface for write lock management with CAS semantics. +// Unlike simple read locks (LockManager), write locks need to: +// - Overwrite real values but not other locks +// - Support rollback via backup/restore +// - Handle partial acquisition failures. +type WriteLockManager interface { + // AcquireWriteLock acquires a single write lock with retry. + // Returns the lock value on success, or error if context cancelled. + AcquireWriteLock(ctx context.Context, key string) (lockValue string, err error) + + // AcquireMultiWriteLocks attempts to acquire write locks for multiple keys. + // Returns: + // - acquired: map of successfully locked keys to lock values + // - savedValues: map of keys to their previous values (for rollback) + // - failed: keys that couldn't be locked + // - error: critical errors (context cancellation, Redis errors) + AcquireMultiWriteLocks(ctx context.Context, keys []string) ( + acquired map[string]string, + savedValues map[string]string, + failed []string, + err error, + ) + + // ReleaseWriteLock releases a single write lock if it matches the expected value. + ReleaseWriteLock(ctx context.Context, key string, lockValue string) error + + // ReleaseWriteLocks releases multiple write locks. + ReleaseWriteLocks(ctx context.Context, lockValues map[string]string) + + // RestoreValues restores backed-up values (rollback on partial failure). + RestoreValues(ctx context.Context, savedValues map[string]string, lockValues map[string]string) + + // TouchLocks refreshes TTLs on multiple locks (for long operations). + TouchLocks(ctx context.Context, lockValues map[string]string) + + // CommitWriteLocks atomically replaces write lock values with real values using CAS. + // Only succeeds for keys where we still hold the exact lock value. + // This is for WRITE locks only - works with backup/restore mechanism. + // + // Returns: + // - succeeded: map of successfully committed keys to their values + // - failed: map of keys that failed with error details (lock lost or Redis errors) + // + // Usage: After acquiring write locks (with backup) and executing callback, use this + // to atomically replace the lock placeholders with actual values. If some keys lost + // their locks (e.g., due to ForceSet), they will be returned in failed map. + // The caller should use RestoreValues() for rollback if needed. + CommitWriteLocks(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (succeeded map[string]string, failed map[string]error) + + // AcquireMultiWriteLocksSequential acquires write locks for multiple keys using sequential + // acquisition with automatic deadlock prevention and rollback. + // + // This method implements a sophisticated lock acquisition strategy: + // 1. Sorts keys for consistent ordering (prevents deadlocks across processes) + // 2. Attempts batch acquisition of all remaining keys + // 3. On partial failure: + // - Keeps locks acquired sequentially (before first failure) + // - Restores values for out-of-order locks (after first failure) + // - Waits for the first failed key, then retries + // 4. Repeats until all locks acquired or context cancelled + // + // This prevents deadlocks by ensuring all processes wait for keys in the same order. + // It prevents cache misses by restoring original values when releasing early. + // + // Returns: + // - lockValues: map of keys to acquired lock values + // - savedValues: map of keys to their previous values (for rollback on error) + // - error: if context cancelled or critical Redis error + AcquireMultiWriteLocksSequential(ctx context.Context, keys []string) (lockValues map[string]string, savedValues map[string]string, err error) +} + +// CASWriteLockManager implements WriteLockManager with Compare-And-Swap semantics. +// It uses LockManager for cohesive lock value generation and prefix management. +type CASWriteLockManager struct { + client rueidis.Client + lockTTL time.Duration + lockManager lockmanager.LockManager + logger Logger +} + +// Logger is the logging interface used for write lock operations. +// This is a type alias for the shared logger interface. +type Logger = logger.Logger + +// Config holds configuration for creating a CASWriteLockManager. +type Config struct { + Client rueidis.Client + LockTTL time.Duration + LockManager lockmanager.LockManager // Required for cohesive lock generation and checking + Logger Logger +} + +// NewCASWriteLockManager creates a new write lock manager with CAS semantics. +// The LockManager is used to ensure consistent lock value generation and prefix checking +// across both read and write lock operations. +func NewCASWriteLockManager(cfg Config) *CASWriteLockManager { + return &CASWriteLockManager{ + client: cfg.Client, + lockTTL: cfg.LockTTL, + lockManager: cfg.LockManager, + logger: cfg.Logger, + } +} + +// AcquireWriteLock attempts to acquire a single write lock. +// Returns the lock value on success, or errs.ErrLockFailed if lock is held by another process. +// The caller is responsible for retry logic if needed. +func (wlm *CASWriteLockManager) AcquireWriteLock(ctx context.Context, key string) (string, error) { + lockVal := wlm.lockManager.GenerateLockValue() + + result := acquireWriteLockScript.Exec(ctx, wlm.client, + []string{key}, + []string{lockVal, strconv.FormatInt(wlm.lockTTL.Milliseconds(), 10), wlm.getLockPrefix()}) + + success, err := result.AsInt64() + if err != nil { + return "", fmt.Errorf("failed to execute write lock script for key %q: %w", key, err) + } + + if success != 1 { + wlm.logger.Debug("write lock contention", "key", key) + return "", fmt.Errorf("failed to acquire write lock for key %q: %w", key, errs.ErrLockFailed) + } + + wlm.logger.Debug("write lock acquired", "key", key, "lockVal", lockVal) + return lockVal, nil +} + +// AcquireMultiWriteLocks attempts to acquire write locks for multiple keys in batches. +func (wlm *CASWriteLockManager) AcquireMultiWriteLocks( + ctx context.Context, + keys []string, +) (map[string]string, map[string]string, []string, error) { + if len(keys) == 0 { + return make(map[string]string), make(map[string]string), nil, nil + } + + acquired := make(map[string]string) + savedValues := make(map[string]string) + failed := make([]string, 0) + + // Group by slot and build Lua exec statements + stmtsBySlot := wlm.groupLockAcquisitionsBySlot(keys) + + // Execute all slots and collect results + for _, stmts := range stmtsBySlot { + // Batch execute all lock acquisitions for this slot + resps := acquireWriteLockWithBackupScript.ExecMulti(ctx, wlm.client, stmts.execStmts...) + + // Process responses in order + for i, resp := range resps { + key := stmts.keyOrder[i] + lockVal := stmts.lockVals[i] + + result, err := processLockAcquisitionResponse(resp, key, lockVal) + if err != nil { + return nil, nil, nil, err + } + + if result.acquired { + acquired[key] = result.lockValue + if result.hasSaved { + savedValues[key] = result.savedValue + } + } else { + failed = append(failed, key) + } + } + } + + return acquired, savedValues, failed, nil +} + +// ReleaseWriteLock releases a single write lock if it matches the expected value. +func (wlm *CASWriteLockManager) ReleaseWriteLock(ctx context.Context, key string, lockValue string) error { + result := unlockKeyScript.Exec(ctx, wlm.client, []string{key}, []string{lockValue}) + if err := result.Error(); err != nil { + wlm.logger.Debug("failed to release write lock", "key", key, "error", err) + return err + } + return nil +} + +// ReleaseWriteLocks releases multiple write locks. +func (wlm *CASWriteLockManager) ReleaseWriteLocks(ctx context.Context, lockValues map[string]string) { + if len(lockValues) == 0 { + return + } + + for key, lockVal := range lockValues { + result := unlockKeyScript.Exec(ctx, wlm.client, []string{key}, []string{lockVal}) + if err := result.Error(); err != nil { + wlm.logger.Debug("failed to release write lock", "key", key, "error", err) + } + } +} + +// RestoreValues restores backed-up values for keys (rollback on partial failure). +func (wlm *CASWriteLockManager) RestoreValues( + ctx context.Context, + savedValues map[string]string, + lockValues map[string]string, +) { + if len(savedValues) == 0 { + return + } + + // Restore values using Lua script (CAS: only if our lock is still there) + for key, val := range savedValues { + if lockVal, hasLock := lockValues[key]; hasLock { + result := setIfLockScript.Exec(ctx, wlm.client, []string{key}, []string{lockVal, val}) + if err := result.Error(); err != nil { + wlm.logger.Debug("failed to restore value during rollback", "key", key, "error", err) + } + } + } +} + +// TouchLocks refreshes TTLs on multiple locks. +func (wlm *CASWriteLockManager) TouchLocks(ctx context.Context, lockValues map[string]string) { + if len(lockValues) == 0 { + return + } + + cmds := make(rueidis.Commands, 0, len(lockValues)) + for key := range lockValues { + cmds = append(cmds, wlm.client.B().Expire().Key(key).Seconds(int64(wlm.lockTTL.Seconds())).Build()) + } + + resps := wlm.client.DoMulti(ctx, cmds...) + for _, resp := range resps { + if err := resp.Error(); err != nil { + wlm.logger.Debug("failed to touch lock", "error", err) + } + } +} + +// slotLockStatements holds lock acquisition statements grouped by slot. +type slotLockStatements struct { + keyOrder []string + lockVals []string + execStmts []rueidis.LuaExec +} + +// groupLockAcquisitionsBySlot groups lock acquisition operations by Redis cluster slot. +func (wlm *CASWriteLockManager) groupLockAcquisitionsBySlot(keys []string) map[uint16]slotLockStatements { + if len(keys) == 0 { + return nil + } + + estimatedSlots, estimatedPerSlot := cmdx.EstimateSlotDistribution(len(keys)) + stmtsBySlot := make(map[uint16]slotLockStatements, estimatedSlots) + + lockTTLStr := strconv.FormatInt(wlm.lockTTL.Milliseconds(), 10) + + for _, key := range keys { + lockVal := wlm.lockManager.GenerateLockValue() + slot := cmdx.Slot(key) + stmts := stmtsBySlot[slot] + + if stmts.keyOrder == nil { + stmts.keyOrder = make([]string, 0, estimatedPerSlot) + stmts.lockVals = make([]string, 0, estimatedPerSlot) + stmts.execStmts = make([]rueidis.LuaExec, 0, estimatedPerSlot) + } + + stmts.keyOrder = append(stmts.keyOrder, key) + stmts.lockVals = append(stmts.lockVals, lockVal) + stmts.execStmts = append(stmts.execStmts, rueidis.LuaExec{ + Keys: []string{key}, + Args: []string{lockVal, lockTTLStr, wlm.getLockPrefix()}, + }) + + stmtsBySlot[slot] = stmts + } + + return stmtsBySlot +} + +// lockAcquisitionResult holds the result of a lock acquisition attempt. +type lockAcquisitionResult struct { + acquired bool + lockValue string + hasSaved bool + savedValue string +} + +// processLockAcquisitionResponse processes the response from acquireWriteLockWithBackupScript. +func processLockAcquisitionResponse(resp rueidis.RedisResult, key, lockVal string) (lockAcquisitionResult, error) { + arr, err := resp.ToArray() + if err != nil { + return lockAcquisitionResult{}, fmt.Errorf("failed to parse lock response for key %q: %w", key, err) + } + + if len(arr) != 2 { + return lockAcquisitionResult{}, fmt.Errorf("unexpected lock response length for key %q: got %d, want 2", key, len(arr)) + } + + success, err := arr[0].AsInt64() + if err != nil { + return lockAcquisitionResult{}, fmt.Errorf("failed to parse lock success for key %q: %w", key, err) + } + + if success != 1 { + return lockAcquisitionResult{acquired: false}, nil + } + + // Successfully acquired - check if we have a saved value + savedVal, err := arr[1].ToString() + if err != nil { + // No saved value (was false in Lua) + return lockAcquisitionResult{ + acquired: true, + lockValue: lockVal, + hasSaved: false, + }, nil + } + + return lockAcquisitionResult{ + acquired: true, + lockValue: lockVal, + hasSaved: true, + savedValue: savedVal, + }, nil +} + +// CommitWriteLocks atomically replaces write lock values with real values using CAS. +func (wlm *CASWriteLockManager) CommitWriteLocks( + ctx context.Context, + ttl time.Duration, + lockValues map[string]string, + actualValues map[string]string, +) (map[string]string, map[string]error) { + if len(lockValues) == 0 { + return make(map[string]string), make(map[string]error) + } + + succeeded := make(map[string]string) + failed := make(map[string]error) + + // Group by slot for Redis Cluster compatibility + stmtsBySlot := wlm.groupCommitsBySlot(lockValues, actualValues, ttl) + + // Execute each slot's statements + for slot, stmt := range stmtsBySlot { + setResps := commitWriteLockScript.ExecMulti(ctx, wlm.client, stmt.execStmts...) + + // Process responses in order + for i, resp := range setResps { + key := stmt.keyOrder[i] + value := actualValues[key] + + // Check for Redis errors + if err := resp.Error(); err != nil { + wlm.logger.Debug("commit write lock failed for key", "key", key, "slot", slot, "error", err) + failed[key] = fmt.Errorf("failed to commit write lock: %w", err) + continue + } + + // Check the Lua script return value (0 = lock lost) + setSuccess, err := resp.AsInt64() + if err != nil || setSuccess == 0 { + wlm.logger.Debug("commit write lock CAS failed for key", "key", key, "slot", slot) + failed[key] = fmt.Errorf("%w", errs.ErrLockLost) + continue + } + + succeeded[key] = value + } + } + + // Populate client-side cache for all successfully set values + // This ensures other clients can see the values via CSC + // Critical for empty strings and all cached values + if len(succeeded) > 0 { + cacheCommands := make([]rueidis.CacheableTTL, 0, len(succeeded)) + for key := range succeeded { + cacheCommands = append(cacheCommands, rueidis.CacheableTTL{ + Cmd: wlm.client.B().Get().Key(key).Cache(), + TTL: ttl, + }) + } + _ = wlm.client.DoMultiCache(ctx, cacheCommands...) + } + + return succeeded, failed +} + +// slotCommitStatements holds commit statements grouped by slot. +type slotCommitStatements struct { + keyOrder []string + execStmts []rueidis.LuaExec +} + +// groupCommitsBySlot groups commit operations by Redis cluster slot. +func (wlm *CASWriteLockManager) groupCommitsBySlot( + lockValues map[string]string, + actualValues map[string]string, + ttl time.Duration, +) map[uint16]slotCommitStatements { + estimatedSlots, estimatedPerSlot := cmdx.EstimateSlotDistribution(len(lockValues)) + stmtsBySlot := make(map[uint16]slotCommitStatements, estimatedSlots) + + // Pre-calculate TTL string once + ttlStr := strconv.FormatInt(ttl.Milliseconds(), 10) + + for key, lockVal := range lockValues { + actualVal, ok := actualValues[key] + if !ok { + // Skip keys without actual values (shouldn't happen, but be defensive) + wlm.logger.Error("no actual value for key in CommitWriteLocks", "key", key) + continue + } + + slot := cmdx.Slot(key) + stmt := stmtsBySlot[slot] + + // Pre-allocate slices on first access to this slot + if stmt.keyOrder == nil { + stmt.keyOrder = make([]string, 0, estimatedPerSlot) + stmt.execStmts = make([]rueidis.LuaExec, 0, estimatedPerSlot) + } + + stmt.keyOrder = append(stmt.keyOrder, key) + stmt.execStmts = append(stmt.execStmts, rueidis.LuaExec{ + Keys: []string{key}, + Args: []string{actualVal, ttlStr, lockVal}, + }) + stmtsBySlot[slot] = stmt + } + + return stmtsBySlot +} + +// Lua scripts for write lock operations. +var ( + acquireWriteLockScript = luascript.New(` + local key = KEYS[1] + local lock_value = ARGV[1] + local ttl = ARGV[2] + local lock_prefix = ARGV[3] + + local current = redis.call("GET", key) + + -- If key is empty, we can set our lock + if current == false then + redis.call("SET", key, lock_value, "PX", ttl) + return 1 + end + + -- If current value is a lock (has lock prefix), we cannot acquire + if string.sub(current, 1, string.len(lock_prefix)) == lock_prefix then + return 0 + end + + -- Current value is a real value (not a lock), we can overwrite with our lock + redis.call("SET", key, lock_value, "PX", ttl) + return 1 + `) + + acquireWriteLockWithBackupScript = luascript.New(` + local key = KEYS[1] + local lock_value = ARGV[1] + local ttl = ARGV[2] + local lock_prefix = ARGV[3] + + local current = redis.call("GET", key) + + -- If key is empty, acquire and return false (nothing to restore) + if current == false then + redis.call("SET", key, lock_value, "PX", ttl) + return {1, false} + end + + -- If current value is a lock, cannot acquire + if string.sub(current, 1, string.len(lock_prefix)) == lock_prefix then + return {0, current} + end + + -- Current value is real, acquire and save it for rollback + redis.call("SET", key, lock_value, "PX", ttl) + return {1, current} + `) + + unlockKeyScript = luascript.New(` + local key = KEYS[1] + local expected = ARGV[1] + if redis.call("GET", key) == expected then + return redis.call("DEL", key) + else + return 0 + end + `) + + setIfLockScript = luascript.New(` + local key = KEYS[1] + local expected_lock = ARGV[1] + local new_value = ARGV[2] + if redis.call("GET", key) == expected_lock then + return redis.call("SET", key, new_value) + else + return 0 + end + `) + + commitWriteLockScript = luascript.New(` + local key = KEYS[1] + local value = ARGV[1] + local ttl = ARGV[2] + local expected_lock = ARGV[3] + + local current = redis.call("GET", key) + + -- STRICT CAS: We can ONLY set if we still hold our exact lock value + -- This prevents Set from overwriting ForceSet values that stole our lock + if current == expected_lock then + redis.call("SET", key, value, "PX", ttl) + return 1 + else + return 0 + end + `) +) + +// getLockPrefix returns the lock prefix from the LockManager. +// This ensures WriteLockManager and LockManager use the same prefix for cohesive lock checking. +func (wlm *CASWriteLockManager) getLockPrefix() string { + return wlm.lockManager.LockPrefix() +} + +const ( + // lockRetryInterval is the interval for periodic lock acquisition retries. + // Used during sequential acquisition when waiting for contested locks. + lockRetryInterval = 50 * time.Millisecond +) + +// AcquireMultiWriteLocksSequential acquires write locks for multiple keys using sequential +// acquisition with automatic deadlock prevention and rollback. +// See WriteLockManager interface documentation for full details. +func (wlm *CASWriteLockManager) AcquireMultiWriteLocksSequential(ctx context.Context, keys []string) (map[string]string, map[string]string, error) { + if len(keys) == 0 { + return make(map[string]string), make(map[string]string), nil + } + + // Sort keys for consistent lock ordering (prevents deadlocks) + sortedKeys := wlm.sortKeys(keys) + lockValues := make(map[string]string) + allSavedValues := make(map[string]string) // Track all saved values for final restoration + ticker := time.NewTicker(lockRetryInterval) + defer ticker.Stop() + + remainingKeys := sortedKeys + for len(remainingKeys) > 0 { + done, savedValues, err := wlm.tryAcquireBatchAndProcess(ctx, sortedKeys, remainingKeys, lockValues, ticker) + if err != nil { + return nil, nil, err + } + // Merge saved values from this iteration + for k, v := range savedValues { + allSavedValues[k] = v + } + if done { + return lockValues, allSavedValues, nil + } + // Update remaining keys for next iteration + remainingKeys = wlm.keysNotIn(sortedKeys, lockValues) + } + + return lockValues, allSavedValues, nil +} + +// tryAcquireBatchAndProcess attempts to acquire locks for remaining keys and processes the result. +// Returns: +// - done: true if all locks were acquired +// - savedValues: map of keys to their previous values from this batch +// - error: if acquisition failed +func (wlm *CASWriteLockManager) tryAcquireBatchAndProcess( + ctx context.Context, + sortedKeys []string, + remainingKeys []string, + lockValues map[string]string, + ticker *time.Ticker, +) (bool, map[string]string, error) { + // Try to batch-acquire all remaining keys with backup + acquired, savedValues, failed, err := wlm.AcquireMultiWriteLocks(ctx, remainingKeys) + if err != nil { + // Critical error - release all locks we've acquired so far + wlm.ReleaseWriteLocks(ctx, lockValues) + return false, nil, err + } + + // Success: All remaining keys acquired + if len(failed) == 0 { + // Add newly acquired locks to our collection + for k, v := range acquired { + lockValues[k] = v + } + wlm.logger.Debug("AcquireMultiWriteLocksSequential completed", "keys", sortedKeys, "count", len(lockValues)) + return true, savedValues, nil + } + + // Handle partial failure + err = wlm.handlePartialLockFailure(ctx, remainingKeys, acquired, savedValues, failed, lockValues, ticker) + // Return saved values from keys we kept (sequential ones) + keptSavedValues := make(map[string]string) + for k := range lockValues { + if v, ok := savedValues[k]; ok { + keptSavedValues[k] = v + } + } + return false, keptSavedValues, err +} + +// handlePartialLockFailure processes a partial lock acquisition failure. +// It keeps locks acquired in sequential order and restores out-of-order locks. +func (wlm *CASWriteLockManager) handlePartialLockFailure( + ctx context.Context, + remainingKeys []string, + acquired map[string]string, + savedValues map[string]string, + failed []string, + lockValues map[string]string, + ticker *time.Ticker, +) error { + wlm.logger.Debug("partial acquisition, analyzing sequential locks", + "acquired_this_batch", len(acquired), + "failed", len(failed), + "total_acquired_so_far", len(lockValues)) + + // Find the first failed key in sorted order + firstFailedKey := wlm.findFirstKey(remainingKeys, failed) + + // Determine which acquired keys to keep vs restore + toKeep, toRestore := wlm.splitAcquiredBySequence(remainingKeys, acquired, firstFailedKey) + + // Restore keys that were acquired out of order + if len(toRestore) > 0 { + wlm.logger.Debug("restoring out-of-order locks", + "restore_count", len(toRestore), + "first_failed", firstFailedKey) + wlm.RestoreValues(ctx, toRestore, savedValues) + } + + // Add sequential locks to our permanent collection + for k, v := range toKeep { + lockValues[k] = v + } + + // Touch/refresh TTL on all locks we're keeping to prevent expiration + if len(lockValues) > 0 { + wlm.TouchLocks(ctx, lockValues) + } + + // Wait for the first failed key to be released + err := wlm.lockManager.WaitForKeyWithRetry(ctx, firstFailedKey, ticker) + if err != nil { + // Context cancelled or timeout - release all locks + wlm.ReleaseWriteLocks(ctx, lockValues) + return err + } + + return nil +} + +// sortKeys creates a sorted copy of the keys to ensure consistent lock ordering. +func (wlm *CASWriteLockManager) sortKeys(keys []string) []string { + sorted := make([]string, len(keys)) + copy(sorted, keys) + sort.Strings(sorted) + return sorted +} + +// findFirstKey finds the first key from sortedKeys that appears in targetKeys. +// This ensures sequential lock acquisition by always waiting for the first failed key. +func (wlm *CASWriteLockManager) findFirstKey(sortedKeys []string, targetKeys []string) string { + // Convert targetKeys to a set for O(1) lookup + targetSet := make(map[string]bool, len(targetKeys)) + for _, k := range targetKeys { + targetSet[k] = true + } + + // Find first key in sorted order + for _, k := range sortedKeys { + if targetSet[k] { + return k + } + } + + // Should never happen if targetKeys is non-empty and derived from sortedKeys + return targetKeys[0] +} + +// splitAcquiredBySequence splits acquired keys into those that should be kept (sequential) +// vs those that should be restored (out of order after a gap). +// +// Example: sortedKeys=[key1,key2,key3], acquired=[key1,key3], firstFailedKey=key2. +// Result: keep=[key1], restore=[key3]. +func (wlm *CASWriteLockManager) splitAcquiredBySequence( + sortedKeys []string, + acquired map[string]string, + firstFailedKey string, +) (keep map[string]string, restore map[string]string) { + keep = make(map[string]string) + restore = make(map[string]string) + + foundFailedKey := false + for _, key := range sortedKeys { + if key == firstFailedKey { + foundFailedKey = true + continue + } + + lockVal, wasAcquired := acquired[key] + if !wasAcquired { + continue + } + + if foundFailedKey { + // This key comes after the failed key - out of order, must restore + restore[key] = lockVal + } else { + // This key comes before the failed key - sequential, can keep + keep[key] = lockVal + } + } + + return keep, restore +} + +// keysNotIn returns keys from sortedKeys that are not in the acquired map. +func (wlm *CASWriteLockManager) keysNotIn(sortedKeys []string, acquired map[string]string) []string { + remaining := make([]string, 0, len(sortedKeys)) + for _, key := range sortedKeys { + if _, ok := acquired[key]; !ok { + remaining = append(remaining, key) + } + } + return remaining +} diff --git a/internal/writelock/writelock_test.go b/internal/writelock/writelock_test.go new file mode 100644 index 0000000..bb56bd3 --- /dev/null +++ b/internal/writelock/writelock_test.go @@ -0,0 +1,275 @@ +//go:build integration + +package writelock_test + +import ( + "context" + "testing" + "time" + + "github.com/redis/rueidis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache/internal/lockmanager" + "github.com/dcbickfo/redcache/internal/writelock" +) + +// noopLogger is a no-op logger for testing +type noopLogger struct{} + +func (noopLogger) Debug(msg string, args ...any) {} +func (noopLogger) Error(msg string, args ...any) {} + +// makeClient creates a Redis client for testing +func makeClient(t *testing.T) rueidis.Client { + t.Helper() + client, err := rueidis.NewClient(rueidis.ClientOption{ + InitAddress: []string{"127.0.0.1:6379"}, + }) + require.NoError(t, err, "Failed to connect to Redis") + return client +} + +// makeWriteLockManager creates a WriteLockManager for testing +func makeWriteLockManager(t *testing.T, client rueidis.Client) writelock.WriteLockManager { + t.Helper() + + lockMgr := lockmanager.NewDistributedLockManager(lockmanager.Config{ + Client: client, + LockPrefix: "__test:lock:", + LockTTL: 5 * time.Second, + Logger: noopLogger{}, + }) + + return writelock.NewCASWriteLockManager(writelock.Config{ + Client: client, + LockTTL: 5 * time.Second, + LockManager: lockMgr, + Logger: noopLogger{}, + }) +} + +// TestWriteLockManager_AcquireWriteLock tests single write lock acquisition +func TestWriteLockManager_AcquireWriteLock(t *testing.T) { + client := makeClient(t) + defer client.Close() + + wlm := makeWriteLockManager(t, client) + ctx := t.Context() + + t.Run("acquires lock for new key", func(t *testing.T) { + lockVal, err := wlm.AcquireWriteLock(ctx, "key1") + require.NoError(t, err) + assert.NotEmpty(t, lockVal) + + // Clean up + err = wlm.ReleaseWriteLock(ctx, "key1", lockVal) + assert.NoError(t, err) + }) + + t.Run("context cancellation", func(t *testing.T) { + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := wlm.AcquireWriteLock(cancelCtx, "key2") + assert.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) +} + +// TestWriteLockManager_AcquireMultiWriteLocks tests batch write lock acquisition +func TestWriteLockManager_AcquireMultiWriteLocks(t *testing.T) { + client := makeClient(t) + defer client.Close() + + wlm := makeWriteLockManager(t, client) + ctx := t.Context() + + t.Run("acquires locks for multiple keys", func(t *testing.T) { + keys := []string{"k1", "k2", "k3"} + + acquired, savedValues, failed, err := wlm.AcquireMultiWriteLocks(ctx, keys) + require.NoError(t, err) + assert.Len(t, acquired, 3) + assert.Empty(t, failed) + assert.NotNil(t, savedValues) + + // Clean up + wlm.ReleaseWriteLocks(ctx, acquired) + }) + + t.Run("handles partial acquisition failure", func(t *testing.T) { + keys := []string{"p1", "p2", "p3"} + + // Acquire first batch + acquired1, _, _, err := wlm.AcquireMultiWriteLocks(ctx, keys) + require.NoError(t, err) + defer wlm.ReleaseWriteLocks(ctx, acquired1) + + // Try to acquire again - should fail + acquired2, _, failed2, err := wlm.AcquireMultiWriteLocks(ctx, keys) + require.NoError(t, err) + assert.Empty(t, acquired2) + assert.Len(t, failed2, 3) + }) +} + +// TestWriteLockManager_ReleaseWriteLock tests lock release +func TestWriteLockManager_ReleaseWriteLock(t *testing.T) { + client := makeClient(t) + defer client.Close() + + wlm := makeWriteLockManager(t, client) + ctx := t.Context() + + t.Run("releases owned lock", func(t *testing.T) { + lockVal, err := wlm.AcquireWriteLock(ctx, "rel1") + require.NoError(t, err) + + err = wlm.ReleaseWriteLock(ctx, "rel1", lockVal) + assert.NoError(t, err) + }) + + t.Run("fails to release with wrong lock value", func(t *testing.T) { + lockVal, err := wlm.AcquireWriteLock(ctx, "rel2") + require.NoError(t, err) + defer wlm.ReleaseWriteLock(ctx, "rel2", lockVal) + + err = wlm.ReleaseWriteLock(ctx, "rel2", "wrong-value") + assert.NoError(t, err) // Doesn't error, just doesn't release + }) +} + +// TestWriteLockManager_CommitWriteLocks tests CAS commit operation +func TestWriteLockManager_CommitWriteLocks(t *testing.T) { + client := makeClient(t) + defer client.Close() + + wlm := makeWriteLockManager(t, client) + ctx := t.Context() + + t.Run("commits locks to real values", func(t *testing.T) { + keys := []string{"c1", "c2"} + acquired, _, _, err := wlm.AcquireMultiWriteLocks(ctx, keys) + require.NoError(t, err) + + actualValues := map[string]string{ + "c1": "value1", + "c2": "value2", + } + + succeeded, failed := wlm.CommitWriteLocks(ctx, 10*time.Second, acquired, actualValues) + assert.Len(t, succeeded, 2) + assert.Empty(t, failed) + + // Verify values are set + resp := client.Do(ctx, client.B().Get().Key("c1").Build()) + val, _ := resp.ToString() + assert.Equal(t, "value1", val) + }) + + t.Run("fails commit if lock lost", func(t *testing.T) { + lockVal, err := wlm.AcquireWriteLock(ctx, "c3") + require.NoError(t, err) + + // Manually overwrite the lock + client.Do(ctx, client.B().Set().Key("c3").Value("stolen").Build()) + + lockValues := map[string]string{"c3": lockVal} + actualValues := map[string]string{"c3": "new-value"} + + succeeded, failed := wlm.CommitWriteLocks(ctx, 10*time.Second, lockValues, actualValues) + assert.Empty(t, succeeded) + assert.Len(t, failed, 1) + }) +} + +// TestWriteLockManager_AcquireMultiWriteLocksSequential tests sequential lock acquisition +func TestWriteLockManager_AcquireMultiWriteLocksSequential(t *testing.T) { + client := makeClient(t) + defer client.Close() + + wlm := makeWriteLockManager(t, client) + ctx := t.Context() + + t.Run("acquires all locks sequentially", func(t *testing.T) { + keys := []string{"s1", "s2", "s3"} + + lockValues, savedValues, err := wlm.AcquireMultiWriteLocksSequential(ctx, keys) + require.NoError(t, err) + assert.Len(t, lockValues, 3) + assert.NotNil(t, savedValues) + + // Verify locks are held + for key := range lockValues { + assert.Contains(t, keys, key) + } + + // Clean up + wlm.ReleaseWriteLocks(ctx, lockValues) + }) + + t.Run("context cancellation", func(t *testing.T) { + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := wlm.AcquireMultiWriteLocksSequential(cancelCtx, []string{"s4", "s5"}) + assert.Error(t, err) + }) +} + +// TestWriteLockManager_TouchLocks tests lock TTL refresh +func TestWriteLockManager_TouchLocks(t *testing.T) { + client := makeClient(t) + defer client.Close() + + wlm := makeWriteLockManager(t, client) + ctx := t.Context() + + t.Run("refreshes lock TTL", func(t *testing.T) { + lockVal, err := wlm.AcquireWriteLock(ctx, "touch1") + require.NoError(t, err) + defer wlm.ReleaseWriteLock(ctx, "touch1", lockVal) + + lockValues := map[string]string{"touch1": lockVal} + + // Touch the lock + wlm.TouchLocks(ctx, lockValues) + + // Verify lock still exists + resp := client.Do(ctx, client.B().Ttl().Key("touch1").Build()) + ttl, err := resp.AsInt64() + require.NoError(t, err) + assert.Greater(t, ttl, int64(0)) + }) +} + +// TestWriteLockManager_RestoreValues tests backup restoration +func TestWriteLockManager_RestoreValues(t *testing.T) { + client := makeClient(t) + defer client.Close() + + wlm := makeWriteLockManager(t, client) + ctx := t.Context() + + t.Run("restores backed up values", func(t *testing.T) { + // Set initial value + client.Do(ctx, client.B().Set().Key("restore1").Value("original").Build()) + + keys := []string{"restore1"} + lockValues, savedValues, err := wlm.AcquireMultiWriteLocksSequential(ctx, keys) + require.NoError(t, err) + + // Verify original value was saved + assert.Equal(t, "original", savedValues["restore1"]) + + // Restore the value + wlm.RestoreValues(ctx, savedValues, lockValues) + + // Verify value was restored + resp := client.Do(ctx, client.B().Get().Key("restore1").Build()) + val, _ := resp.ToString() + assert.Equal(t, "original", val) + }) +} diff --git a/mocks/invalidation/mock_Handler.go b/mocks/invalidation/mock_Handler.go new file mode 100644 index 0000000..602bd68 --- /dev/null +++ b/mocks/invalidation/mock_Handler.go @@ -0,0 +1,173 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mockinvalidation + +import ( + "iter" + "sync" + + "github.com/dcbickfo/redcache/internal/invalidation" + "github.com/redis/rueidis" +) + +// Ensure that MockHandler does implement invalidation.Handler. +// If this is not the case, regenerate this file with mockery. +var _ invalidation.Handler = &MockHandler{} + +// MockHandler is a mock implementation of invalidation.Handler. +// +// func TestSomethingThatUsesHandler(t *testing.T) { +// +// // make and configure a mocked invalidation.Handler +// mockedHandler := &MockHandler{ +// OnInvalidateFunc: func(messages []rueidis.RedisMessage) { +// panic("mock out the OnInvalidate method") +// }, +// RegisterFunc: func(key string) <-chan struct{} { +// panic("mock out the Register method") +// }, +// RegisterAllFunc: func(keys iter.Seq[string], length int) map[string]<-chan struct{} { +// panic("mock out the RegisterAll method") +// }, +// } +// +// // use mockedHandler in code that requires invalidation.Handler +// // and then make assertions. +// +// } +type MockHandler struct { + // OnInvalidateFunc mocks the OnInvalidate method. + OnInvalidateFunc func(messages []rueidis.RedisMessage) + + // RegisterFunc mocks the Register method. + RegisterFunc func(key string) <-chan struct{} + + // RegisterAllFunc mocks the RegisterAll method. + RegisterAllFunc func(keys iter.Seq[string], length int) map[string]<-chan struct{} + + // calls tracks calls to the methods. + calls struct { + // OnInvalidate holds details about calls to the OnInvalidate method. + OnInvalidate []struct { + // Messages is the messages argument value. + Messages []rueidis.RedisMessage + } + // Register holds details about calls to the Register method. + Register []struct { + // Key is the key argument value. + Key string + } + // RegisterAll holds details about calls to the RegisterAll method. + RegisterAll []struct { + // Keys is the keys argument value. + Keys iter.Seq[string] + // Length is the length argument value. + Length int + } + } + lockOnInvalidate sync.RWMutex + lockRegister sync.RWMutex + lockRegisterAll sync.RWMutex +} + +// OnInvalidate calls OnInvalidateFunc. +func (mock *MockHandler) OnInvalidate(messages []rueidis.RedisMessage) { + if mock.OnInvalidateFunc == nil { + panic("MockHandler.OnInvalidateFunc: method is nil but Handler.OnInvalidate was just called") + } + callInfo := struct { + Messages []rueidis.RedisMessage + }{ + Messages: messages, + } + mock.lockOnInvalidate.Lock() + mock.calls.OnInvalidate = append(mock.calls.OnInvalidate, callInfo) + mock.lockOnInvalidate.Unlock() + mock.OnInvalidateFunc(messages) +} + +// OnInvalidateCalls gets all the calls that were made to OnInvalidate. +// Check the length with: +// +// len(mockedHandler.OnInvalidateCalls()) +func (mock *MockHandler) OnInvalidateCalls() []struct { + Messages []rueidis.RedisMessage +} { + var calls []struct { + Messages []rueidis.RedisMessage + } + mock.lockOnInvalidate.RLock() + calls = mock.calls.OnInvalidate + mock.lockOnInvalidate.RUnlock() + return calls +} + +// Register calls RegisterFunc. +func (mock *MockHandler) Register(key string) <-chan struct{} { + if mock.RegisterFunc == nil { + panic("MockHandler.RegisterFunc: method is nil but Handler.Register was just called") + } + callInfo := struct { + Key string + }{ + Key: key, + } + mock.lockRegister.Lock() + mock.calls.Register = append(mock.calls.Register, callInfo) + mock.lockRegister.Unlock() + return mock.RegisterFunc(key) +} + +// RegisterCalls gets all the calls that were made to Register. +// Check the length with: +// +// len(mockedHandler.RegisterCalls()) +func (mock *MockHandler) RegisterCalls() []struct { + Key string +} { + var calls []struct { + Key string + } + mock.lockRegister.RLock() + calls = mock.calls.Register + mock.lockRegister.RUnlock() + return calls +} + +// RegisterAll calls RegisterAllFunc. +func (mock *MockHandler) RegisterAll(keys iter.Seq[string], length int) map[string]<-chan struct{} { + if mock.RegisterAllFunc == nil { + panic("MockHandler.RegisterAllFunc: method is nil but Handler.RegisterAll was just called") + } + callInfo := struct { + Keys iter.Seq[string] + Length int + }{ + Keys: keys, + Length: length, + } + mock.lockRegisterAll.Lock() + mock.calls.RegisterAll = append(mock.calls.RegisterAll, callInfo) + mock.lockRegisterAll.Unlock() + return mock.RegisterAllFunc(keys, length) +} + +// RegisterAllCalls gets all the calls that were made to RegisterAll. +// Check the length with: +// +// len(mockedHandler.RegisterAllCalls()) +func (mock *MockHandler) RegisterAllCalls() []struct { + Keys iter.Seq[string] + Length int +} { + var calls []struct { + Keys iter.Seq[string] + Length int + } + mock.lockRegisterAll.RLock() + calls = mock.calls.RegisterAll + mock.lockRegisterAll.RUnlock() + return calls +} diff --git a/mocks/invalidation/mock_Logger.go b/mocks/invalidation/mock_Logger.go new file mode 100644 index 0000000..386dcdd --- /dev/null +++ b/mocks/invalidation/mock_Logger.go @@ -0,0 +1,133 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mockinvalidation + +import ( + "sync" + + "github.com/dcbickfo/redcache/internal/invalidation" +) + +// Ensure that MockLogger does implement invalidation.Logger. +// If this is not the case, regenerate this file with mockery. +var _ invalidation.Logger = &MockLogger{} + +// MockLogger is a mock implementation of invalidation.Logger. +// +// func TestSomethingThatUsesLogger(t *testing.T) { +// +// // make and configure a mocked invalidation.Logger +// mockedLogger := &MockLogger{ +// DebugFunc: func(msg string, args ...any) { +// panic("mock out the Debug method") +// }, +// ErrorFunc: func(msg string, args ...any) { +// panic("mock out the Error method") +// }, +// } +// +// // use mockedLogger in code that requires invalidation.Logger +// // and then make assertions. +// +// } +type MockLogger struct { + // DebugFunc mocks the Debug method. + DebugFunc func(msg string, args ...any) + + // ErrorFunc mocks the Error method. + ErrorFunc func(msg string, args ...any) + + // calls tracks calls to the methods. + calls struct { + // Debug holds details about calls to the Debug method. + Debug []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + // Error holds details about calls to the Error method. + Error []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + } + lockDebug sync.RWMutex + lockError sync.RWMutex +} + +// Debug calls DebugFunc. +func (mock *MockLogger) Debug(msg string, args ...any) { + if mock.DebugFunc == nil { + panic("MockLogger.DebugFunc: method is nil but Logger.Debug was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockDebug.Lock() + mock.calls.Debug = append(mock.calls.Debug, callInfo) + mock.lockDebug.Unlock() + mock.DebugFunc(msg, args...) +} + +// DebugCalls gets all the calls that were made to Debug. +// Check the length with: +// +// len(mockedLogger.DebugCalls()) +func (mock *MockLogger) DebugCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockDebug.RLock() + calls = mock.calls.Debug + mock.lockDebug.RUnlock() + return calls +} + +// Error calls ErrorFunc. +func (mock *MockLogger) Error(msg string, args ...any) { + if mock.ErrorFunc == nil { + panic("MockLogger.ErrorFunc: method is nil but Logger.Error was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockError.Lock() + mock.calls.Error = append(mock.calls.Error, callInfo) + mock.lockError.Unlock() + mock.ErrorFunc(msg, args...) +} + +// ErrorCalls gets all the calls that were made to Error. +// Check the length with: +// +// len(mockedLogger.ErrorCalls()) +func (mock *MockLogger) ErrorCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockError.RLock() + calls = mock.calls.Error + mock.lockError.RUnlock() + return calls +} diff --git a/mocks/lockmanager/mock_LockChecker.go b/mocks/lockmanager/mock_LockChecker.go new file mode 100644 index 0000000..2793621 --- /dev/null +++ b/mocks/lockmanager/mock_LockChecker.go @@ -0,0 +1,135 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mocklockmanager + +import ( + "context" + "sync" + + "github.com/dcbickfo/redcache/internal/lockmanager" + "github.com/redis/rueidis" +) + +// Ensure that MockLockChecker does implement lockmanager.LockChecker. +// If this is not the case, regenerate this file with mockery. +var _ lockmanager.LockChecker = &MockLockChecker{} + +// MockLockChecker is a mock implementation of lockmanager.LockChecker. +// +// func TestSomethingThatUsesLockChecker(t *testing.T) { +// +// // make and configure a mocked lockmanager.LockChecker +// mockedLockChecker := &MockLockChecker{ +// CheckKeyLockedFunc: func(ctx context.Context, client rueidis.Client, key string) bool { +// panic("mock out the CheckKeyLocked method") +// }, +// HasLockFunc: func(val string) bool { +// panic("mock out the HasLock method") +// }, +// } +// +// // use mockedLockChecker in code that requires lockmanager.LockChecker +// // and then make assertions. +// +// } +type MockLockChecker struct { + // CheckKeyLockedFunc mocks the CheckKeyLocked method. + CheckKeyLockedFunc func(ctx context.Context, client rueidis.Client, key string) bool + + // HasLockFunc mocks the HasLock method. + HasLockFunc func(val string) bool + + // calls tracks calls to the methods. + calls struct { + // CheckKeyLocked holds details about calls to the CheckKeyLocked method. + CheckKeyLocked []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Client is the client argument value. + Client rueidis.Client + // Key is the key argument value. + Key string + } + // HasLock holds details about calls to the HasLock method. + HasLock []struct { + // Val is the val argument value. + Val string + } + } + lockCheckKeyLocked sync.RWMutex + lockHasLock sync.RWMutex +} + +// CheckKeyLocked calls CheckKeyLockedFunc. +func (mock *MockLockChecker) CheckKeyLocked(ctx context.Context, client rueidis.Client, key string) bool { + if mock.CheckKeyLockedFunc == nil { + panic("MockLockChecker.CheckKeyLockedFunc: method is nil but LockChecker.CheckKeyLocked was just called") + } + callInfo := struct { + Ctx context.Context + Client rueidis.Client + Key string + }{ + Ctx: ctx, + Client: client, + Key: key, + } + mock.lockCheckKeyLocked.Lock() + mock.calls.CheckKeyLocked = append(mock.calls.CheckKeyLocked, callInfo) + mock.lockCheckKeyLocked.Unlock() + return mock.CheckKeyLockedFunc(ctx, client, key) +} + +// CheckKeyLockedCalls gets all the calls that were made to CheckKeyLocked. +// Check the length with: +// +// len(mockedLockChecker.CheckKeyLockedCalls()) +func (mock *MockLockChecker) CheckKeyLockedCalls() []struct { + Ctx context.Context + Client rueidis.Client + Key string +} { + var calls []struct { + Ctx context.Context + Client rueidis.Client + Key string + } + mock.lockCheckKeyLocked.RLock() + calls = mock.calls.CheckKeyLocked + mock.lockCheckKeyLocked.RUnlock() + return calls +} + +// HasLock calls HasLockFunc. +func (mock *MockLockChecker) HasLock(val string) bool { + if mock.HasLockFunc == nil { + panic("MockLockChecker.HasLockFunc: method is nil but LockChecker.HasLock was just called") + } + callInfo := struct { + Val string + }{ + Val: val, + } + mock.lockHasLock.Lock() + mock.calls.HasLock = append(mock.calls.HasLock, callInfo) + mock.lockHasLock.Unlock() + return mock.HasLockFunc(val) +} + +// HasLockCalls gets all the calls that were made to HasLock. +// Check the length with: +// +// len(mockedLockChecker.HasLockCalls()) +func (mock *MockLockChecker) HasLockCalls() []struct { + Val string +} { + var calls []struct { + Val string + } + mock.lockHasLock.RLock() + calls = mock.calls.HasLock + mock.lockHasLock.RUnlock() + return calls +} diff --git a/mocks/lockmanager/mock_LockManager.go b/mocks/lockmanager/mock_LockManager.go new file mode 100644 index 0000000..e05a745 --- /dev/null +++ b/mocks/lockmanager/mock_LockManager.go @@ -0,0 +1,915 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mocklockmanager + +import ( + "context" + "sync" + "time" + + "github.com/dcbickfo/redcache/internal/lockmanager" + "github.com/redis/rueidis" +) + +// Ensure that MockLockManager does implement lockmanager.LockManager. +// If this is not the case, regenerate this file with mockery. +var _ lockmanager.LockManager = &MockLockManager{} + +// MockLockManager is a mock implementation of lockmanager.LockManager. +// +// func TestSomethingThatUsesLockManager(t *testing.T) { +// +// // make and configure a mocked lockmanager.LockManager +// mockedLockManager := &MockLockManager{ +// AcquireLockBlockingFunc: func(ctx context.Context, key string) (string, error) { +// panic("mock out the AcquireLockBlocking method") +// }, +// AcquireMultiLocksBlockingFunc: func(ctx context.Context, keys []string) (map[string]string, error) { +// panic("mock out the AcquireMultiLocksBlocking method") +// }, +// CheckKeyLockedFunc: func(ctx context.Context, key string) bool { +// panic("mock out the CheckKeyLocked method") +// }, +// CheckMultiKeysLockedFunc: func(ctx context.Context, keys []string) []string { +// panic("mock out the CheckMultiKeysLocked method") +// }, +// CleanupUnusedLocksFunc: func(ctx context.Context, acquiredLocks map[string]string, usedKeys map[string]string) { +// panic("mock out the CleanupUnusedLocks method") +// }, +// CommitReadLocksFunc: func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) ([]string, []string, error) { +// panic("mock out the CommitReadLocks method") +// }, +// CreateImmediateWaitHandleFunc: func() lockmanager.WaitHandle { +// panic("mock out the CreateImmediateWaitHandle method") +// }, +// GenerateLockValueFunc: func() string { +// panic("mock out the GenerateLockValue method") +// }, +// IsLockValueFunc: func(val string) bool { +// panic("mock out the IsLockValue method") +// }, +// LockPrefixFunc: func() string { +// panic("mock out the LockPrefix method") +// }, +// OnInvalidateFunc: func(messages []rueidis.RedisMessage) { +// panic("mock out the OnInvalidate method") +// }, +// ReleaseLockFunc: func(ctx context.Context, key string, lockValue string) error { +// panic("mock out the ReleaseLock method") +// }, +// ReleaseMultiLocksFunc: func(ctx context.Context, lockValues map[string]string) { +// panic("mock out the ReleaseMultiLocks method") +// }, +// TryAcquireFunc: func(ctx context.Context, key string) (string, lockmanager.WaitHandle, error) { +// panic("mock out the TryAcquire method") +// }, +// TryAcquireMultiFunc: func(ctx context.Context, keys []string) (map[string]string, map[string]lockmanager.WaitHandle, error) { +// panic("mock out the TryAcquireMulti method") +// }, +// WaitForKeyFunc: func(key string) lockmanager.WaitHandle { +// panic("mock out the WaitForKey method") +// }, +// WaitForKeyWithRetryFunc: func(ctx context.Context, key string, ticker *time.Ticker) error { +// panic("mock out the WaitForKeyWithRetry method") +// }, +// WaitForKeyWithSubscriptionFunc: func(ctx context.Context, key string, cacheTTL time.Duration) (lockmanager.WaitHandle, string, error) { +// panic("mock out the WaitForKeyWithSubscription method") +// }, +// } +// +// // use mockedLockManager in code that requires lockmanager.LockManager +// // and then make assertions. +// +// } +type MockLockManager struct { + // AcquireLockBlockingFunc mocks the AcquireLockBlocking method. + AcquireLockBlockingFunc func(ctx context.Context, key string) (string, error) + + // AcquireMultiLocksBlockingFunc mocks the AcquireMultiLocksBlocking method. + AcquireMultiLocksBlockingFunc func(ctx context.Context, keys []string) (map[string]string, error) + + // CheckKeyLockedFunc mocks the CheckKeyLocked method. + CheckKeyLockedFunc func(ctx context.Context, key string) bool + + // CheckMultiKeysLockedFunc mocks the CheckMultiKeysLocked method. + CheckMultiKeysLockedFunc func(ctx context.Context, keys []string) []string + + // CleanupUnusedLocksFunc mocks the CleanupUnusedLocks method. + CleanupUnusedLocksFunc func(ctx context.Context, acquiredLocks map[string]string, usedKeys map[string]string) + + // CommitReadLocksFunc mocks the CommitReadLocks method. + CommitReadLocksFunc func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) ([]string, []string, error) + + // CreateImmediateWaitHandleFunc mocks the CreateImmediateWaitHandle method. + CreateImmediateWaitHandleFunc func() lockmanager.WaitHandle + + // GenerateLockValueFunc mocks the GenerateLockValue method. + GenerateLockValueFunc func() string + + // IsLockValueFunc mocks the IsLockValue method. + IsLockValueFunc func(val string) bool + + // LockPrefixFunc mocks the LockPrefix method. + LockPrefixFunc func() string + + // OnInvalidateFunc mocks the OnInvalidate method. + OnInvalidateFunc func(messages []rueidis.RedisMessage) + + // ReleaseLockFunc mocks the ReleaseLock method. + ReleaseLockFunc func(ctx context.Context, key string, lockValue string) error + + // ReleaseMultiLocksFunc mocks the ReleaseMultiLocks method. + ReleaseMultiLocksFunc func(ctx context.Context, lockValues map[string]string) + + // TryAcquireFunc mocks the TryAcquire method. + TryAcquireFunc func(ctx context.Context, key string) (string, lockmanager.WaitHandle, error) + + // TryAcquireMultiFunc mocks the TryAcquireMulti method. + TryAcquireMultiFunc func(ctx context.Context, keys []string) (map[string]string, map[string]lockmanager.WaitHandle, error) + + // WaitForKeyFunc mocks the WaitForKey method. + WaitForKeyFunc func(key string) lockmanager.WaitHandle + + // WaitForKeyWithRetryFunc mocks the WaitForKeyWithRetry method. + WaitForKeyWithRetryFunc func(ctx context.Context, key string, ticker *time.Ticker) error + + // WaitForKeyWithSubscriptionFunc mocks the WaitForKeyWithSubscription method. + WaitForKeyWithSubscriptionFunc func(ctx context.Context, key string, cacheTTL time.Duration) (lockmanager.WaitHandle, string, error) + + // calls tracks calls to the methods. + calls struct { + // AcquireLockBlocking holds details about calls to the AcquireLockBlocking method. + AcquireLockBlocking []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + } + // AcquireMultiLocksBlocking holds details about calls to the AcquireMultiLocksBlocking method. + AcquireMultiLocksBlocking []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Keys is the keys argument value. + Keys []string + } + // CheckKeyLocked holds details about calls to the CheckKeyLocked method. + CheckKeyLocked []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + } + // CheckMultiKeysLocked holds details about calls to the CheckMultiKeysLocked method. + CheckMultiKeysLocked []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Keys is the keys argument value. + Keys []string + } + // CleanupUnusedLocks holds details about calls to the CleanupUnusedLocks method. + CleanupUnusedLocks []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // AcquiredLocks is the acquiredLocks argument value. + AcquiredLocks map[string]string + // UsedKeys is the usedKeys argument value. + UsedKeys map[string]string + } + // CommitReadLocks holds details about calls to the CommitReadLocks method. + CommitReadLocks []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // TTL is the ttl argument value. + TTL time.Duration + // LockValues is the lockValues argument value. + LockValues map[string]string + // ActualValues is the actualValues argument value. + ActualValues map[string]string + } + // CreateImmediateWaitHandle holds details about calls to the CreateImmediateWaitHandle method. + CreateImmediateWaitHandle []struct { + } + // GenerateLockValue holds details about calls to the GenerateLockValue method. + GenerateLockValue []struct { + } + // IsLockValue holds details about calls to the IsLockValue method. + IsLockValue []struct { + // Val is the val argument value. + Val string + } + // LockPrefix holds details about calls to the LockPrefix method. + LockPrefix []struct { + } + // OnInvalidate holds details about calls to the OnInvalidate method. + OnInvalidate []struct { + // Messages is the messages argument value. + Messages []rueidis.RedisMessage + } + // ReleaseLock holds details about calls to the ReleaseLock method. + ReleaseLock []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + // LockValue is the lockValue argument value. + LockValue string + } + // ReleaseMultiLocks holds details about calls to the ReleaseMultiLocks method. + ReleaseMultiLocks []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // LockValues is the lockValues argument value. + LockValues map[string]string + } + // TryAcquire holds details about calls to the TryAcquire method. + TryAcquire []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + } + // TryAcquireMulti holds details about calls to the TryAcquireMulti method. + TryAcquireMulti []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Keys is the keys argument value. + Keys []string + } + // WaitForKey holds details about calls to the WaitForKey method. + WaitForKey []struct { + // Key is the key argument value. + Key string + } + // WaitForKeyWithRetry holds details about calls to the WaitForKeyWithRetry method. + WaitForKeyWithRetry []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + // Ticker is the ticker argument value. + Ticker *time.Ticker + } + // WaitForKeyWithSubscription holds details about calls to the WaitForKeyWithSubscription method. + WaitForKeyWithSubscription []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + // CacheTTL is the cacheTTL argument value. + CacheTTL time.Duration + } + } + lockAcquireLockBlocking sync.RWMutex + lockAcquireMultiLocksBlocking sync.RWMutex + lockCheckKeyLocked sync.RWMutex + lockCheckMultiKeysLocked sync.RWMutex + lockCleanupUnusedLocks sync.RWMutex + lockCommitReadLocks sync.RWMutex + lockCreateImmediateWaitHandle sync.RWMutex + lockGenerateLockValue sync.RWMutex + lockIsLockValue sync.RWMutex + lockLockPrefix sync.RWMutex + lockOnInvalidate sync.RWMutex + lockReleaseLock sync.RWMutex + lockReleaseMultiLocks sync.RWMutex + lockTryAcquire sync.RWMutex + lockTryAcquireMulti sync.RWMutex + lockWaitForKey sync.RWMutex + lockWaitForKeyWithRetry sync.RWMutex + lockWaitForKeyWithSubscription sync.RWMutex +} + +// AcquireLockBlocking calls AcquireLockBlockingFunc. +func (mock *MockLockManager) AcquireLockBlocking(ctx context.Context, key string) (string, error) { + if mock.AcquireLockBlockingFunc == nil { + panic("MockLockManager.AcquireLockBlockingFunc: method is nil but LockManager.AcquireLockBlocking was just called") + } + callInfo := struct { + Ctx context.Context + Key string + }{ + Ctx: ctx, + Key: key, + } + mock.lockAcquireLockBlocking.Lock() + mock.calls.AcquireLockBlocking = append(mock.calls.AcquireLockBlocking, callInfo) + mock.lockAcquireLockBlocking.Unlock() + return mock.AcquireLockBlockingFunc(ctx, key) +} + +// AcquireLockBlockingCalls gets all the calls that were made to AcquireLockBlocking. +// Check the length with: +// +// len(mockedLockManager.AcquireLockBlockingCalls()) +func (mock *MockLockManager) AcquireLockBlockingCalls() []struct { + Ctx context.Context + Key string +} { + var calls []struct { + Ctx context.Context + Key string + } + mock.lockAcquireLockBlocking.RLock() + calls = mock.calls.AcquireLockBlocking + mock.lockAcquireLockBlocking.RUnlock() + return calls +} + +// AcquireMultiLocksBlocking calls AcquireMultiLocksBlockingFunc. +func (mock *MockLockManager) AcquireMultiLocksBlocking(ctx context.Context, keys []string) (map[string]string, error) { + if mock.AcquireMultiLocksBlockingFunc == nil { + panic("MockLockManager.AcquireMultiLocksBlockingFunc: method is nil but LockManager.AcquireMultiLocksBlocking was just called") + } + callInfo := struct { + Ctx context.Context + Keys []string + }{ + Ctx: ctx, + Keys: keys, + } + mock.lockAcquireMultiLocksBlocking.Lock() + mock.calls.AcquireMultiLocksBlocking = append(mock.calls.AcquireMultiLocksBlocking, callInfo) + mock.lockAcquireMultiLocksBlocking.Unlock() + return mock.AcquireMultiLocksBlockingFunc(ctx, keys) +} + +// AcquireMultiLocksBlockingCalls gets all the calls that were made to AcquireMultiLocksBlocking. +// Check the length with: +// +// len(mockedLockManager.AcquireMultiLocksBlockingCalls()) +func (mock *MockLockManager) AcquireMultiLocksBlockingCalls() []struct { + Ctx context.Context + Keys []string +} { + var calls []struct { + Ctx context.Context + Keys []string + } + mock.lockAcquireMultiLocksBlocking.RLock() + calls = mock.calls.AcquireMultiLocksBlocking + mock.lockAcquireMultiLocksBlocking.RUnlock() + return calls +} + +// CheckKeyLocked calls CheckKeyLockedFunc. +func (mock *MockLockManager) CheckKeyLocked(ctx context.Context, key string) bool { + if mock.CheckKeyLockedFunc == nil { + panic("MockLockManager.CheckKeyLockedFunc: method is nil but LockManager.CheckKeyLocked was just called") + } + callInfo := struct { + Ctx context.Context + Key string + }{ + Ctx: ctx, + Key: key, + } + mock.lockCheckKeyLocked.Lock() + mock.calls.CheckKeyLocked = append(mock.calls.CheckKeyLocked, callInfo) + mock.lockCheckKeyLocked.Unlock() + return mock.CheckKeyLockedFunc(ctx, key) +} + +// CheckKeyLockedCalls gets all the calls that were made to CheckKeyLocked. +// Check the length with: +// +// len(mockedLockManager.CheckKeyLockedCalls()) +func (mock *MockLockManager) CheckKeyLockedCalls() []struct { + Ctx context.Context + Key string +} { + var calls []struct { + Ctx context.Context + Key string + } + mock.lockCheckKeyLocked.RLock() + calls = mock.calls.CheckKeyLocked + mock.lockCheckKeyLocked.RUnlock() + return calls +} + +// CheckMultiKeysLocked calls CheckMultiKeysLockedFunc. +func (mock *MockLockManager) CheckMultiKeysLocked(ctx context.Context, keys []string) []string { + if mock.CheckMultiKeysLockedFunc == nil { + panic("MockLockManager.CheckMultiKeysLockedFunc: method is nil but LockManager.CheckMultiKeysLocked was just called") + } + callInfo := struct { + Ctx context.Context + Keys []string + }{ + Ctx: ctx, + Keys: keys, + } + mock.lockCheckMultiKeysLocked.Lock() + mock.calls.CheckMultiKeysLocked = append(mock.calls.CheckMultiKeysLocked, callInfo) + mock.lockCheckMultiKeysLocked.Unlock() + return mock.CheckMultiKeysLockedFunc(ctx, keys) +} + +// CheckMultiKeysLockedCalls gets all the calls that were made to CheckMultiKeysLocked. +// Check the length with: +// +// len(mockedLockManager.CheckMultiKeysLockedCalls()) +func (mock *MockLockManager) CheckMultiKeysLockedCalls() []struct { + Ctx context.Context + Keys []string +} { + var calls []struct { + Ctx context.Context + Keys []string + } + mock.lockCheckMultiKeysLocked.RLock() + calls = mock.calls.CheckMultiKeysLocked + mock.lockCheckMultiKeysLocked.RUnlock() + return calls +} + +// CleanupUnusedLocks calls CleanupUnusedLocksFunc. +func (mock *MockLockManager) CleanupUnusedLocks(ctx context.Context, acquiredLocks map[string]string, usedKeys map[string]string) { + if mock.CleanupUnusedLocksFunc == nil { + panic("MockLockManager.CleanupUnusedLocksFunc: method is nil but LockManager.CleanupUnusedLocks was just called") + } + callInfo := struct { + Ctx context.Context + AcquiredLocks map[string]string + UsedKeys map[string]string + }{ + Ctx: ctx, + AcquiredLocks: acquiredLocks, + UsedKeys: usedKeys, + } + mock.lockCleanupUnusedLocks.Lock() + mock.calls.CleanupUnusedLocks = append(mock.calls.CleanupUnusedLocks, callInfo) + mock.lockCleanupUnusedLocks.Unlock() + mock.CleanupUnusedLocksFunc(ctx, acquiredLocks, usedKeys) +} + +// CleanupUnusedLocksCalls gets all the calls that were made to CleanupUnusedLocks. +// Check the length with: +// +// len(mockedLockManager.CleanupUnusedLocksCalls()) +func (mock *MockLockManager) CleanupUnusedLocksCalls() []struct { + Ctx context.Context + AcquiredLocks map[string]string + UsedKeys map[string]string +} { + var calls []struct { + Ctx context.Context + AcquiredLocks map[string]string + UsedKeys map[string]string + } + mock.lockCleanupUnusedLocks.RLock() + calls = mock.calls.CleanupUnusedLocks + mock.lockCleanupUnusedLocks.RUnlock() + return calls +} + +// CommitReadLocks calls CommitReadLocksFunc. +func (mock *MockLockManager) CommitReadLocks(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) ([]string, []string, error) { + if mock.CommitReadLocksFunc == nil { + panic("MockLockManager.CommitReadLocksFunc: method is nil but LockManager.CommitReadLocks was just called") + } + callInfo := struct { + Ctx context.Context + TTL time.Duration + LockValues map[string]string + ActualValues map[string]string + }{ + Ctx: ctx, + TTL: ttl, + LockValues: lockValues, + ActualValues: actualValues, + } + mock.lockCommitReadLocks.Lock() + mock.calls.CommitReadLocks = append(mock.calls.CommitReadLocks, callInfo) + mock.lockCommitReadLocks.Unlock() + return mock.CommitReadLocksFunc(ctx, ttl, lockValues, actualValues) +} + +// CommitReadLocksCalls gets all the calls that were made to CommitReadLocks. +// Check the length with: +// +// len(mockedLockManager.CommitReadLocksCalls()) +func (mock *MockLockManager) CommitReadLocksCalls() []struct { + Ctx context.Context + TTL time.Duration + LockValues map[string]string + ActualValues map[string]string +} { + var calls []struct { + Ctx context.Context + TTL time.Duration + LockValues map[string]string + ActualValues map[string]string + } + mock.lockCommitReadLocks.RLock() + calls = mock.calls.CommitReadLocks + mock.lockCommitReadLocks.RUnlock() + return calls +} + +// CreateImmediateWaitHandle calls CreateImmediateWaitHandleFunc. +func (mock *MockLockManager) CreateImmediateWaitHandle() lockmanager.WaitHandle { + if mock.CreateImmediateWaitHandleFunc == nil { + panic("MockLockManager.CreateImmediateWaitHandleFunc: method is nil but LockManager.CreateImmediateWaitHandle was just called") + } + callInfo := struct { + }{} + mock.lockCreateImmediateWaitHandle.Lock() + mock.calls.CreateImmediateWaitHandle = append(mock.calls.CreateImmediateWaitHandle, callInfo) + mock.lockCreateImmediateWaitHandle.Unlock() + return mock.CreateImmediateWaitHandleFunc() +} + +// CreateImmediateWaitHandleCalls gets all the calls that were made to CreateImmediateWaitHandle. +// Check the length with: +// +// len(mockedLockManager.CreateImmediateWaitHandleCalls()) +func (mock *MockLockManager) CreateImmediateWaitHandleCalls() []struct { +} { + var calls []struct { + } + mock.lockCreateImmediateWaitHandle.RLock() + calls = mock.calls.CreateImmediateWaitHandle + mock.lockCreateImmediateWaitHandle.RUnlock() + return calls +} + +// GenerateLockValue calls GenerateLockValueFunc. +func (mock *MockLockManager) GenerateLockValue() string { + if mock.GenerateLockValueFunc == nil { + panic("MockLockManager.GenerateLockValueFunc: method is nil but LockManager.GenerateLockValue was just called") + } + callInfo := struct { + }{} + mock.lockGenerateLockValue.Lock() + mock.calls.GenerateLockValue = append(mock.calls.GenerateLockValue, callInfo) + mock.lockGenerateLockValue.Unlock() + return mock.GenerateLockValueFunc() +} + +// GenerateLockValueCalls gets all the calls that were made to GenerateLockValue. +// Check the length with: +// +// len(mockedLockManager.GenerateLockValueCalls()) +func (mock *MockLockManager) GenerateLockValueCalls() []struct { +} { + var calls []struct { + } + mock.lockGenerateLockValue.RLock() + calls = mock.calls.GenerateLockValue + mock.lockGenerateLockValue.RUnlock() + return calls +} + +// IsLockValue calls IsLockValueFunc. +func (mock *MockLockManager) IsLockValue(val string) bool { + if mock.IsLockValueFunc == nil { + panic("MockLockManager.IsLockValueFunc: method is nil but LockManager.IsLockValue was just called") + } + callInfo := struct { + Val string + }{ + Val: val, + } + mock.lockIsLockValue.Lock() + mock.calls.IsLockValue = append(mock.calls.IsLockValue, callInfo) + mock.lockIsLockValue.Unlock() + return mock.IsLockValueFunc(val) +} + +// IsLockValueCalls gets all the calls that were made to IsLockValue. +// Check the length with: +// +// len(mockedLockManager.IsLockValueCalls()) +func (mock *MockLockManager) IsLockValueCalls() []struct { + Val string +} { + var calls []struct { + Val string + } + mock.lockIsLockValue.RLock() + calls = mock.calls.IsLockValue + mock.lockIsLockValue.RUnlock() + return calls +} + +// LockPrefix calls LockPrefixFunc. +func (mock *MockLockManager) LockPrefix() string { + if mock.LockPrefixFunc == nil { + panic("MockLockManager.LockPrefixFunc: method is nil but LockManager.LockPrefix was just called") + } + callInfo := struct { + }{} + mock.lockLockPrefix.Lock() + mock.calls.LockPrefix = append(mock.calls.LockPrefix, callInfo) + mock.lockLockPrefix.Unlock() + return mock.LockPrefixFunc() +} + +// LockPrefixCalls gets all the calls that were made to LockPrefix. +// Check the length with: +// +// len(mockedLockManager.LockPrefixCalls()) +func (mock *MockLockManager) LockPrefixCalls() []struct { +} { + var calls []struct { + } + mock.lockLockPrefix.RLock() + calls = mock.calls.LockPrefix + mock.lockLockPrefix.RUnlock() + return calls +} + +// OnInvalidate calls OnInvalidateFunc. +func (mock *MockLockManager) OnInvalidate(messages []rueidis.RedisMessage) { + if mock.OnInvalidateFunc == nil { + panic("MockLockManager.OnInvalidateFunc: method is nil but LockManager.OnInvalidate was just called") + } + callInfo := struct { + Messages []rueidis.RedisMessage + }{ + Messages: messages, + } + mock.lockOnInvalidate.Lock() + mock.calls.OnInvalidate = append(mock.calls.OnInvalidate, callInfo) + mock.lockOnInvalidate.Unlock() + mock.OnInvalidateFunc(messages) +} + +// OnInvalidateCalls gets all the calls that were made to OnInvalidate. +// Check the length with: +// +// len(mockedLockManager.OnInvalidateCalls()) +func (mock *MockLockManager) OnInvalidateCalls() []struct { + Messages []rueidis.RedisMessage +} { + var calls []struct { + Messages []rueidis.RedisMessage + } + mock.lockOnInvalidate.RLock() + calls = mock.calls.OnInvalidate + mock.lockOnInvalidate.RUnlock() + return calls +} + +// ReleaseLock calls ReleaseLockFunc. +func (mock *MockLockManager) ReleaseLock(ctx context.Context, key string, lockValue string) error { + if mock.ReleaseLockFunc == nil { + panic("MockLockManager.ReleaseLockFunc: method is nil but LockManager.ReleaseLock was just called") + } + callInfo := struct { + Ctx context.Context + Key string + LockValue string + }{ + Ctx: ctx, + Key: key, + LockValue: lockValue, + } + mock.lockReleaseLock.Lock() + mock.calls.ReleaseLock = append(mock.calls.ReleaseLock, callInfo) + mock.lockReleaseLock.Unlock() + return mock.ReleaseLockFunc(ctx, key, lockValue) +} + +// ReleaseLockCalls gets all the calls that were made to ReleaseLock. +// Check the length with: +// +// len(mockedLockManager.ReleaseLockCalls()) +func (mock *MockLockManager) ReleaseLockCalls() []struct { + Ctx context.Context + Key string + LockValue string +} { + var calls []struct { + Ctx context.Context + Key string + LockValue string + } + mock.lockReleaseLock.RLock() + calls = mock.calls.ReleaseLock + mock.lockReleaseLock.RUnlock() + return calls +} + +// ReleaseMultiLocks calls ReleaseMultiLocksFunc. +func (mock *MockLockManager) ReleaseMultiLocks(ctx context.Context, lockValues map[string]string) { + if mock.ReleaseMultiLocksFunc == nil { + panic("MockLockManager.ReleaseMultiLocksFunc: method is nil but LockManager.ReleaseMultiLocks was just called") + } + callInfo := struct { + Ctx context.Context + LockValues map[string]string + }{ + Ctx: ctx, + LockValues: lockValues, + } + mock.lockReleaseMultiLocks.Lock() + mock.calls.ReleaseMultiLocks = append(mock.calls.ReleaseMultiLocks, callInfo) + mock.lockReleaseMultiLocks.Unlock() + mock.ReleaseMultiLocksFunc(ctx, lockValues) +} + +// ReleaseMultiLocksCalls gets all the calls that were made to ReleaseMultiLocks. +// Check the length with: +// +// len(mockedLockManager.ReleaseMultiLocksCalls()) +func (mock *MockLockManager) ReleaseMultiLocksCalls() []struct { + Ctx context.Context + LockValues map[string]string +} { + var calls []struct { + Ctx context.Context + LockValues map[string]string + } + mock.lockReleaseMultiLocks.RLock() + calls = mock.calls.ReleaseMultiLocks + mock.lockReleaseMultiLocks.RUnlock() + return calls +} + +// TryAcquire calls TryAcquireFunc. +func (mock *MockLockManager) TryAcquire(ctx context.Context, key string) (string, lockmanager.WaitHandle, error) { + if mock.TryAcquireFunc == nil { + panic("MockLockManager.TryAcquireFunc: method is nil but LockManager.TryAcquire was just called") + } + callInfo := struct { + Ctx context.Context + Key string + }{ + Ctx: ctx, + Key: key, + } + mock.lockTryAcquire.Lock() + mock.calls.TryAcquire = append(mock.calls.TryAcquire, callInfo) + mock.lockTryAcquire.Unlock() + return mock.TryAcquireFunc(ctx, key) +} + +// TryAcquireCalls gets all the calls that were made to TryAcquire. +// Check the length with: +// +// len(mockedLockManager.TryAcquireCalls()) +func (mock *MockLockManager) TryAcquireCalls() []struct { + Ctx context.Context + Key string +} { + var calls []struct { + Ctx context.Context + Key string + } + mock.lockTryAcquire.RLock() + calls = mock.calls.TryAcquire + mock.lockTryAcquire.RUnlock() + return calls +} + +// TryAcquireMulti calls TryAcquireMultiFunc. +func (mock *MockLockManager) TryAcquireMulti(ctx context.Context, keys []string) (map[string]string, map[string]lockmanager.WaitHandle, error) { + if mock.TryAcquireMultiFunc == nil { + panic("MockLockManager.TryAcquireMultiFunc: method is nil but LockManager.TryAcquireMulti was just called") + } + callInfo := struct { + Ctx context.Context + Keys []string + }{ + Ctx: ctx, + Keys: keys, + } + mock.lockTryAcquireMulti.Lock() + mock.calls.TryAcquireMulti = append(mock.calls.TryAcquireMulti, callInfo) + mock.lockTryAcquireMulti.Unlock() + return mock.TryAcquireMultiFunc(ctx, keys) +} + +// TryAcquireMultiCalls gets all the calls that were made to TryAcquireMulti. +// Check the length with: +// +// len(mockedLockManager.TryAcquireMultiCalls()) +func (mock *MockLockManager) TryAcquireMultiCalls() []struct { + Ctx context.Context + Keys []string +} { + var calls []struct { + Ctx context.Context + Keys []string + } + mock.lockTryAcquireMulti.RLock() + calls = mock.calls.TryAcquireMulti + mock.lockTryAcquireMulti.RUnlock() + return calls +} + +// WaitForKey calls WaitForKeyFunc. +func (mock *MockLockManager) WaitForKey(key string) lockmanager.WaitHandle { + if mock.WaitForKeyFunc == nil { + panic("MockLockManager.WaitForKeyFunc: method is nil but LockManager.WaitForKey was just called") + } + callInfo := struct { + Key string + }{ + Key: key, + } + mock.lockWaitForKey.Lock() + mock.calls.WaitForKey = append(mock.calls.WaitForKey, callInfo) + mock.lockWaitForKey.Unlock() + return mock.WaitForKeyFunc(key) +} + +// WaitForKeyCalls gets all the calls that were made to WaitForKey. +// Check the length with: +// +// len(mockedLockManager.WaitForKeyCalls()) +func (mock *MockLockManager) WaitForKeyCalls() []struct { + Key string +} { + var calls []struct { + Key string + } + mock.lockWaitForKey.RLock() + calls = mock.calls.WaitForKey + mock.lockWaitForKey.RUnlock() + return calls +} + +// WaitForKeyWithRetry calls WaitForKeyWithRetryFunc. +func (mock *MockLockManager) WaitForKeyWithRetry(ctx context.Context, key string, ticker *time.Ticker) error { + if mock.WaitForKeyWithRetryFunc == nil { + panic("MockLockManager.WaitForKeyWithRetryFunc: method is nil but LockManager.WaitForKeyWithRetry was just called") + } + callInfo := struct { + Ctx context.Context + Key string + Ticker *time.Ticker + }{ + Ctx: ctx, + Key: key, + Ticker: ticker, + } + mock.lockWaitForKeyWithRetry.Lock() + mock.calls.WaitForKeyWithRetry = append(mock.calls.WaitForKeyWithRetry, callInfo) + mock.lockWaitForKeyWithRetry.Unlock() + return mock.WaitForKeyWithRetryFunc(ctx, key, ticker) +} + +// WaitForKeyWithRetryCalls gets all the calls that were made to WaitForKeyWithRetry. +// Check the length with: +// +// len(mockedLockManager.WaitForKeyWithRetryCalls()) +func (mock *MockLockManager) WaitForKeyWithRetryCalls() []struct { + Ctx context.Context + Key string + Ticker *time.Ticker +} { + var calls []struct { + Ctx context.Context + Key string + Ticker *time.Ticker + } + mock.lockWaitForKeyWithRetry.RLock() + calls = mock.calls.WaitForKeyWithRetry + mock.lockWaitForKeyWithRetry.RUnlock() + return calls +} + +// WaitForKeyWithSubscription calls WaitForKeyWithSubscriptionFunc. +func (mock *MockLockManager) WaitForKeyWithSubscription(ctx context.Context, key string, cacheTTL time.Duration) (lockmanager.WaitHandle, string, error) { + if mock.WaitForKeyWithSubscriptionFunc == nil { + panic("MockLockManager.WaitForKeyWithSubscriptionFunc: method is nil but LockManager.WaitForKeyWithSubscription was just called") + } + callInfo := struct { + Ctx context.Context + Key string + CacheTTL time.Duration + }{ + Ctx: ctx, + Key: key, + CacheTTL: cacheTTL, + } + mock.lockWaitForKeyWithSubscription.Lock() + mock.calls.WaitForKeyWithSubscription = append(mock.calls.WaitForKeyWithSubscription, callInfo) + mock.lockWaitForKeyWithSubscription.Unlock() + return mock.WaitForKeyWithSubscriptionFunc(ctx, key, cacheTTL) +} + +// WaitForKeyWithSubscriptionCalls gets all the calls that were made to WaitForKeyWithSubscription. +// Check the length with: +// +// len(mockedLockManager.WaitForKeyWithSubscriptionCalls()) +func (mock *MockLockManager) WaitForKeyWithSubscriptionCalls() []struct { + Ctx context.Context + Key string + CacheTTL time.Duration +} { + var calls []struct { + Ctx context.Context + Key string + CacheTTL time.Duration + } + mock.lockWaitForKeyWithSubscription.RLock() + calls = mock.calls.WaitForKeyWithSubscription + mock.lockWaitForKeyWithSubscription.RUnlock() + return calls +} diff --git a/mocks/lockmanager/mock_Logger.go b/mocks/lockmanager/mock_Logger.go new file mode 100644 index 0000000..91513ff --- /dev/null +++ b/mocks/lockmanager/mock_Logger.go @@ -0,0 +1,133 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mocklockmanager + +import ( + "sync" + + "github.com/dcbickfo/redcache/internal/lockmanager" +) + +// Ensure that MockLogger does implement lockmanager.Logger. +// If this is not the case, regenerate this file with mockery. +var _ lockmanager.Logger = &MockLogger{} + +// MockLogger is a mock implementation of lockmanager.Logger. +// +// func TestSomethingThatUsesLogger(t *testing.T) { +// +// // make and configure a mocked lockmanager.Logger +// mockedLogger := &MockLogger{ +// DebugFunc: func(msg string, args ...any) { +// panic("mock out the Debug method") +// }, +// ErrorFunc: func(msg string, args ...any) { +// panic("mock out the Error method") +// }, +// } +// +// // use mockedLogger in code that requires lockmanager.Logger +// // and then make assertions. +// +// } +type MockLogger struct { + // DebugFunc mocks the Debug method. + DebugFunc func(msg string, args ...any) + + // ErrorFunc mocks the Error method. + ErrorFunc func(msg string, args ...any) + + // calls tracks calls to the methods. + calls struct { + // Debug holds details about calls to the Debug method. + Debug []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + // Error holds details about calls to the Error method. + Error []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + } + lockDebug sync.RWMutex + lockError sync.RWMutex +} + +// Debug calls DebugFunc. +func (mock *MockLogger) Debug(msg string, args ...any) { + if mock.DebugFunc == nil { + panic("MockLogger.DebugFunc: method is nil but Logger.Debug was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockDebug.Lock() + mock.calls.Debug = append(mock.calls.Debug, callInfo) + mock.lockDebug.Unlock() + mock.DebugFunc(msg, args...) +} + +// DebugCalls gets all the calls that were made to Debug. +// Check the length with: +// +// len(mockedLogger.DebugCalls()) +func (mock *MockLogger) DebugCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockDebug.RLock() + calls = mock.calls.Debug + mock.lockDebug.RUnlock() + return calls +} + +// Error calls ErrorFunc. +func (mock *MockLogger) Error(msg string, args ...any) { + if mock.ErrorFunc == nil { + panic("MockLogger.ErrorFunc: method is nil but Logger.Error was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockError.Lock() + mock.calls.Error = append(mock.calls.Error, callInfo) + mock.lockError.Unlock() + mock.ErrorFunc(msg, args...) +} + +// ErrorCalls gets all the calls that were made to Error. +// Check the length with: +// +// len(mockedLogger.ErrorCalls()) +func (mock *MockLogger) ErrorCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockError.RLock() + calls = mock.calls.Error + mock.lockError.RUnlock() + return calls +} diff --git a/mocks/lockmanager/mock_WaitHandle.go b/mocks/lockmanager/mock_WaitHandle.go new file mode 100644 index 0000000..bc742d6 --- /dev/null +++ b/mocks/lockmanager/mock_WaitHandle.go @@ -0,0 +1,78 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mocklockmanager + +import ( + "context" + "sync" + + "github.com/dcbickfo/redcache/internal/lockmanager" +) + +// Ensure that MockWaitHandle does implement lockmanager.WaitHandle. +// If this is not the case, regenerate this file with mockery. +var _ lockmanager.WaitHandle = &MockWaitHandle{} + +// MockWaitHandle is a mock implementation of lockmanager.WaitHandle. +// +// func TestSomethingThatUsesWaitHandle(t *testing.T) { +// +// // make and configure a mocked lockmanager.WaitHandle +// mockedWaitHandle := &MockWaitHandle{ +// WaitFunc: func(ctx context.Context) error { +// panic("mock out the Wait method") +// }, +// } +// +// // use mockedWaitHandle in code that requires lockmanager.WaitHandle +// // and then make assertions. +// +// } +type MockWaitHandle struct { + // WaitFunc mocks the Wait method. + WaitFunc func(ctx context.Context) error + + // calls tracks calls to the methods. + calls struct { + // Wait holds details about calls to the Wait method. + Wait []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + } + lockWait sync.RWMutex +} + +// Wait calls WaitFunc. +func (mock *MockWaitHandle) Wait(ctx context.Context) error { + if mock.WaitFunc == nil { + panic("MockWaitHandle.WaitFunc: method is nil but WaitHandle.Wait was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockWait.Lock() + mock.calls.Wait = append(mock.calls.Wait, callInfo) + mock.lockWait.Unlock() + return mock.WaitFunc(ctx) +} + +// WaitCalls gets all the calls that were made to Wait. +// Check the length with: +// +// len(mockedWaitHandle.WaitCalls()) +func (mock *MockWaitHandle) WaitCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockWait.RLock() + calls = mock.calls.Wait + mock.lockWait.RUnlock() + return calls +} diff --git a/mocks/logger/mock_Logger.go b/mocks/logger/mock_Logger.go new file mode 100644 index 0000000..030a269 --- /dev/null +++ b/mocks/logger/mock_Logger.go @@ -0,0 +1,133 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mocklogger + +import ( + "sync" + + "github.com/dcbickfo/redcache/internal/logger" +) + +// Ensure that MockLogger does implement logger.Logger. +// If this is not the case, regenerate this file with mockery. +var _ logger.Logger = &MockLogger{} + +// MockLogger is a mock implementation of logger.Logger. +// +// func TestSomethingThatUsesLogger(t *testing.T) { +// +// // make and configure a mocked logger.Logger +// mockedLogger := &MockLogger{ +// DebugFunc: func(msg string, args ...any) { +// panic("mock out the Debug method") +// }, +// ErrorFunc: func(msg string, args ...any) { +// panic("mock out the Error method") +// }, +// } +// +// // use mockedLogger in code that requires logger.Logger +// // and then make assertions. +// +// } +type MockLogger struct { + // DebugFunc mocks the Debug method. + DebugFunc func(msg string, args ...any) + + // ErrorFunc mocks the Error method. + ErrorFunc func(msg string, args ...any) + + // calls tracks calls to the methods. + calls struct { + // Debug holds details about calls to the Debug method. + Debug []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + // Error holds details about calls to the Error method. + Error []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + } + lockDebug sync.RWMutex + lockError sync.RWMutex +} + +// Debug calls DebugFunc. +func (mock *MockLogger) Debug(msg string, args ...any) { + if mock.DebugFunc == nil { + panic("MockLogger.DebugFunc: method is nil but Logger.Debug was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockDebug.Lock() + mock.calls.Debug = append(mock.calls.Debug, callInfo) + mock.lockDebug.Unlock() + mock.DebugFunc(msg, args...) +} + +// DebugCalls gets all the calls that were made to Debug. +// Check the length with: +// +// len(mockedLogger.DebugCalls()) +func (mock *MockLogger) DebugCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockDebug.RLock() + calls = mock.calls.Debug + mock.lockDebug.RUnlock() + return calls +} + +// Error calls ErrorFunc. +func (mock *MockLogger) Error(msg string, args ...any) { + if mock.ErrorFunc == nil { + panic("MockLogger.ErrorFunc: method is nil but Logger.Error was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockError.Lock() + mock.calls.Error = append(mock.calls.Error, callInfo) + mock.lockError.Unlock() + mock.ErrorFunc(msg, args...) +} + +// ErrorCalls gets all the calls that were made to Error. +// Check the length with: +// +// len(mockedLogger.ErrorCalls()) +func (mock *MockLogger) ErrorCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockError.RLock() + calls = mock.calls.Error + mock.lockError.RUnlock() + return calls +} diff --git a/mocks/luascript/mock_Executor.go b/mocks/luascript/mock_Executor.go new file mode 100644 index 0000000..65cd2d6 --- /dev/null +++ b/mocks/luascript/mock_Executor.go @@ -0,0 +1,153 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mockluascript + +import ( + "context" + "sync" + + "github.com/dcbickfo/redcache/internal/luascript" + "github.com/redis/rueidis" +) + +// Ensure that MockExecutor does implement luascript.Executor. +// If this is not the case, regenerate this file with mockery. +var _ luascript.Executor = &MockExecutor{} + +// MockExecutor is a mock implementation of luascript.Executor. +// +// func TestSomethingThatUsesExecutor(t *testing.T) { +// +// // make and configure a mocked luascript.Executor +// mockedExecutor := &MockExecutor{ +// ExecFunc: func(ctx context.Context, client rueidis.Client, keys []string, args []string) rueidis.RedisResult { +// panic("mock out the Exec method") +// }, +// ExecMultiFunc: func(ctx context.Context, client rueidis.Client, statements ...rueidis.LuaExec) []rueidis.RedisResult { +// panic("mock out the ExecMulti method") +// }, +// } +// +// // use mockedExecutor in code that requires luascript.Executor +// // and then make assertions. +// +// } +type MockExecutor struct { + // ExecFunc mocks the Exec method. + ExecFunc func(ctx context.Context, client rueidis.Client, keys []string, args []string) rueidis.RedisResult + + // ExecMultiFunc mocks the ExecMulti method. + ExecMultiFunc func(ctx context.Context, client rueidis.Client, statements ...rueidis.LuaExec) []rueidis.RedisResult + + // calls tracks calls to the methods. + calls struct { + // Exec holds details about calls to the Exec method. + Exec []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Client is the client argument value. + Client rueidis.Client + // Keys is the keys argument value. + Keys []string + // Args is the args argument value. + Args []string + } + // ExecMulti holds details about calls to the ExecMulti method. + ExecMulti []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Client is the client argument value. + Client rueidis.Client + // Statements is the statements argument value. + Statements []rueidis.LuaExec + } + } + lockExec sync.RWMutex + lockExecMulti sync.RWMutex +} + +// Exec calls ExecFunc. +func (mock *MockExecutor) Exec(ctx context.Context, client rueidis.Client, keys []string, args []string) rueidis.RedisResult { + if mock.ExecFunc == nil { + panic("MockExecutor.ExecFunc: method is nil but Executor.Exec was just called") + } + callInfo := struct { + Ctx context.Context + Client rueidis.Client + Keys []string + Args []string + }{ + Ctx: ctx, + Client: client, + Keys: keys, + Args: args, + } + mock.lockExec.Lock() + mock.calls.Exec = append(mock.calls.Exec, callInfo) + mock.lockExec.Unlock() + return mock.ExecFunc(ctx, client, keys, args) +} + +// ExecCalls gets all the calls that were made to Exec. +// Check the length with: +// +// len(mockedExecutor.ExecCalls()) +func (mock *MockExecutor) ExecCalls() []struct { + Ctx context.Context + Client rueidis.Client + Keys []string + Args []string +} { + var calls []struct { + Ctx context.Context + Client rueidis.Client + Keys []string + Args []string + } + mock.lockExec.RLock() + calls = mock.calls.Exec + mock.lockExec.RUnlock() + return calls +} + +// ExecMulti calls ExecMultiFunc. +func (mock *MockExecutor) ExecMulti(ctx context.Context, client rueidis.Client, statements ...rueidis.LuaExec) []rueidis.RedisResult { + if mock.ExecMultiFunc == nil { + panic("MockExecutor.ExecMultiFunc: method is nil but Executor.ExecMulti was just called") + } + callInfo := struct { + Ctx context.Context + Client rueidis.Client + Statements []rueidis.LuaExec + }{ + Ctx: ctx, + Client: client, + Statements: statements, + } + mock.lockExecMulti.Lock() + mock.calls.ExecMulti = append(mock.calls.ExecMulti, callInfo) + mock.lockExecMulti.Unlock() + return mock.ExecMultiFunc(ctx, client, statements...) +} + +// ExecMultiCalls gets all the calls that were made to ExecMulti. +// Check the length with: +// +// len(mockedExecutor.ExecMultiCalls()) +func (mock *MockExecutor) ExecMultiCalls() []struct { + Ctx context.Context + Client rueidis.Client + Statements []rueidis.LuaExec +} { + var calls []struct { + Ctx context.Context + Client rueidis.Client + Statements []rueidis.LuaExec + } + mock.lockExecMulti.RLock() + calls = mock.calls.ExecMulti + mock.lockExecMulti.RUnlock() + return calls +} diff --git a/mocks/writelock/mock_Logger.go b/mocks/writelock/mock_Logger.go new file mode 100644 index 0000000..52f83c5 --- /dev/null +++ b/mocks/writelock/mock_Logger.go @@ -0,0 +1,133 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mockwritelock + +import ( + "sync" + + "github.com/dcbickfo/redcache/internal/writelock" +) + +// Ensure that MockLogger does implement writelock.Logger. +// If this is not the case, regenerate this file with mockery. +var _ writelock.Logger = &MockLogger{} + +// MockLogger is a mock implementation of writelock.Logger. +// +// func TestSomethingThatUsesLogger(t *testing.T) { +// +// // make and configure a mocked writelock.Logger +// mockedLogger := &MockLogger{ +// DebugFunc: func(msg string, args ...any) { +// panic("mock out the Debug method") +// }, +// ErrorFunc: func(msg string, args ...any) { +// panic("mock out the Error method") +// }, +// } +// +// // use mockedLogger in code that requires writelock.Logger +// // and then make assertions. +// +// } +type MockLogger struct { + // DebugFunc mocks the Debug method. + DebugFunc func(msg string, args ...any) + + // ErrorFunc mocks the Error method. + ErrorFunc func(msg string, args ...any) + + // calls tracks calls to the methods. + calls struct { + // Debug holds details about calls to the Debug method. + Debug []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + // Error holds details about calls to the Error method. + Error []struct { + // Msg is the msg argument value. + Msg string + // Args is the args argument value. + Args []any + } + } + lockDebug sync.RWMutex + lockError sync.RWMutex +} + +// Debug calls DebugFunc. +func (mock *MockLogger) Debug(msg string, args ...any) { + if mock.DebugFunc == nil { + panic("MockLogger.DebugFunc: method is nil but Logger.Debug was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockDebug.Lock() + mock.calls.Debug = append(mock.calls.Debug, callInfo) + mock.lockDebug.Unlock() + mock.DebugFunc(msg, args...) +} + +// DebugCalls gets all the calls that were made to Debug. +// Check the length with: +// +// len(mockedLogger.DebugCalls()) +func (mock *MockLogger) DebugCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockDebug.RLock() + calls = mock.calls.Debug + mock.lockDebug.RUnlock() + return calls +} + +// Error calls ErrorFunc. +func (mock *MockLogger) Error(msg string, args ...any) { + if mock.ErrorFunc == nil { + panic("MockLogger.ErrorFunc: method is nil but Logger.Error was just called") + } + callInfo := struct { + Msg string + Args []any + }{ + Msg: msg, + Args: args, + } + mock.lockError.Lock() + mock.calls.Error = append(mock.calls.Error, callInfo) + mock.lockError.Unlock() + mock.ErrorFunc(msg, args...) +} + +// ErrorCalls gets all the calls that were made to Error. +// Check the length with: +// +// len(mockedLogger.ErrorCalls()) +func (mock *MockLogger) ErrorCalls() []struct { + Msg string + Args []any +} { + var calls []struct { + Msg string + Args []any + } + mock.lockError.RLock() + calls = mock.calls.Error + mock.lockError.RUnlock() + return calls +} diff --git a/mocks/writelock/mock_WriteLockManager.go b/mocks/writelock/mock_WriteLockManager.go new file mode 100644 index 0000000..43aff46 --- /dev/null +++ b/mocks/writelock/mock_WriteLockManager.go @@ -0,0 +1,459 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: matryer + +package mockwritelock + +import ( + "context" + "sync" + "time" + + "github.com/dcbickfo/redcache/internal/writelock" +) + +// Ensure that MockWriteLockManager does implement writelock.WriteLockManager. +// If this is not the case, regenerate this file with mockery. +var _ writelock.WriteLockManager = &MockWriteLockManager{} + +// MockWriteLockManager is a mock implementation of writelock.WriteLockManager. +// +// func TestSomethingThatUsesWriteLockManager(t *testing.T) { +// +// // make and configure a mocked writelock.WriteLockManager +// mockedWriteLockManager := &MockWriteLockManager{ +// AcquireMultiWriteLocksFunc: func(ctx context.Context, keys []string) (map[string]string, map[string]string, []string, error) { +// panic("mock out the AcquireMultiWriteLocks method") +// }, +// AcquireMultiWriteLocksSequentialFunc: func(ctx context.Context, keys []string) (map[string]string, map[string]string, error) { +// panic("mock out the AcquireMultiWriteLocksSequential method") +// }, +// AcquireWriteLockFunc: func(ctx context.Context, key string) (string, error) { +// panic("mock out the AcquireWriteLock method") +// }, +// CommitWriteLocksFunc: func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (map[string]string, map[string]error) { +// panic("mock out the CommitWriteLocks method") +// }, +// ReleaseWriteLockFunc: func(ctx context.Context, key string, lockValue string) error { +// panic("mock out the ReleaseWriteLock method") +// }, +// ReleaseWriteLocksFunc: func(ctx context.Context, lockValues map[string]string) { +// panic("mock out the ReleaseWriteLocks method") +// }, +// RestoreValuesFunc: func(ctx context.Context, savedValues map[string]string, lockValues map[string]string) { +// panic("mock out the RestoreValues method") +// }, +// TouchLocksFunc: func(ctx context.Context, lockValues map[string]string) { +// panic("mock out the TouchLocks method") +// }, +// } +// +// // use mockedWriteLockManager in code that requires writelock.WriteLockManager +// // and then make assertions. +// +// } +type MockWriteLockManager struct { + // AcquireMultiWriteLocksFunc mocks the AcquireMultiWriteLocks method. + AcquireMultiWriteLocksFunc func(ctx context.Context, keys []string) (map[string]string, map[string]string, []string, error) + + // AcquireMultiWriteLocksSequentialFunc mocks the AcquireMultiWriteLocksSequential method. + AcquireMultiWriteLocksSequentialFunc func(ctx context.Context, keys []string) (map[string]string, map[string]string, error) + + // AcquireWriteLockFunc mocks the AcquireWriteLock method. + AcquireWriteLockFunc func(ctx context.Context, key string) (string, error) + + // CommitWriteLocksFunc mocks the CommitWriteLocks method. + CommitWriteLocksFunc func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (map[string]string, map[string]error) + + // ReleaseWriteLockFunc mocks the ReleaseWriteLock method. + ReleaseWriteLockFunc func(ctx context.Context, key string, lockValue string) error + + // ReleaseWriteLocksFunc mocks the ReleaseWriteLocks method. + ReleaseWriteLocksFunc func(ctx context.Context, lockValues map[string]string) + + // RestoreValuesFunc mocks the RestoreValues method. + RestoreValuesFunc func(ctx context.Context, savedValues map[string]string, lockValues map[string]string) + + // TouchLocksFunc mocks the TouchLocks method. + TouchLocksFunc func(ctx context.Context, lockValues map[string]string) + + // calls tracks calls to the methods. + calls struct { + // AcquireMultiWriteLocks holds details about calls to the AcquireMultiWriteLocks method. + AcquireMultiWriteLocks []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Keys is the keys argument value. + Keys []string + } + // AcquireMultiWriteLocksSequential holds details about calls to the AcquireMultiWriteLocksSequential method. + AcquireMultiWriteLocksSequential []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Keys is the keys argument value. + Keys []string + } + // AcquireWriteLock holds details about calls to the AcquireWriteLock method. + AcquireWriteLock []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + } + // CommitWriteLocks holds details about calls to the CommitWriteLocks method. + CommitWriteLocks []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // TTL is the ttl argument value. + TTL time.Duration + // LockValues is the lockValues argument value. + LockValues map[string]string + // ActualValues is the actualValues argument value. + ActualValues map[string]string + } + // ReleaseWriteLock holds details about calls to the ReleaseWriteLock method. + ReleaseWriteLock []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + // LockValue is the lockValue argument value. + LockValue string + } + // ReleaseWriteLocks holds details about calls to the ReleaseWriteLocks method. + ReleaseWriteLocks []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // LockValues is the lockValues argument value. + LockValues map[string]string + } + // RestoreValues holds details about calls to the RestoreValues method. + RestoreValues []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // SavedValues is the savedValues argument value. + SavedValues map[string]string + // LockValues is the lockValues argument value. + LockValues map[string]string + } + // TouchLocks holds details about calls to the TouchLocks method. + TouchLocks []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // LockValues is the lockValues argument value. + LockValues map[string]string + } + } + lockAcquireMultiWriteLocks sync.RWMutex + lockAcquireMultiWriteLocksSequential sync.RWMutex + lockAcquireWriteLock sync.RWMutex + lockCommitWriteLocks sync.RWMutex + lockReleaseWriteLock sync.RWMutex + lockReleaseWriteLocks sync.RWMutex + lockRestoreValues sync.RWMutex + lockTouchLocks sync.RWMutex +} + +// AcquireMultiWriteLocks calls AcquireMultiWriteLocksFunc. +func (mock *MockWriteLockManager) AcquireMultiWriteLocks(ctx context.Context, keys []string) (map[string]string, map[string]string, []string, error) { + if mock.AcquireMultiWriteLocksFunc == nil { + panic("MockWriteLockManager.AcquireMultiWriteLocksFunc: method is nil but WriteLockManager.AcquireMultiWriteLocks was just called") + } + callInfo := struct { + Ctx context.Context + Keys []string + }{ + Ctx: ctx, + Keys: keys, + } + mock.lockAcquireMultiWriteLocks.Lock() + mock.calls.AcquireMultiWriteLocks = append(mock.calls.AcquireMultiWriteLocks, callInfo) + mock.lockAcquireMultiWriteLocks.Unlock() + return mock.AcquireMultiWriteLocksFunc(ctx, keys) +} + +// AcquireMultiWriteLocksCalls gets all the calls that were made to AcquireMultiWriteLocks. +// Check the length with: +// +// len(mockedWriteLockManager.AcquireMultiWriteLocksCalls()) +func (mock *MockWriteLockManager) AcquireMultiWriteLocksCalls() []struct { + Ctx context.Context + Keys []string +} { + var calls []struct { + Ctx context.Context + Keys []string + } + mock.lockAcquireMultiWriteLocks.RLock() + calls = mock.calls.AcquireMultiWriteLocks + mock.lockAcquireMultiWriteLocks.RUnlock() + return calls +} + +// AcquireMultiWriteLocksSequential calls AcquireMultiWriteLocksSequentialFunc. +func (mock *MockWriteLockManager) AcquireMultiWriteLocksSequential(ctx context.Context, keys []string) (map[string]string, map[string]string, error) { + if mock.AcquireMultiWriteLocksSequentialFunc == nil { + panic("MockWriteLockManager.AcquireMultiWriteLocksSequentialFunc: method is nil but WriteLockManager.AcquireMultiWriteLocksSequential was just called") + } + callInfo := struct { + Ctx context.Context + Keys []string + }{ + Ctx: ctx, + Keys: keys, + } + mock.lockAcquireMultiWriteLocksSequential.Lock() + mock.calls.AcquireMultiWriteLocksSequential = append(mock.calls.AcquireMultiWriteLocksSequential, callInfo) + mock.lockAcquireMultiWriteLocksSequential.Unlock() + return mock.AcquireMultiWriteLocksSequentialFunc(ctx, keys) +} + +// AcquireMultiWriteLocksSequentialCalls gets all the calls that were made to AcquireMultiWriteLocksSequential. +// Check the length with: +// +// len(mockedWriteLockManager.AcquireMultiWriteLocksSequentialCalls()) +func (mock *MockWriteLockManager) AcquireMultiWriteLocksSequentialCalls() []struct { + Ctx context.Context + Keys []string +} { + var calls []struct { + Ctx context.Context + Keys []string + } + mock.lockAcquireMultiWriteLocksSequential.RLock() + calls = mock.calls.AcquireMultiWriteLocksSequential + mock.lockAcquireMultiWriteLocksSequential.RUnlock() + return calls +} + +// AcquireWriteLock calls AcquireWriteLockFunc. +func (mock *MockWriteLockManager) AcquireWriteLock(ctx context.Context, key string) (string, error) { + if mock.AcquireWriteLockFunc == nil { + panic("MockWriteLockManager.AcquireWriteLockFunc: method is nil but WriteLockManager.AcquireWriteLock was just called") + } + callInfo := struct { + Ctx context.Context + Key string + }{ + Ctx: ctx, + Key: key, + } + mock.lockAcquireWriteLock.Lock() + mock.calls.AcquireWriteLock = append(mock.calls.AcquireWriteLock, callInfo) + mock.lockAcquireWriteLock.Unlock() + return mock.AcquireWriteLockFunc(ctx, key) +} + +// AcquireWriteLockCalls gets all the calls that were made to AcquireWriteLock. +// Check the length with: +// +// len(mockedWriteLockManager.AcquireWriteLockCalls()) +func (mock *MockWriteLockManager) AcquireWriteLockCalls() []struct { + Ctx context.Context + Key string +} { + var calls []struct { + Ctx context.Context + Key string + } + mock.lockAcquireWriteLock.RLock() + calls = mock.calls.AcquireWriteLock + mock.lockAcquireWriteLock.RUnlock() + return calls +} + +// CommitWriteLocks calls CommitWriteLocksFunc. +func (mock *MockWriteLockManager) CommitWriteLocks(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (map[string]string, map[string]error) { + if mock.CommitWriteLocksFunc == nil { + panic("MockWriteLockManager.CommitWriteLocksFunc: method is nil but WriteLockManager.CommitWriteLocks was just called") + } + callInfo := struct { + Ctx context.Context + TTL time.Duration + LockValues map[string]string + ActualValues map[string]string + }{ + Ctx: ctx, + TTL: ttl, + LockValues: lockValues, + ActualValues: actualValues, + } + mock.lockCommitWriteLocks.Lock() + mock.calls.CommitWriteLocks = append(mock.calls.CommitWriteLocks, callInfo) + mock.lockCommitWriteLocks.Unlock() + return mock.CommitWriteLocksFunc(ctx, ttl, lockValues, actualValues) +} + +// CommitWriteLocksCalls gets all the calls that were made to CommitWriteLocks. +// Check the length with: +// +// len(mockedWriteLockManager.CommitWriteLocksCalls()) +func (mock *MockWriteLockManager) CommitWriteLocksCalls() []struct { + Ctx context.Context + TTL time.Duration + LockValues map[string]string + ActualValues map[string]string +} { + var calls []struct { + Ctx context.Context + TTL time.Duration + LockValues map[string]string + ActualValues map[string]string + } + mock.lockCommitWriteLocks.RLock() + calls = mock.calls.CommitWriteLocks + mock.lockCommitWriteLocks.RUnlock() + return calls +} + +// ReleaseWriteLock calls ReleaseWriteLockFunc. +func (mock *MockWriteLockManager) ReleaseWriteLock(ctx context.Context, key string, lockValue string) error { + if mock.ReleaseWriteLockFunc == nil { + panic("MockWriteLockManager.ReleaseWriteLockFunc: method is nil but WriteLockManager.ReleaseWriteLock was just called") + } + callInfo := struct { + Ctx context.Context + Key string + LockValue string + }{ + Ctx: ctx, + Key: key, + LockValue: lockValue, + } + mock.lockReleaseWriteLock.Lock() + mock.calls.ReleaseWriteLock = append(mock.calls.ReleaseWriteLock, callInfo) + mock.lockReleaseWriteLock.Unlock() + return mock.ReleaseWriteLockFunc(ctx, key, lockValue) +} + +// ReleaseWriteLockCalls gets all the calls that were made to ReleaseWriteLock. +// Check the length with: +// +// len(mockedWriteLockManager.ReleaseWriteLockCalls()) +func (mock *MockWriteLockManager) ReleaseWriteLockCalls() []struct { + Ctx context.Context + Key string + LockValue string +} { + var calls []struct { + Ctx context.Context + Key string + LockValue string + } + mock.lockReleaseWriteLock.RLock() + calls = mock.calls.ReleaseWriteLock + mock.lockReleaseWriteLock.RUnlock() + return calls +} + +// ReleaseWriteLocks calls ReleaseWriteLocksFunc. +func (mock *MockWriteLockManager) ReleaseWriteLocks(ctx context.Context, lockValues map[string]string) { + if mock.ReleaseWriteLocksFunc == nil { + panic("MockWriteLockManager.ReleaseWriteLocksFunc: method is nil but WriteLockManager.ReleaseWriteLocks was just called") + } + callInfo := struct { + Ctx context.Context + LockValues map[string]string + }{ + Ctx: ctx, + LockValues: lockValues, + } + mock.lockReleaseWriteLocks.Lock() + mock.calls.ReleaseWriteLocks = append(mock.calls.ReleaseWriteLocks, callInfo) + mock.lockReleaseWriteLocks.Unlock() + mock.ReleaseWriteLocksFunc(ctx, lockValues) +} + +// ReleaseWriteLocksCalls gets all the calls that were made to ReleaseWriteLocks. +// Check the length with: +// +// len(mockedWriteLockManager.ReleaseWriteLocksCalls()) +func (mock *MockWriteLockManager) ReleaseWriteLocksCalls() []struct { + Ctx context.Context + LockValues map[string]string +} { + var calls []struct { + Ctx context.Context + LockValues map[string]string + } + mock.lockReleaseWriteLocks.RLock() + calls = mock.calls.ReleaseWriteLocks + mock.lockReleaseWriteLocks.RUnlock() + return calls +} + +// RestoreValues calls RestoreValuesFunc. +func (mock *MockWriteLockManager) RestoreValues(ctx context.Context, savedValues map[string]string, lockValues map[string]string) { + if mock.RestoreValuesFunc == nil { + panic("MockWriteLockManager.RestoreValuesFunc: method is nil but WriteLockManager.RestoreValues was just called") + } + callInfo := struct { + Ctx context.Context + SavedValues map[string]string + LockValues map[string]string + }{ + Ctx: ctx, + SavedValues: savedValues, + LockValues: lockValues, + } + mock.lockRestoreValues.Lock() + mock.calls.RestoreValues = append(mock.calls.RestoreValues, callInfo) + mock.lockRestoreValues.Unlock() + mock.RestoreValuesFunc(ctx, savedValues, lockValues) +} + +// RestoreValuesCalls gets all the calls that were made to RestoreValues. +// Check the length with: +// +// len(mockedWriteLockManager.RestoreValuesCalls()) +func (mock *MockWriteLockManager) RestoreValuesCalls() []struct { + Ctx context.Context + SavedValues map[string]string + LockValues map[string]string +} { + var calls []struct { + Ctx context.Context + SavedValues map[string]string + LockValues map[string]string + } + mock.lockRestoreValues.RLock() + calls = mock.calls.RestoreValues + mock.lockRestoreValues.RUnlock() + return calls +} + +// TouchLocks calls TouchLocksFunc. +func (mock *MockWriteLockManager) TouchLocks(ctx context.Context, lockValues map[string]string) { + if mock.TouchLocksFunc == nil { + panic("MockWriteLockManager.TouchLocksFunc: method is nil but WriteLockManager.TouchLocks was just called") + } + callInfo := struct { + Ctx context.Context + LockValues map[string]string + }{ + Ctx: ctx, + LockValues: lockValues, + } + mock.lockTouchLocks.Lock() + mock.calls.TouchLocks = append(mock.calls.TouchLocks, callInfo) + mock.lockTouchLocks.Unlock() + mock.TouchLocksFunc(ctx, lockValues) +} + +// TouchLocksCalls gets all the calls that were made to TouchLocks. +// Check the length with: +// +// len(mockedWriteLockManager.TouchLocksCalls()) +func (mock *MockWriteLockManager) TouchLocksCalls() []struct { + Ctx context.Context + LockValues map[string]string +} { + var calls []struct { + Ctx context.Context + LockValues map[string]string + } + mock.lockTouchLocks.RLock() + calls = mock.calls.TouchLocks + mock.lockTouchLocks.RUnlock() + return calls +} diff --git a/primeable_cacheaside.go b/primeable_cacheaside.go index 0fc8dc7..88197e7 100644 --- a/primeable_cacheaside.go +++ b/primeable_cacheaside.go @@ -2,165 +2,15 @@ package redcache import ( "context" + "errors" "fmt" - "sort" - "strconv" "time" "github.com/redis/rueidis" - "github.com/dcbickfo/redcache/internal/cmdx" - "github.com/dcbickfo/redcache/internal/lockutil" + "github.com/dcbickfo/redcache/internal/lockmanager" "github.com/dcbickfo/redcache/internal/mapsx" - "github.com/dcbickfo/redcache/internal/syncx" -) - -const ( - // lockRetryInterval is the interval for periodic lock acquisition retries. - // Used when waiting for locks to be released. - lockRetryInterval = 50 * time.Millisecond -) - -// Pre-compiled Lua scripts for atomic operations. -var ( - unlockKeyScript = rueidis.NewLuaScript(` - local key = KEYS[1] - local expected = ARGV[1] - if redis.call("GET", key) == expected then - return redis.call("DEL", key) - else - return 0 - end - `) - - // acquireWriteLockScript atomically acquires a write lock for Set operations. - // - // IMPORTANT: We CANNOT use SET NX here like CacheAside does for Get operations. - // - // Why SET NX doesn't work for Set operations: - // - SET NX only succeeds if the key doesn't exist at all - // - Set operations need to overwrite existing real values (not locks) - // - If we used SET NX, Set would fail when trying to update an existing cached value - // - // This Lua script provides the correct behavior: - // - Acquires lock if key is empty (like SET NX) - // - Acquires lock if key contains a real value (overwrites to prepare for Set) - // - REFUSES to acquire if there's an active lock from a Get operation - // - // This prevents the race condition where Set would overwrite Get's lock, - // causing Get to lose its lock and retry, ultimately seeing Set's value - // instead of its own callback result. - // - // Returns 0 if there's an existing lock (cannot acquire). - // Returns 1 if lock was successfully acquired. - acquireWriteLockScript = rueidis.NewLuaScript(` - local key = KEYS[1] - local lock_value = ARGV[1] - local ttl = ARGV[2] - local lock_prefix = ARGV[3] - - local current = redis.call("GET", key) - - -- If key is empty, we can set our lock - if current == false then - redis.call("SET", key, lock_value, "PX", ttl) - return 1 - end - - -- If current value is a lock (has lock prefix), we cannot acquire - if string.sub(current, 1, string.len(lock_prefix)) == lock_prefix then - return 0 - end - - -- Current value is a real value (not a lock), we can overwrite with our lock - redis.call("SET", key, lock_value, "PX", ttl) - return 1 - `) - - // acquireWriteLockWithBackupScript acquires a lock and returns the previous value. - // This is used for sequential lock acquisition where we need to restore values - // if we can't acquire all locks. - // - // Returns: [success (0 or 1), previous_value or false] - // - [0, current] if lock exists (cannot acquire) - // - [1, false] if key was empty (acquired from nothing) - // - [1, current] if key had real value (acquired, can restore) - acquireWriteLockWithBackupScript = rueidis.NewLuaScript(` - local key = KEYS[1] - local lock_value = ARGV[1] - local ttl = ARGV[2] - local lock_prefix = ARGV[3] - - local current = redis.call("GET", key) - - -- If key is empty, acquire and return false (nothing to restore) - if current == false then - redis.call("SET", key, lock_value, "PX", ttl) - return {1, false} - end - - -- If current value is a lock, cannot acquire - if string.sub(current, 1, string.len(lock_prefix)) == lock_prefix then - return {0, current} - end - - -- Current value is real data - save it before overwriting - redis.call("SET", key, lock_value, "PX", ttl) - return {1, current} - `) - - // restoreValueOrDeleteScript restores a saved value or deletes the key. - // Used when releasing locks during sequential acquisition rollback. - // - // ARGV[2] can be: - // - false/nil: delete the key (was empty before) - // - string: restore the original value - restoreValueOrDeleteScript = rueidis.NewLuaScript(` - local key = KEYS[1] - local expected_lock = ARGV[1] - local restore_value = ARGV[2] - - -- Only restore if we still hold our lock - if redis.call("GET", key) == expected_lock then - if restore_value and restore_value ~= "" then - -- Restore original value - redis.call("SET", key, restore_value) - else - -- Was empty before, delete - redis.call("DEL", key) - end - return 1 - else - -- Someone else has the key now, don't touch it - return 0 - end - `) - - setWithLockScript = rueidis.NewLuaScript(` - local key = KEYS[1] - local value = ARGV[1] - local ttl = ARGV[2] - local expected_lock = ARGV[3] - - local current = redis.call("GET", key) - - -- STRICT CAS: We can ONLY set if we still hold our exact lock value - -- This prevents Set from overwriting ForceSet values that stole our lock - -- - -- If current is nil (false in Lua), our lock expired or was deleted - -- If current is different, either: - -- - Another Set operation acquired a different lock - -- - A ForceSet operation overwrote our lock with a real value - -- - Our lock naturally expired - -- - -- In all cases where we don't hold our exact lock, we return 0 (failure) - if current == expected_lock then - redis.call("SET", key, value, "PX", ttl) - return 1 - else - return 0 - end - `) + "github.com/dcbickfo/redcache/internal/writelock" ) // PrimeableCacheAside extends CacheAside with explicit Set operations for cache priming @@ -200,61 +50,41 @@ var ( // for detailed analysis and recommendations. type PrimeableCacheAside struct { *CacheAside - lockChecker lockutil.LockChecker // Shared lock checking logic (interface) -} - -// waitForLockRelease waits for a lock to be released via invalidation or timeout. -// This is a common pattern used throughout the code to wait for distributed locks. -// NOTE: The caller should have already subscribed to invalidations using DoCache. -func (pca *PrimeableCacheAside) waitForLockRelease(ctx context.Context, key string) error { - // Don't subscribe again - the caller should have already used DoCache - // This avoids duplicate subscriptions and ensures we get the invalidation - - // Register locally to wait for the invalidation - waitStart := time.Now() - waitChan := pca.register(key) - pca.logger.Debug("waiting for lock release", "key", key, "start", waitStart) - - // Wait for lock release using shared utility - if err := lockutil.WaitForSingleLock(ctx, waitChan, pca.lockTTL); err != nil { - pca.logger.Debug("lock wait failed", "key", key, "duration", time.Since(waitStart), "error", err) - return err - } - - waitDuration := time.Since(waitStart) - pca.logger.Debug("lock released", "key", key, "duration", waitDuration) - - // If we waited for nearly the full lockTTL, it means we timed out rather than got an invalidation - if waitDuration > pca.lockTTL-100*time.Millisecond { - pca.logger.Error("lock release likely timed out rather than received invalidation", - "key", key, "duration", waitDuration, "lockTTL", pca.lockTTL) - } - - return nil + writeLockManager writelock.WriteLockManager // Write lock coordination (interface) } // waitForReadLock waits for any active read lock on the key to be released. +// Uses lockManager to handle registration and subscription in the correct order. func (pca *PrimeableCacheAside) waitForReadLock(ctx context.Context, key string) error { startTime := time.Now() - // Check if there's a read lock on this key using DoCache - // This ensures we're subscribed to invalidations if a lock exists - resp := pca.client.DoCache(ctx, pca.client.B().Get().Key(key).Cache(), time.Second) - val, err := resp.ToString() + // Register for invalidations and subscribe to Redis client-side cache in correct order + // This prevents the race condition between registration and subscription + waitHandle, val, err := pca.lockManager.WaitForKeyWithSubscription(ctx, key, time.Second) pca.logger.Debug("waitForReadLock check", "key", key, "hasValue", err == nil, "value", val, - "isLock", err == nil && pca.lockChecker.HasLock(val)) + "isLock", err == nil && pca.lockManager.IsLockValue(val)) - if err == nil && pca.lockChecker.HasLock(val) { + // Only wait if there's actually a lock + if err == nil && pca.lockManager.IsLockValue(val) { pca.logger.Debug("read lock exists, waiting for it to complete", "key", key, "lockValue", val) - // Since we used DoCache, we're guaranteed to get an invalidation when the lock is released - if waitErr := pca.waitForLockRelease(ctx, key); waitErr != nil { + + // Wait for lock release (already registered and subscribed above) + if waitErr := waitHandle.Wait(ctx); waitErr != nil { + pca.logger.Debug("lock wait failed", "key", key, "duration", time.Since(startTime), "error", waitErr) return waitErr } - pca.logger.Debug("read lock cleared", "key", key, "duration", time.Since(startTime)) + + waitDuration := time.Since(startTime) + pca.logger.Debug("read lock cleared", "key", key, "duration", waitDuration) + + // Log if timeout likely occurred + if waitDuration > pca.lockTTL-100*time.Millisecond { + pca.logger.Error("lock release likely timed out", "key", key, "duration", waitDuration, "lockTTL", pca.lockTTL) + } } return nil @@ -262,78 +92,59 @@ func (pca *PrimeableCacheAside) waitForReadLock(ctx context.Context, key string) // trySetKeyFuncForWrite performs coordinated cache update operation with distributed locking. // Cache locks provide both write-write coordination and CAS protection. +// Uses WriteLockManager which can overwrite existing values (but not other locks). func (pca *PrimeableCacheAside) trySetKeyFuncForWrite(ctx context.Context, ttl time.Duration, key string, fn func(ctx context.Context, key string) (string, error)) (val string, err error) { // Wait for any existing read locks to complete if waitErr := pca.waitForReadLock(ctx, key); waitErr != nil { return "", waitErr } - // Acquire cache lock using acquireWriteLockScript - // The Lua script provides write-write coordination by checking for existing locks - // and refusing to acquire if another Set operation holds the lock. - // - // IMPORTANT: Do NOT use SET NX here (see acquireWriteLockScript comment for details). - // The script ensures Set can overwrite real values but won't overwrite active locks. - lockVal := pca.generateLockValue() - ticker := time.NewTicker(lockRetryInterval) - defer ticker.Stop() - + // Acquire write lock with retry logic (WriteLockManager can overwrite values, but not locks) + var lockVal string for { - result := acquireWriteLockScript.Exec(ctx, pca.client, - []string{key}, - []string{lockVal, strconv.FormatInt(pca.lockTTL.Milliseconds(), 10), pca.lockPrefix}) - - success, execErr := result.AsInt64() - if execErr != nil { - return "", fmt.Errorf("failed to acquire cache lock for key %q: %w", key, execErr) + lockVal, err = pca.writeLockManager.AcquireWriteLock(ctx, key) + if err == nil { + break // Successfully acquired lock } - if success == 1 { - // Successfully acquired the lock - break + // Check if error is lock contention (retryable) or a real error + if !errors.Is(err, ErrLockFailed) { + return "", err } - // There's an active lock (from Get or another Set) - // Wait for it to be released via invalidation - waitChan := pca.register(key) - _ = pca.client.DoCache(ctx, pca.client.B().Get().Key(key).Cache(), pca.lockTTL) - - select { - case <-ctx.Done(): - return "", ctx.Err() - case <-waitChan: - // Lock was released, retry acquisition - continue - case <-ticker.C: - // Periodic retry - continue + // Lock contention - wait for the lock to be released + pca.logger.Debug("write lock contention, waiting for release", "key", key) + // Register for invalidations and subscribe in correct order + waitHandle, _, _ := pca.lockManager.WaitForKeyWithSubscription(ctx, key, time.Second) + if waitErr := waitHandle.Wait(ctx); waitErr != nil { + return "", waitErr } + // Retry after wait } - pca.logger.Debug("acquired cache key lock", "key", key) + pca.logger.Debug("acquired cache key lock (blocking)", "key", key) // We have the lock, now execute the callback val, err = fn(ctx, key) if err != nil { // Release the cache lock on error - pca.unlockKey(ctx, key, lockVal) + _ = pca.writeLockManager.ReleaseWriteLock(ctx, key, lockVal) return "", err } - // Set the value in Redis using a Lua script that verifies we still hold the lock + // Set the value in Redis using a CAS to verify we still hold the lock // This uses strict compare-and-swap (CAS): only sets if we hold our exact lock value // This prevents Set from overwriting ForceSet values that may have stolen our lock - setResult := setWithLockScript.Exec(ctx, pca.client, - []string{key}, - []string{val, strconv.FormatInt(ttl.Milliseconds(), 10), lockVal}) + succeeded, failed := pca.writeLockManager.CommitWriteLocks(ctx, ttl, + map[string]string{key: lockVal}, + map[string]string{key: val}) - setSuccess, err := setResult.AsInt64() - if err != nil { - return "", fmt.Errorf("failed to set value for key %q: %w", key, err) + if len(failed) > 0 { + return "", failed[key] } - if setSuccess == 0 { - return "", fmt.Errorf("%w for key %q", ErrLockLost, key) + if len(succeeded) == 0 { + return "", fmt.Errorf("failed to set value for key %q", key) } // Note: No DoCache needed here. Redis automatically sends invalidation messages to all @@ -345,561 +156,6 @@ func (pca *PrimeableCacheAside) trySetKeyFuncForWrite(ctx context.Context, ttl t return val, nil } -// unlockKey releases a cache lock if it matches the expected value. -func (pca *PrimeableCacheAside) unlockKey(ctx context.Context, key string, lockVal string) { - _ = unlockKeyScript.Exec(ctx, pca.client, []string{key}, []string{lockVal}).Error() -} - -// lockAcquisitionResult holds the result of processing a single lock acquisition response. -type lockAcquisitionResult struct { - acquired bool - lockValue string - savedValue string - hasSaved bool -} - -// processLockAcquisitionResponse processes a single lock acquisition response. -// Returns the result or an error if the response is invalid. -func processLockAcquisitionResponse(resp rueidis.RedisResult, key, lockVal string) (lockAcquisitionResult, error) { - // Response is [success, previous_value] - result, err := resp.ToArray() - if err != nil { - return lockAcquisitionResult{}, fmt.Errorf("failed to acquire cache lock for key %q: %w", key, err) - } - - if len(result) != 2 { - return lockAcquisitionResult{}, fmt.Errorf("unexpected response length for key %q: got %d, expected 2", key, len(result)) - } - - success, err := result[0].AsInt64() - if err != nil { - return lockAcquisitionResult{}, fmt.Errorf("failed to parse success for key %q: %w", key, err) - } - - if success == 1 { - // Acquired successfully - check if there's a previous value to save - prevValue, prevErr := result[1].ToString() - if prevErr == nil && prevValue != "" { - // Previous value exists and is not empty - return lockAcquisitionResult{ - acquired: true, - lockValue: lockVal, - savedValue: prevValue, - hasSaved: true, - }, nil - } - // No previous value (key was empty) - return lockAcquisitionResult{ - acquired: true, - lockValue: lockVal, - }, nil - } - - // Failed to acquire - return lockAcquisitionResult{acquired: false}, nil -} - -// tryAcquireMultiCacheLocksBatched attempts to acquire cache locks for multiple keys in batches. -// Uses acquireWriteLockWithBackupScript which saves the previous value before acquiring the lock. -// -// Returns: -// - acquired: map of keys to their lock values (successfully acquired) -// - savedValues: map of keys to their previous values (to restore on rollback) -// - failed: list of keys that failed to acquire -// - error: if a critical error occurred -func (pca *PrimeableCacheAside) tryAcquireMultiCacheLocksBatched( - ctx context.Context, - keys []string, -) (acquired map[string]string, savedValues map[string]string, failed []string, err error) { - if len(keys) == 0 { - return make(map[string]string), make(map[string]string), nil, nil - } - - acquired = make(map[string]string) - savedValues = make(map[string]string) - failed = make([]string, 0) - - // Group by slot and build Lua exec statements - stmtsBySlot := pca.groupLockAcquisitionsBySlot(keys) - - // Execute all slots and collect results - for _, stmts := range stmtsBySlot { - // Batch execute all lock acquisitions for this slot - resps := acquireWriteLockWithBackupScript.ExecMulti(ctx, pca.client, stmts.execStmts...) - - // Process responses in order - for i, resp := range resps { - key := stmts.keyOrder[i] - lockVal := stmts.lockVals[i] - - result, respErr := processLockAcquisitionResponse(resp, key, lockVal) - if respErr != nil { - return nil, nil, nil, respErr - } - - if result.acquired { - acquired[key] = result.lockValue - if result.hasSaved { - savedValues[key] = result.savedValue - } - } else { - failed = append(failed, key) - } - } - } - - return acquired, savedValues, failed, nil -} - -// slotLockStatements holds lock acquisition statements grouped by slot. -type slotLockStatements struct { - keyOrder []string - lockVals []string - execStmts []rueidis.LuaExec -} - -// estimateSlotDistribution estimates the number of Redis cluster slots and keys per slot -// for efficient pre-allocation when grouping operations by slot. -// Uses a heuristic of ~8 slots for typical Redis Cluster distributions. -func estimateSlotDistribution(itemCount int) (estimatedSlots, estimatedPerSlot int) { - estimatedSlots = itemCount / 8 - if estimatedSlots < 1 { - estimatedSlots = 1 - } - estimatedPerSlot = (itemCount / estimatedSlots) + 1 - return -} - -// groupLockAcquisitionsBySlot groups lock acquisition operations by Redis cluster slot. -// This is necessary for Lua scripts which must execute on a single node in Redis Cluster. -// Unlike regular commands, Lua scripts (LuaExec) require manual slot grouping. -func (pca *PrimeableCacheAside) groupLockAcquisitionsBySlot(keys []string) map[uint16]slotLockStatements { - if len(keys) == 0 { - return nil - } - - // Pre-allocate with estimated capacity - estimatedSlots, estimatedPerSlot := estimateSlotDistribution(len(keys)) - stmtsBySlot := make(map[uint16]slotLockStatements, estimatedSlots) - - // Pre-calculate lock TTL string once - lockTTLStr := strconv.FormatInt(pca.lockTTL.Milliseconds(), 10) - - for _, key := range keys { - lockVal := pca.generateLockValue() - slot := cmdx.Slot(key) - stmts := stmtsBySlot[slot] - - // Pre-allocate slices on first access to this slot - if stmts.keyOrder == nil { - stmts.keyOrder = make([]string, 0, estimatedPerSlot) - stmts.lockVals = make([]string, 0, estimatedPerSlot) - stmts.execStmts = make([]rueidis.LuaExec, 0, estimatedPerSlot) - } - - stmts.keyOrder = append(stmts.keyOrder, key) - stmts.lockVals = append(stmts.lockVals, lockVal) - stmts.execStmts = append(stmts.execStmts, rueidis.LuaExec{ - Keys: []string{key}, - Args: []string{lockVal, lockTTLStr, pca.lockPrefix}, - }) - stmtsBySlot[slot] = stmts - } - - return stmtsBySlot -} - -// releaseMultiCacheLocks releases multiple cache locks. -func (pca *PrimeableCacheAside) releaseMultiCacheLocks(ctx context.Context, lockValues map[string]string) { - for key, lockVal := range lockValues { - pca.unlockKey(ctx, key, lockVal) - } -} - -// acquireMultiCacheLocks acquires cache locks for multiple keys using sequential acquisition -// with value preservation to prevent deadlocks and cache misses. -// -// Strategy: -// 1. Sort keys for consistent ordering (prevents deadlocks) -// 2. Try batch-acquire all remaining keys -// 3. On partial failure: -// a. Restore original values for acquired keys (using saved values from backup script) -// b. Wait for FIRST failed key in sorted order (sequential, not any) -// c. Acquire that specific key and continue -// 4. Repeat until all locks acquired -// -// This approach prevents deadlocks by ensuring clients always wait for keys in the same order, -// and prevents cache misses by restoring original values when releasing locks. -// -// Returns a map of keys to their lock values, or an error if acquisition fails. -func (pca *PrimeableCacheAside) acquireMultiCacheLocks(ctx context.Context, keys []string) (map[string]string, error) { - if len(keys) == 0 { - return make(map[string]string), nil - } - - // Sort keys for consistent lock ordering (prevents deadlocks) - sortedKeys := pca.sortKeys(keys) - lockValues := make(map[string]string) - ticker := time.NewTicker(lockRetryInterval) - defer ticker.Stop() - - remainingKeys := sortedKeys - for len(remainingKeys) > 0 { - done, err := pca.tryAcquireBatchAndProcess(ctx, sortedKeys, remainingKeys, lockValues, ticker) - if err != nil { - return nil, err - } - if done { - return lockValues, nil - } - // Update remaining keys for next iteration - remainingKeys = pca.keysNotIn(sortedKeys, lockValues) - } - - return lockValues, nil -} - -// tryAcquireBatchAndProcess attempts to acquire locks for remaining keys and processes the result. -// Returns true if all locks were acquired (done), or an error if acquisition failed. -func (pca *PrimeableCacheAside) tryAcquireBatchAndProcess( - ctx context.Context, - sortedKeys []string, - remainingKeys []string, - lockValues map[string]string, - ticker *time.Ticker, -) (bool, error) { - // Try to batch-acquire all remaining keys with backup - acquired, savedValues, failed, err := pca.tryAcquireMultiCacheLocksBatched(ctx, remainingKeys) - if err != nil { - // Critical error - release all locks we've acquired so far - pca.releaseMultiCacheLocks(ctx, lockValues) - return false, err - } - - // Success: All remaining keys acquired - if len(failed) == 0 { - // Add newly acquired locks to our collection - for k, v := range acquired { - lockValues[k] = v - } - pca.logger.Debug("acquireMultiCacheLocks completed", "keys", sortedKeys, "count", len(lockValues)) - return true, nil - } - - // Handle partial failure - return false, pca.handlePartialLockFailure(ctx, remainingKeys, acquired, savedValues, failed, lockValues, ticker) -} - -// handlePartialLockFailure processes a partial lock acquisition failure. -// It keeps locks acquired in sequential order and restores out-of-order locks. -func (pca *PrimeableCacheAside) handlePartialLockFailure( - ctx context.Context, - remainingKeys []string, - acquired map[string]string, - savedValues map[string]string, - failed []string, - lockValues map[string]string, - ticker *time.Ticker, -) error { - pca.logger.Debug("partial acquisition, analyzing sequential locks", - "acquired_this_batch", len(acquired), - "failed", len(failed), - "total_acquired_so_far", len(lockValues)) - - // Find the first failed key in sorted order - firstFailedKey := pca.findFirstKey(remainingKeys, failed) - - // Determine which acquired keys to keep vs restore - toKeep, toRestore := pca.splitAcquiredBySequence(remainingKeys, acquired, firstFailedKey) - - // Restore keys that were acquired out of order - if len(toRestore) > 0 { - pca.logger.Debug("restoring out-of-order locks", - "restore_count", len(toRestore), - "first_failed", firstFailedKey) - pca.restoreMultiValues(ctx, toRestore, savedValues) - } - - // Add sequential locks to our permanent collection - for k, v := range toKeep { - lockValues[k] = v - } - - // Touch/refresh TTL on all locks we're keeping to prevent expiration - if len(lockValues) > 0 { - pca.touchMultiLocks(ctx, lockValues) - } - - // Wait for the first failed key to be released - err := pca.waitForSingleLock(ctx, firstFailedKey, ticker) - if err != nil { - // Context cancelled or timeout - release all locks - pca.releaseMultiCacheLocks(ctx, lockValues) - return err - } - - return nil -} - -// sortKeys creates a sorted copy of the keys to ensure consistent lock ordering. -func (pca *PrimeableCacheAside) sortKeys(keys []string) []string { - sorted := make([]string, len(keys)) - copy(sorted, keys) - sort.Strings(sorted) - return sorted -} - -// restoreMultiValues restores original values for keys that were acquired. -// Uses the restoreValueOrDeleteScript to atomically restore values or delete keys. -// This is critical for preventing cache misses when releasing locks on partial failure. -func (pca *PrimeableCacheAside) restoreMultiValues( - ctx context.Context, - lockValues map[string]string, - savedValues map[string]string, -) { - if len(lockValues) == 0 { - return - } - - for key, lockVal := range lockValues { - // Get the saved value (empty string if key didn't exist before) - savedVal := savedValues[key] // Empty string if not in map - - // Use Lua script to restore value or delete key atomically - _ = restoreValueOrDeleteScript.Exec(ctx, pca.client, - []string{key}, - []string{lockVal, savedVal}, - ).Error() - } -} - -// findFirstKey finds the first key from sortedKeys that appears in targetKeys. -// This ensures sequential lock acquisition by always waiting for the first failed key. -func (pca *PrimeableCacheAside) findFirstKey(sortedKeys []string, targetKeys []string) string { - // Convert targetKeys to a set for O(1) lookup - targetSet := make(map[string]bool, len(targetKeys)) - for _, k := range targetKeys { - targetSet[k] = true - } - - // Find first key in sorted order - for _, k := range sortedKeys { - if targetSet[k] { - return k - } - } - - // Should never happen if targetKeys is non-empty and derived from sortedKeys - return targetKeys[0] -} - -// splitAcquiredBySequence splits acquired keys into those that should be kept (sequential) -// vs those that should be restored (out of order after a gap). -// -// Example: sortedKeys=[key1,key2,key3], acquired=[key1,key3], firstFailedKey=key2. -// Result: keep=[key1], restore=[key3]. -func (pca *PrimeableCacheAside) splitAcquiredBySequence( - sortedKeys []string, - acquired map[string]string, - firstFailedKey string, -) (keep map[string]string, restore map[string]string) { - keep = make(map[string]string) - restore = make(map[string]string) - - foundFailedKey := false - for _, key := range sortedKeys { - if key == firstFailedKey { - foundFailedKey = true - continue - } - - lockVal, wasAcquired := acquired[key] - if !wasAcquired { - continue - } - - if foundFailedKey { - // This key comes after the failed key - out of order, must restore - restore[key] = lockVal - } else { - // This key comes before the failed key - sequential, can keep - keep[key] = lockVal - } - } - - return keep, restore -} - -// touchMultiLocks refreshes the TTL on multiple locks to prevent expiration. -// This is critical when waiting for locks, as we need to maintain our holds. -func (pca *PrimeableCacheAside) touchMultiLocks(ctx context.Context, lockValues map[string]string) { - if len(lockValues) == 0 { - return - } - - // Build SET commands to refresh TTL on each lock - cmds := make([]rueidis.Completed, 0, len(lockValues)) - for key, lockVal := range lockValues { - cmds = append(cmds, pca.client.B().Set(). - Key(key). - Value(lockVal). - Px(pca.lockTTL). - Build()) - } - - // Execute all SET commands in parallel - _ = pca.client.DoMulti(ctx, cmds...) -} - -// keysNotIn returns keys from sortedKeys that are not in the acquired map. -func (pca *PrimeableCacheAside) keysNotIn(sortedKeys []string, acquired map[string]string) []string { - remaining := make([]string, 0, len(sortedKeys)) - for _, key := range sortedKeys { - if _, ok := acquired[key]; !ok { - remaining = append(remaining, key) - } - } - return remaining -} - -// waitForSingleLock waits for a specific key's lock to be released. -// Fetches the key using DoCache to register for invalidation notifications, -// then waits for the invalidation event or ticker timeout. -func (pca *PrimeableCacheAside) waitForSingleLock(ctx context.Context, key string, ticker *time.Ticker) error { - // Fetch key using DoCache to register for invalidation notifications - // This ensures we get notified when the lock is released - waitChan := pca.register(key) - _ = pca.client.DoCache(ctx, pca.client.B().Get().Key(key).Cache(), pca.lockTTL) - - select { - case <-waitChan: - // Lock was released (invalidation event) - return nil - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - // Periodic retry timeout - return nil - } -} - -// setMultiValuesWithCAS sets multiple values using CAS to verify we still hold the locks. -// Returns maps of succeeded and failed keys. -// Uses batched Lua script execution grouped by Redis cluster slot for optimal performance. -func (pca *PrimeableCacheAside) setMultiValuesWithCAS( - ctx context.Context, - ttl time.Duration, - values map[string]string, - lockValues map[string]string, -) (map[string]string, map[string]error) { - if len(values) == 0 { - return make(map[string]string), make(map[string]error) - } - - succeeded := make(map[string]string) - failed := make(map[string]error) - - // Group by slot for efficient batching - stmtsBySlot := pca.groupSetValuesBySlot(values, lockValues, ttl) - - // Execute all slots in parallel and collect results - for slot, stmts := range stmtsBySlot { - // Execute all Lua scripts for this slot in a single batch - setResps := setWithLockScript.ExecMulti(ctx, pca.client, stmts.execStmts...) - - // Process responses in order - for i, resp := range setResps { - key := stmts.keyOrder[i] - value := values[key] - - setSuccess, err := resp.AsInt64() - if err != nil { - pca.logger.Debug("set CAS failed for key", "key", key, "slot", slot, "error", err) - failed[key] = fmt.Errorf("failed to set value: %w", err) - continue - } - - if setSuccess == 0 { - pca.logger.Debug("set CAS lock lost for key", "key", key, "slot", slot) - failed[key] = fmt.Errorf("%w", ErrLockLost) - continue - } - - succeeded[key] = value - } - } - - // Populate client-side cache for all successfully set values - // This ensures other clients can see the values via CSC - // Critical for empty strings and all cached values - if len(succeeded) > 0 { - cacheCommands := make([]rueidis.CacheableTTL, 0, len(succeeded)) - for key := range succeeded { - cacheCommands = append(cacheCommands, rueidis.CacheableTTL{ - Cmd: pca.client.B().Get().Key(key).Cache(), - TTL: ttl, - }) - } - _ = pca.client.DoMultiCache(ctx, cacheCommands...) - } - - return succeeded, failed -} - -// slotSetStatements holds Lua execution statements grouped by slot. -type slotSetStatements struct { - keyOrder []string - execStmts []rueidis.LuaExec -} - -// groupSetValuesBySlot groups set operations by Redis cluster slot for Lua script execution. -// This is necessary for CAS (compare-and-swap) Lua scripts which must execute on a single node. -// Unlike regular SET commands which rueidis.DoMulti routes automatically, Lua scripts -// (LuaExec) require manual slot grouping to ensure atomic operations on the correct node. -func (pca *PrimeableCacheAside) groupSetValuesBySlot( - values map[string]string, - lockValues map[string]string, - ttl time.Duration, -) map[uint16]slotSetStatements { - if len(values) == 0 { - return nil - } - - // Pre-allocate with estimated capacity - estimatedSlots, estimatedPerSlot := estimateSlotDistribution(len(values)) - stmtsBySlot := make(map[uint16]slotSetStatements, estimatedSlots) - - // Pre-calculate TTL string once - ttlStr := strconv.FormatInt(ttl.Milliseconds(), 10) - - for key, value := range values { - lockVal, hasLock := lockValues[key] - if !hasLock { - // Skip keys without locks (shouldn't happen, but be defensive) - pca.logger.Error("no lock value for key in groupSetValuesBySlot", "key", key) - continue - } - - slot := cmdx.Slot(key) - stmts := stmtsBySlot[slot] - - // Pre-allocate slices on first access to this slot - if stmts.keyOrder == nil { - stmts.keyOrder = make([]string, 0, estimatedPerSlot) - stmts.execStmts = make([]rueidis.LuaExec, 0, estimatedPerSlot) - } - - stmts.keyOrder = append(stmts.keyOrder, key) - stmts.execStmts = append(stmts.execStmts, rueidis.LuaExec{ - Keys: []string{key}, - Args: []string{value, ttlStr, lockVal}, - }) - stmtsBySlot[slot] = stmts - } - - return stmtsBySlot -} - // NewPrimeableCacheAside creates a new PrimeableCacheAside instance with the specified // Redis client options and cache-aside configuration. // @@ -930,10 +186,20 @@ func NewPrimeableCacheAside(clientOption rueidis.ClientOption, caOption CacheAsi return nil, err } - return &PrimeableCacheAside{ - CacheAside: ca, - lockChecker: &lockutil.PrefixLockChecker{Prefix: caOption.LockPrefix}, - }, nil + pca := &PrimeableCacheAside{ + CacheAside: ca, + } + + // Create write lock manager for Set operations + // IMPORTANT: Pass LockManager to ensure cohesive lock value generation and prefix checking + pca.writeLockManager = writelock.NewCASWriteLockManager(writelock.Config{ + Client: ca.client, + LockTTL: ca.lockTTL, + LockManager: ca.lockManager, // Ensures consistent lock generation and checking + Logger: ca.logger, + }) + + return pca, nil } // Set performs a coordinated cache update operation with distributed locking. @@ -1099,17 +365,18 @@ func (pca *PrimeableCacheAside) SetMulti( // Acquire cache locks for all keys // Cache locks provide both write-write coordination (prevents concurrent Sets) // and CAS protection (verifies we still hold locks during setMultiValuesWithCAS) - lockValues, acquireErr := pca.acquireMultiCacheLocks(ctx, keys) + // Use sequential acquisition strategy from WriteLockManager for deadlock prevention + lockValues, savedValues, acquireErr := pca.writeLockManager.AcquireMultiWriteLocksSequential(ctx, keys) if acquireErr != nil { return nil, acquireErr } defer func() { // Release cache locks only for keys that weren't successfully set. // Successfully set keys no longer contain locks (they contain the real values), - // so unlockKey would fail the comparison check and do nothing, but it's wasteful + // so ReleaseWriteLock would fail the comparison check and do nothing, but it's wasteful // to attempt the unlock at all. for key, lockVal := range lockValues { - pca.unlockKey(ctx, key, lockVal) + _ = pca.writeLockManager.ReleaseWriteLock(ctx, key, lockVal) } }() @@ -1120,7 +387,15 @@ func (pca *PrimeableCacheAside) SetMulti( } // Set all values using CAS to verify we still hold the locks - succeeded, failed := pca.setMultiValuesWithCAS(ctx, ttl, vals, lockValues) + // CommitWriteLocks handles CSC population internally + var succeeded map[string]string + var failed map[string]error + if len(vals) == 0 { + succeeded = make(map[string]string) + failed = make(map[string]error) + } else { + succeeded, failed = pca.writeLockManager.CommitWriteLocks(ctx, ttl, lockValues, vals) + } // Remove successfully set keys from lockValues so defer won't try to unlock them. // This is an optimization: unlockKey would be safe (lock check would fail) but wasteful. @@ -1129,8 +404,8 @@ func (pca *PrimeableCacheAside) SetMulti( } if len(failed) > 0 { - // Partial failure - some keys lost their locks - pca.logger.Debug("SetMulti partial failure", "succeeded", len(succeeded), "failed", len(failed)) + // Partial failure - restore previous values for failed keys + pca.restoreFailedKeys(ctx, failed, savedValues, lockValues) return succeeded, NewBatchError(failed, mapsx.Keys(succeeded)) } @@ -1138,29 +413,59 @@ func (pca *PrimeableCacheAside) SetMulti( return vals, nil } +// restoreFailedKeys restores previous values for keys that failed CAS during SetMulti. +// This prevents leaving lock placeholders in Redis when the final commit fails. +func (pca *PrimeableCacheAside) restoreFailedKeys( + ctx context.Context, + failed map[string]error, + savedValues map[string]string, + lockValues map[string]string, +) { + pca.logger.Debug("SetMulti partial failure, restoring previous values for failed keys", + "failed", len(failed)) + + failedKeys := mapsx.Keys(failed) + restoreMap := make(map[string]string) + failedLockVals := make(map[string]string) + for _, key := range failedKeys { + if savedVal, ok := savedValues[key]; ok { + restoreMap[key] = savedVal + if lockVal, hasLock := lockValues[key]; hasLock { + failedLockVals[key] = lockVal + } + } + } + if len(restoreMap) > 0 { + pca.writeLockManager.RestoreValues(ctx, restoreMap, failedLockVals) + } +} + // waitForReadLocks checks for read locks on the given keys and waits for them to complete. -// Follows the same pattern as CacheAside's registerAll and WaitForAll usage. +// Uses lock manager's WaitHandle abstraction for consistent invalidation handling. func (pca *PrimeableCacheAside) waitForReadLocks(ctx context.Context, keys []string) { - // Use shared utility to batch check locks - // BatchCheckLocks now uses DoMultiCache, so we're already subscribed to invalidations - lockedKeys := lockutil.BatchCheckLocks(ctx, pca.client, keys, pca.lockChecker) + // Register wait handles FIRST for all keys (before checking locks) + // This prevents race where locks are released between check and registration + waitHandleMap := make(map[string]lockmanager.WaitHandle, len(keys)) + for _, key := range keys { + waitHandleMap[key] = pca.lockManager.WaitForKey(key) + } + + // Now check which keys have locks (subscribes to Redis via DoMultiCache) + lockedKeys := pca.lockManager.CheckMultiKeysLocked(ctx, keys) if len(lockedKeys) == 0 { return // No read locks to wait for } pca.logger.Debug("waiting for read locks to complete", "count", len(lockedKeys)) - // No need to subscribe again - BatchCheckLocks already used DoMultiCache - - // Register all keys and get their wait channels - waitChannels := make(map[string]<-chan struct{}, len(lockedKeys)) + // Wait for all locked keys (returns immediately if already released) + waitHandles := make([]lockmanager.WaitHandle, 0, len(lockedKeys)) for _, key := range lockedKeys { - waitChannels[key] = pca.register(key) + waitHandles = append(waitHandles, waitHandleMap[key]) } - // Use syncx.WaitForAll like CacheAside does - channels := mapsx.Values(waitChannels) - if err := syncx.WaitForAll(ctx, channels); err != nil { + // Wait for all handles concurrently using lock manager's helper + if err := lockmanager.WaitForAll(ctx, waitHandles); err != nil { pca.logger.Debug("context cancelled while waiting for read locks", "error", err) return } @@ -1194,16 +499,8 @@ func (pca *PrimeableCacheAside) setMultiValues(ctx context.Context, ttl time.Dur } } - // Populate CSC for all set values - cacheCommands := make([]rueidis.CacheableTTL, 0, len(values)) - for key := range values { - cacheCommands = append(cacheCommands, rueidis.CacheableTTL{ - Cmd: pca.client.B().Get().Key(key).Cache(), - TTL: ttl, - }) - } - _ = pca.client.DoMultiCache(ctx, cacheCommands...) - + // Note: No DoMultiCache needed. Redis automatically sends invalidation + // messages to all clients tracking these keys when SET executes. return nil } @@ -1288,18 +585,13 @@ func (pca *PrimeableCacheAside) ForceSetMulti(ctx context.Context, ttl time.Dura return pca.setMultiValues(ctx, ttl, values) } -// Close closes both the underlying CacheAside client and the rueidislock locker. -// This should be called when the PrimeableCacheAside is no longer needed. -// Close cleans up all resources used by PrimeableCacheAside. -// It cleans up parent CacheAside resources and closes the Redis client. +// Close cleans up resources used by the PrimeableCacheAside instance. +// It cancels all pending lock wait operations and cleans up internal state. +// Note: This does NOT close the underlying Redis client, as that's owned by the caller. +// The caller is responsible for closing the Redis client when done. func (pca *PrimeableCacheAside) Close() { // Clean up parent CacheAside resources (lock entries, etc) if pca.CacheAside != nil { pca.CacheAside.Close() } - - // Close Redis client - if pca.CacheAside != nil && pca.Client() != nil { - pca.Client().Close() - } } diff --git a/primeable_cacheaside_cluster_test.go b/primeable_cacheaside_cluster_test.go index 5ae1d17..dd0b12d 100644 --- a/primeable_cacheaside_cluster_test.go +++ b/primeable_cacheaside_cluster_test.go @@ -1,3 +1,5 @@ +//go:build cluster + package redcache_test import ( @@ -42,7 +44,7 @@ func makeClusterPrimeableCacheAside(t *testing.T) *redcache.PrimeableCacheAside } // Test cluster connectivity - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() innerClient := cacheAside.Client() if pingErr := innerClient.Do(ctx, innerClient.B().Ping().Build()).Error(); pingErr != nil { @@ -63,7 +65,7 @@ func TestPrimeableCacheAside_Cluster_BasicSetOperations(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:set:" + uuid.New().String() expectedValue := "value:" + uuid.New().String() @@ -96,7 +98,7 @@ func TestPrimeableCacheAside_Cluster_BasicSetOperations(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:force:" + uuid.New().String() // Manually set a lock @@ -126,7 +128,7 @@ func TestPrimeableCacheAside_Cluster_SetMultiOperations(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() // Use hash tags to ensure same slot keys := []string{ @@ -174,7 +176,7 @@ func TestPrimeableCacheAside_Cluster_SetMultiOperations(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() // Create keys in different slots keys := []string{ @@ -223,7 +225,7 @@ func TestPrimeableCacheAside_Cluster_SetMultiOperations(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() keys := []string{ "{shard:400}:key1", @@ -265,7 +267,7 @@ func TestPrimeableCacheAside_Cluster_LargeKeySet(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() // Create 50 keys across multiple slots numKeys := 50 @@ -328,7 +330,7 @@ func TestPrimeableCacheAside_Cluster_ConcurrentSetOperations(t *testing.T) { client2 := makeClusterPrimeableCacheAside(t) defer client2.Close() - ctx := context.Background() + ctx := t.Context() // Keys in different slots key1 := "{shard:700}:concurrent1" @@ -378,7 +380,7 @@ func TestPrimeableCacheAside_Cluster_ConcurrentSetOperations(t *testing.T) { client2 := makeClusterPrimeableCacheAside(t) defer client2.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:same:" + uuid.New().String() var callbackCount atomic.Int32 @@ -431,7 +433,7 @@ func TestPrimeableCacheAside_Cluster_SetAndGetIntegration(t *testing.T) { client2 := makeClusterPrimeableCacheAside(t) defer client2.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:setget:" + uuid.New().String() setValue := "set-value:" + uuid.New().String() @@ -462,7 +464,7 @@ func TestPrimeableCacheAside_Cluster_SetAndGetIntegration(t *testing.T) { client2 := makeClusterPrimeableCacheAside(t) defer client2.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:getwait:" + uuid.New().String() // Client 1 starts slow Get @@ -509,7 +511,7 @@ func TestPrimeableCacheAside_Cluster_Invalidation(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:del:" + uuid.New().String() // Set a value @@ -535,7 +537,7 @@ func TestPrimeableCacheAside_Cluster_Invalidation(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() keys := []string{ "{shard:900}:del1", @@ -574,7 +576,7 @@ func TestPrimeableCacheAside_Cluster_Invalidation(t *testing.T) { client2 := makeClusterPrimeableCacheAside(t) defer client2.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:del-set:" + uuid.New().String() // Client 1 starts Set @@ -612,7 +614,7 @@ func TestPrimeableCacheAside_Cluster_ErrorHandling(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:error:" + uuid.New().String() callCount := 0 @@ -650,7 +652,7 @@ func TestPrimeableCacheAside_Cluster_ErrorHandling(t *testing.T) { err := innerClient.Do(context.Background(), innerClient.B().Set().Key(key).Value(lockVal).Ex(time.Second*5).Build()).Error() require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) defer cancel() // Set should fail with timeout @@ -674,7 +676,7 @@ func TestPrimeableCacheAside_Cluster_SpecialValues(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:empty:" + uuid.New().String() err := client.Set(ctx, time.Second*10, key, func(_ context.Context, _ string) (string, error) { @@ -696,7 +698,7 @@ func TestPrimeableCacheAside_Cluster_SpecialValues(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() testCases := []struct { name string @@ -731,7 +733,7 @@ func TestPrimeableCacheAside_Cluster_SpecialValues(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:large:" + uuid.New().String() largeValue := strings.Repeat("x", 1024*1024) // 1MB @@ -756,7 +758,7 @@ func TestPrimeableCacheAside_Cluster_StressTest(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() // Create many keys across slots numKeys := 50 @@ -835,7 +837,7 @@ func TestPrimeableCacheAside_Cluster_TTLConsistency(t *testing.T) { } defer client.Close() - ctx := context.Background() + ctx := t.Context() key := "pcluster:ttl:" + uuid.New().String() // Set with short TTL diff --git a/primeable_cacheaside_distributed_test.go b/primeable_cacheaside_distributed_test.go index 4d8e592..0eb5ec7 100644 --- a/primeable_cacheaside_distributed_test.go +++ b/primeable_cacheaside_distributed_test.go @@ -1,3 +1,5 @@ +//go:build distributed + package redcache_test import ( @@ -21,7 +23,7 @@ import ( // coordinate correctly across multiple clients, especially with read operations func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { t.Run("write waits for read lock from different client", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:write-read:" + uuid.New().String() // Create separate clients @@ -67,7 +69,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { writeDuration := time.Since(writeStart) require.NoError(t, err) - assert.Greater(t, writeDuration, 900*time.Millisecond, "Write should wait for read lock") + assert.Greater(t, writeDuration, 850*time.Millisecond, "Write should wait for read lock") assert.Less(t, writeDuration, 1500*time.Millisecond, "Write should not wait too long") <-readDone @@ -82,7 +84,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("multiple writes coordinate through rueidislock", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:multi-write:" + uuid.New().String() numClients := 5 @@ -132,7 +134,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("SetMulti waits for GetMulti from different client", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "pdist:batch:1:" + uuid.New().String() key2 := "pdist:batch:2:" + uuid.New().String() @@ -185,14 +187,14 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { require.NoError(t, err) assert.Len(t, result, 2) - assert.Greater(t, writeDuration, 900*time.Millisecond, "SetMulti should wait") + assert.Greater(t, writeDuration, 850*time.Millisecond, "SetMulti should wait") assert.Less(t, writeDuration, 1500*time.Millisecond, "SetMulti should not wait too long") <-readDone }) t.Run("ForceSet bypasses all locks including from other clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:force:" + uuid.New().String() client1, err := redcache.NewPrimeableCacheAside( @@ -232,7 +234,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("read lock expiration allows write to proceed", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:expire:" + uuid.New().String() // Use very short lock TTL @@ -277,7 +279,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("invalidation from write triggers waiting reads across clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:invalidate:" + uuid.New().String() // Set up initial lock to block everyone @@ -345,7 +347,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("stress test - many concurrent read and write operations", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() numClients := 10 numOperations := 20 @@ -434,7 +436,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("context cancellation during distributed Set", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:ctx-cancel:" + uuid.New().String() client1, err := redcache.NewPrimeableCacheAside( @@ -476,7 +478,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("callback error in distributed Set does not cache", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:callback-error:" + uuid.New().String() client1, err := redcache.NewPrimeableCacheAside( @@ -519,7 +521,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("SetMulti callback error across clients does not cache", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client1, err := redcache.NewPrimeableCacheAside( rueidis.ClientOption{InitAddress: addr}, @@ -559,7 +561,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("Del during distributed Set coordination", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:del-during-set:" + uuid.New().String() client1, err := redcache.NewPrimeableCacheAside( @@ -604,7 +606,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("DelMulti coordination across clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client1, err := redcache.NewPrimeableCacheAside( rueidis.ClientOption{InitAddress: addr}, @@ -663,7 +665,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("empty and special values across clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client1, err := redcache.NewPrimeableCacheAside( rueidis.ClientOption{InitAddress: addr}, @@ -712,7 +714,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("large value handling across clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client1, err := redcache.NewPrimeableCacheAside( rueidis.ClientOption{InitAddress: addr}, @@ -748,7 +750,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("TTL consistency across clients", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client1, err := redcache.NewPrimeableCacheAside( rueidis.ClientOption{InitAddress: addr}, @@ -802,7 +804,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("concurrent SetMulti with partial overlap - coordinated completion", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) defer cancel() client1, err := redcache.NewPrimeableCacheAside( @@ -882,7 +884,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("distributed: Set from client A + ForceSet from client B steals lock", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "pdist:set-forceSet:" + uuid.New().String() clientA, err := redcache.NewPrimeableCacheAside( @@ -950,7 +952,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { }) t.Run("distributed: SetMulti from client A + ForceSetMulti from client B steals some locks", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "pdist:setmulti:1:" + uuid.New().String() key2 := "pdist:setmulti:2:" + uuid.New().String() key3 := "pdist:setmulti:3:" + uuid.New().String() @@ -1061,7 +1063,7 @@ func TestPrimeableCacheAside_DistributedCoordination(t *testing.T) { // (< 50ms) to prove operations complete via invalidation, not polling. func TestPrimeableCacheAside_DistributedInvalidationTiming(t *testing.T) { t.Run("Distributed Set waits for Set via invalidation, not ticker", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "dist-inv-set-set:" + uuid.New().String() // Create two separate clients (simulating distributed processes) @@ -1125,7 +1127,7 @@ func TestPrimeableCacheAside_DistributedInvalidationTiming(t *testing.T) { }) t.Run("Distributed SetMulti waits for SetMulti via invalidation", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "dist-inv-setm-setm-1:" + uuid.New().String() key2 := "dist-inv-setm-setm-2:" + uuid.New().String() keys := []string{key1, key2} @@ -1202,7 +1204,7 @@ func TestPrimeableCacheAside_DistributedInvalidationTiming(t *testing.T) { }) t.Run("Distributed Set waits for Get via invalidation", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "dist-inv-set-get:" + uuid.New().String() client1, err := redcache.NewPrimeableCacheAside( @@ -1264,7 +1266,7 @@ func TestPrimeableCacheAside_DistributedInvalidationTiming(t *testing.T) { }) t.Run("Distributed SetMulti waits for GetMulti via invalidation", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "dist-inv-setm-getm-1:" + uuid.New().String() key2 := "dist-inv-setm-getm-2:" + uuid.New().String() keys := []string{key1, key2} diff --git a/primeable_cacheaside_test.go b/primeable_cacheaside_test.go index 16c0a57..611a373 100644 --- a/primeable_cacheaside_test.go +++ b/primeable_cacheaside_test.go @@ -1,3 +1,5 @@ +//go:build integration + package redcache_test import ( @@ -56,7 +58,7 @@ func forceSetMulti(client *redcache.PrimeableCacheAside, ctx context.Context, tt func TestPrimeableCacheAside_Set(t *testing.T) { t.Run("successful set acquires lock and sets value", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -81,7 +83,7 @@ func TestPrimeableCacheAside_Set(t *testing.T) { }) t.Run("waits and retries when lock cannot be acquired", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) defer cancel() client := makeClientWithSet(t, addr) defer client.Close() @@ -107,7 +109,7 @@ func TestPrimeableCacheAside_Set(t *testing.T) { }) t.Run("subsequent Get retrieves Set value", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -130,7 +132,7 @@ func TestPrimeableCacheAside_Set(t *testing.T) { func TestPrimeableCacheAside_ForceSet(t *testing.T) { t.Run("successful force set bypasses locks", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -148,7 +150,7 @@ func TestPrimeableCacheAside_ForceSet(t *testing.T) { }) t.Run("force set overrides existing lock", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -172,7 +174,7 @@ func TestPrimeableCacheAside_ForceSet(t *testing.T) { }) t.Run("force set overrides existing value", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -200,7 +202,7 @@ func TestPrimeableCacheAside_ForceSet(t *testing.T) { func TestPrimeableCacheAside_SetMulti(t *testing.T) { t.Run("successful set multi acquires locks and sets values", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -224,7 +226,7 @@ func TestPrimeableCacheAside_SetMulti(t *testing.T) { }) t.Run("empty values returns empty result", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -234,7 +236,7 @@ func TestPrimeableCacheAside_SetMulti(t *testing.T) { }) t.Run("waits for all locks to be released then sets all values", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -272,7 +274,7 @@ func TestPrimeableCacheAside_SetMulti(t *testing.T) { }) t.Run("subsequent GetMulti retrieves SetMulti values", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -305,7 +307,7 @@ func TestPrimeableCacheAside_SetMulti(t *testing.T) { }) t.Run("successful SetMulti doesn't delete values in cleanup", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -340,7 +342,7 @@ func TestPrimeableCacheAside_SetMulti(t *testing.T) { }) t.Run("callback exceeding lock TTL results in CAS failure", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Create client with very short lock TTL (200ms) option := rueidis.ClientOption{InitAddress: addr} @@ -394,7 +396,7 @@ func TestPrimeableCacheAside_SetMulti(t *testing.T) { func TestPrimeableCacheAside_ForceSetMulti(t *testing.T) { t.Run("successful force set multi bypasses locks", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -417,7 +419,7 @@ func TestPrimeableCacheAside_ForceSetMulti(t *testing.T) { }) t.Run("empty values completes successfully", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -426,7 +428,7 @@ func TestPrimeableCacheAside_ForceSetMulti(t *testing.T) { }) t.Run("force set multi overrides existing locks", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -462,7 +464,7 @@ func TestPrimeableCacheAside_ForceSetMulti(t *testing.T) { func TestPrimeableCacheAside_Integration(t *testing.T) { t.Run("Set waits for concurrent Get to complete", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -516,7 +518,7 @@ func TestPrimeableCacheAside_Integration(t *testing.T) { }) t.Run("ForceSet overrides lock from Get", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -540,7 +542,7 @@ func TestPrimeableCacheAside_Integration(t *testing.T) { }) t.Run("concurrent Set operations wait and eventually succeed", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -608,7 +610,7 @@ func TestPrimeableCacheAside_EdgeCases_ContextCancellation(t *testing.T) { defer client.Close() key := "key:" + uuid.New().String() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) cancel() // Cancel immediately err := client.Set(ctx, time.Second, key, func(_ context.Context, _ string) (string, error) { @@ -631,7 +633,7 @@ func TestPrimeableCacheAside_EdgeCases_ContextCancellation(t *testing.T) { require.NoError(t, err) // Try to Set with short timeout - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) defer cancel() err = client.Set(ctx, time.Second, key, func(_ context.Context, _ string) (string, error) { @@ -653,7 +655,7 @@ func TestPrimeableCacheAside_EdgeCases_ContextCancellation(t *testing.T) { "key:2:" + uuid.New().String(): "value2", } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) cancel() // Cancel immediately _, err := setMultiValue(client, ctx, time.Second, values) @@ -682,7 +684,7 @@ func TestPrimeableCacheAside_EdgeCases_ContextCancellation(t *testing.T) { key2: "value2", } - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) defer cancel() _, err = setMultiValue(client, ctx, time.Second, values) @@ -697,7 +699,7 @@ func TestPrimeableCacheAside_EdgeCases_ContextCancellation(t *testing.T) { func TestPrimeableCacheAside_EdgeCases_TTL(t *testing.T) { t.Run("Set with very short TTL", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -726,7 +728,7 @@ func TestPrimeableCacheAside_EdgeCases_TTL(t *testing.T) { }) t.Run("Set with 1 second TTL has correct expiration", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -747,7 +749,7 @@ func TestPrimeableCacheAside_EdgeCases_TTL(t *testing.T) { }) t.Run("SetMulti with very short TTL", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -776,7 +778,7 @@ func TestPrimeableCacheAside_EdgeCases_TTL(t *testing.T) { func TestPrimeableCacheAside_EdgeCases_DuplicateKeys(t *testing.T) { t.Run("SetMulti with duplicate keys in input", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -801,7 +803,7 @@ func TestPrimeableCacheAside_EdgeCases_DuplicateKeys(t *testing.T) { func TestPrimeableCacheAside_EdgeCases_SpecialValues(t *testing.T) { t.Run("Set with empty string value", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -820,7 +822,7 @@ func TestPrimeableCacheAside_EdgeCases_SpecialValues(t *testing.T) { }) t.Run("Set with value that starts with lock prefix", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -846,7 +848,7 @@ func TestPrimeableCacheAside_EdgeCases_SpecialValues(t *testing.T) { }) t.Run("Set with unicode and special characters", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -866,7 +868,7 @@ func TestPrimeableCacheAside_EdgeCases_SpecialValues(t *testing.T) { }) t.Run("SetMulti with empty string values", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -889,7 +891,7 @@ func TestPrimeableCacheAside_EdgeCases_SpecialValues(t *testing.T) { }) t.Run("Set with very large value", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -912,7 +914,7 @@ func TestPrimeableCacheAside_EdgeCases_SpecialValues(t *testing.T) { func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { t.Run("Get racing with Set - Set completes first", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -939,7 +941,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { }) t.Run("Get starts then Set completes - Get should see new value on retry", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -989,7 +991,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { }) t.Run("GetMulti racing with SetMulti on overlapping keys", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1041,7 +1043,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { }) t.Run("ForceSet triggers invalidation for waiting Get", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1082,7 +1084,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { }) t.Run("ForceSet overrides lock while Get is holding it", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1161,7 +1163,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { }) t.Run("ForceSetMulti overrides locks while GetMulti is holding them", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1258,7 +1260,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { }) t.Run("Set in progress + ForceSet overwrites lock during callback", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1324,7 +1326,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { }) t.Run("SetMulti in progress + ForceSetMulti overwrites some locks during callback", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1436,7 +1438,7 @@ func TestPrimeableCacheAside_EdgeCases_GetRacingWithSet(t *testing.T) { func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { t.Run("Set overwrites existing non-lock value", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1469,7 +1471,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { }) t.Run("Set immediately after Del triggers new write", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1504,7 +1506,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { // for comprehensive distributed coordination tests t.Run("Get with callback error does not cache", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1528,7 +1530,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { }) t.Run("GetMulti with empty keys slice", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1542,7 +1544,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { }) t.Run("GetMulti with callback error does not cache", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1569,7 +1571,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { }) t.Run("Del on non-existent key succeeds", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1581,7 +1583,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { }) t.Run("DelMulti on non-existent keys succeeds", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1596,7 +1598,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { }) t.Run("DelMulti with empty keys slice", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1611,7 +1613,7 @@ func TestPrimeableCacheAside_EdgeCases_Additional(t *testing.T) { // This would happen if a previous Get operation completed, left a lock in Redis, // and then Set is called. func TestPrimeableCacheAside_SetDoesNotBlockOnRedisLock(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1642,7 +1644,7 @@ func TestPrimeableCacheAside_SetDoesNotBlockOnRedisLock(t *testing.T) { require.NoError(t, err) // Should have waited approximately 1 second for lock to expire - assert.Greater(t, elapsed, time.Millisecond*900, "Should have waited for lock TTL") + assert.Greater(t, elapsed, time.Millisecond*850, "Should have waited for lock TTL") assert.Less(t, elapsed, time.Second*2, "Should not have blocked indefinitely") // Verify value was set @@ -1657,7 +1659,7 @@ func TestPrimeableCacheAside_SetDoesNotBlockOnRedisLock(t *testing.T) { /* func TestPrimeableCacheAside_SetWithCallback(t *testing.T) { t.Run("acquires lock, executes callback, caches result", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1685,7 +1687,7 @@ func TestPrimeableCacheAside_SetWithCallback(t *testing.T) { }) t.Run("concurrent Set operations coordinate via locking", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client1 := makeClientWithSet(t, addr) defer client1.Client().Close() client2 := makeClientWithSet(t, addr) @@ -1732,7 +1734,7 @@ func TestPrimeableCacheAside_SetWithCallback(t *testing.T) { }) t.Run("callback error prevents caching", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1759,7 +1761,7 @@ func TestPrimeableCacheAside_SetWithCallback(t *testing.T) { // TestPrimeableCacheAside_SetMultiWithCallback tests batch coordinated cache updates with a callback. func TestPrimeableCacheAside_SetMultiWithCallback(t *testing.T) { t.Run("acquires locks, executes callback, caches results", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1803,7 +1805,7 @@ func TestPrimeableCacheAside_SetMultiWithCallback(t *testing.T) { }) t.Run("empty keys returns empty result", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1818,7 +1820,7 @@ func TestPrimeableCacheAside_SetMultiWithCallback(t *testing.T) { }) t.Run("callback error prevents caching", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() client := makeClientWithSet(t, addr) defer client.Close() @@ -1850,7 +1852,7 @@ func TestPrimeableCacheAside_SetMultiWithCallback(t *testing.T) { // cache key locks and doesn't act like ForceSet func TestPrimeableCacheAside_SetLockingBehavior(t *testing.T) { t.Run("Set respects new Get operations that start after read lock cleared", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "set-locking:" + uuid.New().String() client, err := redcache.NewPrimeableCacheAside( @@ -1908,7 +1910,7 @@ func TestPrimeableCacheAside_SetLockingBehavior(t *testing.T) { }) t.Run("Set cannot overwrite active read lock from concurrent Get", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "set-no-overwrite:" + uuid.New().String() client1, err := redcache.NewPrimeableCacheAside( @@ -1956,7 +1958,7 @@ func TestPrimeableCacheAside_SetLockingBehavior(t *testing.T) { require.NoError(t, err) // Set should have waited for Get - assert.Greater(t, setDuration, 900*time.Millisecond, "Set should wait for Get lock") + assert.Greater(t, setDuration, 850*time.Millisecond, "Set should wait for Get lock") <-getDone assert.NoError(t, getErr) @@ -1972,7 +1974,7 @@ func TestPrimeableCacheAside_SetLockingBehavior(t *testing.T) { }) t.Run("ForceSet bypasses locks while Set respects them", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "force-vs-set:" + uuid.New().String() client, err := redcache.NewPrimeableCacheAside( @@ -2030,7 +2032,7 @@ func TestPrimeableCacheAside_SetLockingBehavior(t *testing.T) { }) t.Run("Set properly uses compare-and-swap to ensure lock is held", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "set-cas:" + uuid.New().String() client, err := redcache.NewPrimeableCacheAside( @@ -2062,7 +2064,7 @@ func TestPrimeableCacheAside_SetLockingBehavior(t *testing.T) { // This test file should focus on Set/SetMulti behavior and their interactions with Get/GetMulti t.Run("SetMulti with callback exceeding lock TTL", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "setmulti-exceed-1:" + uuid.New().String() key2 := "setmulti-exceed-2:" + uuid.New().String() @@ -2114,7 +2116,7 @@ func TestPrimeableCacheAside_SetLockingBehavior(t *testing.T) { // so tests verify operations complete in < 40ms to prove invalidation is working. func TestPrimeableCacheAside_SetInvalidationMechanism(t *testing.T) { t.Run("Set receives cache lock invalidation, not just ticker", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "set-invalidation:" + uuid.New().String() client, err := redcache.NewPrimeableCacheAside( @@ -2169,7 +2171,7 @@ func TestPrimeableCacheAside_SetInvalidationMechanism(t *testing.T) { }) t.Run("Set waits for another Set's cache lock via invalidation", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key := "set-wait-set:" + uuid.New().String() client1, err := redcache.NewPrimeableCacheAside( @@ -2232,7 +2234,7 @@ func TestPrimeableCacheAside_SetInvalidationMechanism(t *testing.T) { }) t.Run("SetMulti receives invalidations during sequential acquisition", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() key1 := "setmulti-inv-1:" + uuid.New().String() key2 := "setmulti-inv-2:" + uuid.New().String() key3 := "setmulti-inv-3:" + uuid.New().String() @@ -2323,7 +2325,7 @@ func TestPrimeableCacheAside_SetInvalidationMechanism(t *testing.T) { // achievable with current test infrastructure without risky architectural changes. func TestPrimeableCacheAside_SetMultiPartialFailure(t *testing.T) { t.Run("SetMulti with middle key locked - sequential acquisition", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Use a shared UUID prefix to ensure sort order prefix := uuid.New().String() @@ -2394,7 +2396,7 @@ func TestPrimeableCacheAside_SetMultiPartialFailure(t *testing.T) { }) t.Run("SetMulti with out-of-order success triggers restore", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Use a shared UUID prefix to ensure sort order prefix := uuid.New().String() @@ -2462,7 +2464,7 @@ func TestPrimeableCacheAside_SetMultiPartialFailure(t *testing.T) { }) t.Run("SetMulti TTL refresh during multi-retry waiting", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Use a shared UUID prefix to ensure sort order prefix := uuid.New().String() @@ -2537,7 +2539,7 @@ func TestPrimeableCacheAside_SetMultiPartialFailure(t *testing.T) { }) t.Run("SetMulti keysNotIn correctly filters remaining keys", func(t *testing.T) { - ctx := context.Background() + ctx := t.Context() // Use a shared UUID prefix to ensure sort order prefix := uuid.New().String() diff --git a/primeable_cacheaside_unit_test.go b/primeable_cacheaside_unit_test.go new file mode 100644 index 0000000..3965340 --- /dev/null +++ b/primeable_cacheaside_unit_test.go @@ -0,0 +1,609 @@ +package redcache + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/redis/rueidis" + "github.com/redis/rueidis/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/dcbickfo/redcache/internal/lockmanager" + mocklockmanager "github.com/dcbickfo/redcache/mocks/lockmanager" + mocklogger "github.com/dcbickfo/redcache/mocks/logger" + mockwritelock "github.com/dcbickfo/redcache/mocks/writelock" +) + +// TestPrimeableCacheAside_Set tests all Set() scenarios using subtests +func TestPrimeableCacheAside_Set(t *testing.T) { + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, keysAndValues ...interface{}) {}, + ErrorFunc: func(msg string, keysAndValues ...interface{}) {}, + } + + t.Run("SuccessfulSet_NoReadLock", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + lockValue := "lock-123" + computedValue := "computed-value" + + mockClient := mock.NewClient(ctrl) + + // WaitForKeyWithSubscription handles Redis internally, no DoCache expectation needed + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + IsLockValueFunc: func(value string) bool { + return len(value) > 0 && value[0] == 0x01 + }, + WaitForKeyWithSubscriptionFunc: func(ctx context.Context, key string, timeout time.Duration) (lockmanager.WaitHandle, string, error) { + // No lock value found + return nil, "", rueidis.Nil + }, + } + + mockWriteLockManager := &mockwritelock.MockWriteLockManager{ + AcquireWriteLockFunc: func(ctx context.Context, key string) (string, error) { + return lockValue, nil + }, + CommitWriteLocksFunc: func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (map[string]string, map[string]error) { + // Return success for all keys + return actualValues, nil + }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + writeLockManager: mockWriteLockManager, + } + + err := pca.Set(ctx, time.Minute, key, func(ctx context.Context, key string) (string, error) { + return computedValue, nil + }) + + require.NoError(t, err) + + // Verify write lock was acquired + assert.Len(t, mockWriteLockManager.AcquireWriteLockCalls(), 1) + assert.Equal(t, key, mockWriteLockManager.AcquireWriteLockCalls()[0].Key) + + // Verify commit was called + assert.Len(t, mockWriteLockManager.CommitWriteLocksCalls(), 1) + commitCall := mockWriteLockManager.CommitWriteLocksCalls()[0] + assert.Equal(t, time.Minute, commitCall.TTL) + assert.Equal(t, lockValue, commitCall.LockValues[key]) + assert.Equal(t, computedValue, commitCall.ActualValues[key]) + }) + + t.Run("WaitForReadLock_ThenSuccessfulSet", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + lockValue := "lock-123" + readLockValue := "\x01read-lock-value" + computedValue := "computed-value" + + mockClient := mock.NewClient(ctrl) + + mockWaitHandle := &mocklockmanager.MockWaitHandle{ + WaitFunc: func(ctx context.Context) error { + // Simulate successful wait for read lock to complete + return nil + }, + } + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + IsLockValueFunc: func(value string) bool { + return len(value) > 0 && value[0] == 0x01 + }, + WaitForKeyWithSubscriptionFunc: func(ctx context.Context, key string, timeout time.Duration) (lockmanager.WaitHandle, string, error) { + // Read lock exists + return mockWaitHandle, readLockValue, nil + }, + } + + mockWriteLockManager := &mockwritelock.MockWriteLockManager{ + AcquireWriteLockFunc: func(ctx context.Context, key string) (string, error) { + return lockValue, nil + }, + CommitWriteLocksFunc: func(ctx context.Context, ttl time.Duration, lockValues map[string]string, actualValues map[string]string) (map[string]string, map[string]error) { + return actualValues, nil + }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + writeLockManager: mockWriteLockManager, + } + + err := pca.Set(ctx, time.Minute, key, func(ctx context.Context, key string) (string, error) { + return computedValue, nil + }) + + require.NoError(t, err) + + // Verify wait was called + assert.Len(t, mockWaitHandle.WaitCalls(), 1) + + // Verify write lock was acquired after wait + assert.Len(t, mockWriteLockManager.AcquireWriteLockCalls(), 1) + }) + + t.Run("CallbackError_ReleasesLock", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + lockValue := "lock-123" + callbackErr := errors.New("callback failed") + + mockClient := mock.NewClient(ctrl) + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + IsLockValueFunc: func(value string) bool { + return len(value) > 0 && value[0] == 0x01 + }, + WaitForKeyWithSubscriptionFunc: func(ctx context.Context, key string, timeout time.Duration) (lockmanager.WaitHandle, string, error) { + return nil, "", rueidis.Nil + }, + } + + mockWriteLockManager := &mockwritelock.MockWriteLockManager{ + AcquireWriteLockFunc: func(ctx context.Context, key string) (string, error) { + return lockValue, nil + }, + ReleaseWriteLockFunc: func(ctx context.Context, key string, lockVal string) error { + return nil + }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + writeLockManager: mockWriteLockManager, + } + + err := pca.Set(ctx, time.Minute, key, func(ctx context.Context, key string) (string, error) { + return "", callbackErr + }) + + require.Error(t, err) + assert.Equal(t, callbackErr, err) + + // Verify lock was released (not committed) on error + assert.Len(t, mockWriteLockManager.ReleaseWriteLockCalls(), 1) + releaseCall := mockWriteLockManager.ReleaseWriteLockCalls()[0] + assert.Equal(t, key, releaseCall.Key) + assert.Equal(t, lockValue, releaseCall.LockValue) + + // Verify commit was NOT called + assert.Len(t, mockWriteLockManager.CommitWriteLocksCalls(), 0) + }) + + t.Run("AcquireLockFailure", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + lockErr := errors.New("lock acquisition failed") + + mockClient := mock.NewClient(ctrl) + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + IsLockValueFunc: func(value string) bool { + return len(value) > 0 && value[0] == 0x01 + }, + WaitForKeyWithSubscriptionFunc: func(ctx context.Context, key string, timeout time.Duration) (lockmanager.WaitHandle, string, error) { + return nil, "", rueidis.Nil + }, + } + + mockWriteLockManager := &mockwritelock.MockWriteLockManager{ + AcquireWriteLockFunc: func(ctx context.Context, key string) (string, error) { + return "", lockErr + }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + writeLockManager: mockWriteLockManager, + } + + err := pca.Set(ctx, time.Minute, key, func(ctx context.Context, key string) (string, error) { + t.Error("Callback should not be called when lock acquisition fails") + return "", nil + }) + + require.Error(t, err) + assert.Equal(t, lockErr, err) + }) +} + +// TestPrimeableCacheAside_SetMulti tests all SetMulti() scenarios using subtests +func TestPrimeableCacheAside_SetMulti(t *testing.T) { + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, keysAndValues ...interface{}) {}, + ErrorFunc: func(msg string, keysAndValues ...interface{}) {}, + } + + t.Run("AllKeysSuccess", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + keys := []string{"key1", "key2", "key3"} + lockValues := map[string]string{ + "key1": "lock-1", + "key2": "lock-2", + "key3": "lock-3", + } + computedValues := map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + mockClient := mock.NewClient(ctrl) + + // SetMulti uses waitForReadLocks which calls WaitForKey and CheckMultiKeysLocked + mockWaitHandle := &mocklockmanager.MockWaitHandle{ + WaitFunc: func(ctx context.Context) error { + return nil + }, + } + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + IsLockValueFunc: func(value string) bool { + return len(value) > 0 && value[0] == 0x01 + }, + WaitForKeyFunc: func(key string) lockmanager.WaitHandle { + return mockWaitHandle + }, + CheckMultiKeysLockedFunc: func(ctx context.Context, keys []string) []string { + // No keys are locked + return []string{} + }, + } + + mockWriteLockManager := &mockwritelock.MockWriteLockManager{ + AcquireMultiWriteLocksSequentialFunc: func(ctx context.Context, keys []string) (map[string]string, map[string]string, error) { + // All locks acquired successfully, no saved values + return lockValues, map[string]string{}, nil + }, + ReleaseWriteLockFunc: func(ctx context.Context, key string, lockVal string) error { + // Release called for cleanup + return nil + }, + CommitWriteLocksFunc: func(ctx context.Context, ttl time.Duration, lockVals map[string]string, actualVals map[string]string) (map[string]string, map[string]error) { + // All commits successful + return actualVals, nil + }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + writeLockManager: mockWriteLockManager, + } + + result, err := pca.SetMulti(ctx, time.Minute, keys, func(ctx context.Context, keys []string) (map[string]string, error) { + return computedValues, nil + }) + + require.NoError(t, err) + assert.Equal(t, computedValues, result) + + // Verify acquire multi sequential was called + assert.Len(t, mockWriteLockManager.AcquireMultiWriteLocksSequentialCalls(), 1) + + // Verify commit was called with all values + assert.Len(t, mockWriteLockManager.CommitWriteLocksCalls(), 1) + commitCall := mockWriteLockManager.CommitWriteLocksCalls()[0] + assert.Equal(t, time.Minute, commitCall.TTL) + assert.Equal(t, lockValues, commitCall.LockValues) + assert.Equal(t, computedValues, commitCall.ActualValues) + }) + + t.Run("AcquireLockError", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + keys := []string{"key1", "key2", "key3"} + lockErr := errors.New("lock acquisition failed") + + mockClient := mock.NewClient(ctrl) + + mockWaitHandle := &mocklockmanager.MockWaitHandle{ + WaitFunc: func(ctx context.Context) error { + return nil + }, + } + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + IsLockValueFunc: func(value string) bool { + return len(value) > 0 && value[0] == 0x01 + }, + WaitForKeyFunc: func(key string) lockmanager.WaitHandle { + return mockWaitHandle + }, + CheckMultiKeysLockedFunc: func(ctx context.Context, keys []string) []string { + return []string{} + }, + } + + mockWriteLockManager := &mockwritelock.MockWriteLockManager{ + AcquireMultiWriteLocksSequentialFunc: func(ctx context.Context, keys []string) (map[string]string, map[string]string, error) { + // Lock acquisition failed + return nil, nil, lockErr + }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + writeLockManager: mockWriteLockManager, + } + + result, err := pca.SetMulti(ctx, time.Minute, keys, func(ctx context.Context, keys []string) (map[string]string, error) { + t.Error("Callback should not be called when lock acquisition fails") + return nil, nil + }) + + require.Error(t, err) + assert.Equal(t, lockErr, err) + assert.Nil(t, result) + }) + + t.Run("CallbackError_ReleasesLocks", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + keys := []string{"key1", "key2"} + lockValues := map[string]string{ + "key1": "lock-1", + "key2": "lock-2", + } + callbackErr := errors.New("callback failed") + + mockClient := mock.NewClient(ctrl) + + mockWaitHandle := &mocklockmanager.MockWaitHandle{ + WaitFunc: func(ctx context.Context) error { + return nil + }, + } + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + IsLockValueFunc: func(value string) bool { + return len(value) > 0 && value[0] == 0x01 + }, + WaitForKeyFunc: func(key string) lockmanager.WaitHandle { + return mockWaitHandle + }, + CheckMultiKeysLockedFunc: func(ctx context.Context, keys []string) []string { + return []string{} + }, + } + + mockWriteLockManager := &mockwritelock.MockWriteLockManager{ + AcquireMultiWriteLocksSequentialFunc: func(ctx context.Context, keys []string) (map[string]string, map[string]string, error) { + return lockValues, map[string]string{}, nil + }, + ReleaseWriteLockFunc: func(ctx context.Context, key string, lockVal string) error { + // Release called for cleanup + return nil + }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + writeLockManager: mockWriteLockManager, + } + + result, err := pca.SetMulti(ctx, time.Minute, keys, func(ctx context.Context, keys []string) (map[string]string, error) { + return nil, callbackErr + }) + + require.Error(t, err) + assert.Equal(t, callbackErr, err) + assert.Nil(t, result) + + // Verify locks were released (not committed) + // ReleaseWriteLock is called once for each key (2 keys) + assert.Len(t, mockWriteLockManager.ReleaseWriteLockCalls(), 2) + + // Verify commit was NOT called + assert.Len(t, mockWriteLockManager.CommitWriteLocksCalls(), 0) + }) +} + +// TestPrimeableCacheAside_ForceSet tests ForceSet() scenarios using subtests +func TestPrimeableCacheAside_ForceSet(t *testing.T) { + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, keysAndValues ...interface{}) {}, + ErrorFunc: func(msg string, keysAndValues ...interface{}) {}, + } + + t.Run("SuccessfulForceSet", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + key := "test:key" + value := "force-value" + + mockClient := mock.NewClient(ctrl) + + // ForceSet uses Do to set the value directly + mockClient.EXPECT(). + Do(gomock.Any(), gomock.Any()). + Return(mock.Result(mock.RedisString("OK"))) + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + } + + err := pca.ForceSet(ctx, time.Minute, key, value) + + require.NoError(t, err) + }) +} + +// TestPrimeableCacheAside_ForceSetMulti tests ForceSetMulti() scenarios using subtests +func TestPrimeableCacheAside_ForceSetMulti(t *testing.T) { + mockLogger := &mocklogger.MockLogger{ + DebugFunc: func(msg string, keysAndValues ...interface{}) {}, + ErrorFunc: func(msg string, keysAndValues ...interface{}) {}, + } + + t.Run("SuccessfulForceSetMulti", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + values := map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + mockClient := mock.NewClient(ctrl) + + // ForceSetMulti uses DoMulti for batch SET + mockClient.EXPECT(). + DoMulti(gomock.Any(), gomock.Any()). + Return([]rueidis.RedisResult{ + mock.Result(mock.RedisString("OK")), + mock.Result(mock.RedisString("OK")), + mock.Result(mock.RedisString("OK")), + }) + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + } + + err := pca.ForceSetMulti(ctx, time.Minute, values) + + require.NoError(t, err) + }) + + t.Run("PartialFailure", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := t.Context() + values := map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + mockClient := mock.NewClient(ctrl) + + // Some keys succeed, one fails + mockClient.EXPECT(). + DoMulti(gomock.Any(), gomock.Any()). + Return([]rueidis.RedisResult{ + mock.Result(mock.RedisString("OK")), + mock.ErrorResult(errors.New("redis error")), + mock.Result(mock.RedisString("OK")), + }) + + mockLockManager := &mocklockmanager.MockLockManager{ + LockPrefixFunc: func() string { return "__redcache:lock:" }, + } + + pca := &PrimeableCacheAside{ + CacheAside: &CacheAside{ + client: mockClient, + lockManager: mockLockManager, + lockTTL: 5 * time.Second, + logger: mockLogger, + maxRetries: 3, + }, + } + + err := pca.ForceSetMulti(ctx, time.Minute, values) + + // ForceSetMulti should return error if any key fails + require.Error(t, err) + }) +} diff --git a/test_helpers_test.go b/test_helpers_test.go index 6b8a62d..2a56aa4 100644 --- a/test_helpers_test.go +++ b/test_helpers_test.go @@ -10,6 +10,9 @@ import ( "github.com/dcbickfo/redcache/internal/cmdx" ) +// Shared test configuration variables +var addr = []string{"localhost:6379"} + // assertValueEquals checks if the actual value matches the expected value. // This helper reduces boilerplate in tests that need to verify returned values. func assertValueEquals(t *testing.T, expected, actual string, msgAndArgs ...interface{}) bool {