From 4274cffa9d1ef8f0b3c5bbe39eac856844205753 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 29 Nov 2025 00:46:06 +0200 Subject: [PATCH 1/5] fix: double freeturn bug --- config_comparison_test.go | 226 ++++++++++++++++++ internal/pool/double_freeturn_simple_test.go | 158 +++++++++++++ internal/pool/double_freeturn_test.go | 229 +++++++++++++++++++ internal/pool/pool.go | 38 ++- internal/pool/want_conn.go | 9 + redis.go | 9 +- 6 files changed, 657 insertions(+), 12 deletions(-) create mode 100644 config_comparison_test.go create mode 100644 internal/pool/double_freeturn_simple_test.go create mode 100644 internal/pool/double_freeturn_test.go diff --git a/config_comparison_test.go b/config_comparison_test.go new file mode 100644 index 000000000..b1cf9a76a --- /dev/null +++ b/config_comparison_test.go @@ -0,0 +1,226 @@ +package redis + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestBadConfigurationHighLoad demonstrates the problem with default configuration +// under high load with slow dials. +func TestBadConfigurationHighLoad(t *testing.T) { + var dialCount atomic.Int32 + var dialsFailed atomic.Int32 + var dialsSucceeded atomic.Int32 + + // Simulate slow network - 300ms per dial (e.g., network latency, TLS handshake) + slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + dialCount.Add(1) + select { + case <-time.After(300 * time.Millisecond): + dialsSucceeded.Add(1) + return &net.TCPConn{}, nil + case <-ctx.Done(): + dialsFailed.Add(1) + return nil, ctx.Err() + } + } + + // BAD CONFIGURATION: Default settings + // On an 8-CPU machine: + // - PoolSize = 10 * 8 = 80 + // - MaxConcurrentDials = 80 + // - MinIdleConns = 0 (no pre-warming) + opt := &Options{ + Addr: "localhost:6379", + Dialer: slowDialer, + PoolSize: 80, // Default: 10 * GOMAXPROCS + MaxConcurrentDials: 80, // Default: same as PoolSize + MinIdleConns: 0, // Default: no pre-warming + DialTimeout: 5 * time.Second, + } + + client := NewClient(opt) + defer client.Close() + + // Simulate high load: 200 concurrent requests with 200ms timeout + // This simulates a burst of traffic (e.g., after a deployment or cache miss) + const numRequests = 200 + const requestTimeout = 200 * time.Millisecond + + var wg sync.WaitGroup + var timeouts atomic.Int32 + var successes atomic.Int32 + var errors atomic.Int32 + + startTime := time.Now() + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + + _, err := client.Get(ctx, fmt.Sprintf("key-%d", id)).Result() + + if err != nil { + if ctx.Err() == context.DeadlineExceeded || err == context.DeadlineExceeded { + timeouts.Add(1) + } else { + errors.Add(1) + } + } else { + successes.Add(1) + } + }(i) + + // Stagger requests slightly to simulate real traffic + if i%20 == 0 { + time.Sleep(5 * time.Millisecond) + } + } + + wg.Wait() + totalTime := time.Since(startTime) + + timeoutRate := float64(timeouts.Load()) / float64(numRequests) * 100 + successRate := float64(successes.Load()) / float64(numRequests) * 100 + + t.Logf("\n=== BAD CONFIGURATION (Default Settings) ===") + t.Logf("Configuration:") + t.Logf(" PoolSize: %d", opt.PoolSize) + t.Logf(" MaxConcurrentDials: %d", opt.MaxConcurrentDials) + t.Logf(" MinIdleConns: %d", opt.MinIdleConns) + t.Logf("\nResults:") + t.Logf(" Total time: %v", totalTime) + t.Logf(" Successes: %d (%.1f%%)", successes.Load(), successRate) + t.Logf(" Timeouts: %d (%.1f%%)", timeouts.Load(), timeoutRate) + t.Logf(" Other errors: %d", errors.Load()) + t.Logf(" Total dials: %d (succeeded: %d, failed: %d)", + dialCount.Load(), dialsSucceeded.Load(), dialsFailed.Load()) + + // With bad configuration: + // - MaxConcurrentDials=80 means only 80 dials can run concurrently + // - Each dial takes 300ms, but request timeout is 200ms + // - Requests timeout waiting for dial slots + // - Expected: High timeout rate (>50%) + + if timeoutRate < 50 { + t.Logf("WARNING: Expected high timeout rate (>50%%), got %.1f%%. Test may not be stressing the system enough.", timeoutRate) + } +} + +// TestGoodConfigurationHighLoad demonstrates how proper configuration fixes the problem +func TestGoodConfigurationHighLoad(t *testing.T) { + var dialCount atomic.Int32 + var dialsFailed atomic.Int32 + var dialsSucceeded atomic.Int32 + + // Same slow dialer - 300ms per dial + slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + dialCount.Add(1) + select { + case <-time.After(300 * time.Millisecond): + dialsSucceeded.Add(1) + return &net.TCPConn{}, nil + case <-ctx.Done(): + dialsFailed.Add(1) + return nil, ctx.Err() + } + } + + // GOOD CONFIGURATION: Tuned for high load + opt := &Options{ + Addr: "localhost:6379", + Dialer: slowDialer, + PoolSize: 300, // Increased from 80 + MaxConcurrentDials: 300, // Increased from 80 + MinIdleConns: 50, // Pre-warm the pool + DialTimeout: 5 * time.Second, + } + + client := NewClient(opt) + defer client.Close() + + // Wait for pool to warm up + time.Sleep(100 * time.Millisecond) + + // Same load: 200 concurrent requests with 200ms timeout + const numRequests = 200 + const requestTimeout = 200 * time.Millisecond + + var wg sync.WaitGroup + var timeouts atomic.Int32 + var successes atomic.Int32 + var errors atomic.Int32 + + startTime := time.Now() + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + + _, err := client.Get(ctx, fmt.Sprintf("key-%d", id)).Result() + + if err != nil { + if ctx.Err() == context.DeadlineExceeded || err == context.DeadlineExceeded { + timeouts.Add(1) + } else { + errors.Add(1) + } + } else { + successes.Add(1) + } + }(i) + + // Stagger requests slightly + if i%20 == 0 { + time.Sleep(5 * time.Millisecond) + } + } + + wg.Wait() + totalTime := time.Since(startTime) + + timeoutRate := float64(timeouts.Load()) / float64(numRequests) * 100 + successRate := float64(successes.Load()) / float64(numRequests) * 100 + + t.Logf("\n=== GOOD CONFIGURATION (Tuned Settings) ===") + t.Logf("Configuration:") + t.Logf(" PoolSize: %d", opt.PoolSize) + t.Logf(" MaxConcurrentDials: %d", opt.MaxConcurrentDials) + t.Logf(" MinIdleConns: %d", opt.MinIdleConns) + t.Logf("\nResults:") + t.Logf(" Total time: %v", totalTime) + t.Logf(" Successes: %d (%.1f%%)", successes.Load(), successRate) + t.Logf(" Timeouts: %d (%.1f%%)", timeouts.Load(), timeoutRate) + t.Logf(" Other errors: %d", errors.Load()) + t.Logf(" Total dials: %d (succeeded: %d, failed: %d)", + dialCount.Load(), dialsSucceeded.Load(), dialsFailed.Load()) + + // With good configuration: + // - Higher MaxConcurrentDials allows more concurrent dials + // - MinIdleConns pre-warms the pool + // - Expected: Low timeout rate (<20%) + + if timeoutRate > 20 { + t.Errorf("Expected low timeout rate (<20%%), got %.1f%%", timeoutRate) + } +} + +// TestConfigurationComparison runs both tests and shows the difference +func TestConfigurationComparison(t *testing.T) { + t.Run("BadConfiguration", TestBadConfigurationHighLoad) + t.Run("GoodConfiguration", TestGoodConfigurationHighLoad) +} + diff --git a/internal/pool/double_freeturn_simple_test.go b/internal/pool/double_freeturn_simple_test.go new file mode 100644 index 000000000..3cfbff3e3 --- /dev/null +++ b/internal/pool/double_freeturn_simple_test.go @@ -0,0 +1,158 @@ +package pool_test + +import ( + "context" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// TestDoubleFreeTurnSimple tests the double-free bug with a simple scenario. +// This test FAILS with the OLD code and PASSES with the NEW code. +// +// Scenario: +// 1. Request A times out, Dial A completes and delivers connection to Request B +// 2. Request B's own Dial B completes later +// 3. With the bug: Dial B frees Request B's turn (even though Request B is using connection A) +// 4. Then Request B calls Put() and frees the turn AGAIN (double-free) +// 5. This allows more concurrent operations than PoolSize permits +// +// Detection method: +// - Try to acquire PoolSize+1 connections after the double-free +// - With the bug: All succeed (pool size violated) +// - With the fix: Only PoolSize succeed +func TestDoubleFreeTurnSimple(t *testing.T) { + ctx := context.Background() + + var dialCount atomic.Int32 + dialBComplete := make(chan struct{}) + requestBGotConn := make(chan struct{}) + requestBCalledPut := make(chan struct{}) + + controlledDialer := func(ctx context.Context) (net.Conn, error) { + count := dialCount.Add(1) + + if count == 1 { + // Dial A: takes 150ms + time.Sleep(150 * time.Millisecond) + t.Logf("Dial A completed") + } else if count == 2 { + // Dial B: takes 300ms (longer than Dial A) + time.Sleep(300 * time.Millisecond) + t.Logf("Dial B completed") + close(dialBComplete) + } else { + // Other dials: fast + time.Sleep(10 * time.Millisecond) + } + + return newDummyConn(), nil + } + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: controlledDialer, + PoolSize: 2, // Only 2 concurrent operations allowed + MaxConcurrentDials: 5, + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + }) + defer testPool.Close() + + // Request A: Short timeout (100ms), will timeout before dial completes (150ms) + go func() { + shortCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + _, err := testPool.Get(shortCtx) + if err != nil { + t.Logf("Request A: Timed out as expected: %v", err) + } + }() + + // Wait for Request A to start + time.Sleep(20 * time.Millisecond) + + // Request B: Long timeout, will receive connection from Request A's dial + requestBDone := make(chan struct{}) + go func() { + defer close(requestBDone) + + longCtx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + cn, err := testPool.Get(longCtx) + if err != nil { + t.Errorf("Request B: Should have received connection but got error: %v", err) + return + } + + t.Logf("Request B: Got connection from Request A's dial") + close(requestBGotConn) + + // Wait for dial B to complete + <-dialBComplete + + t.Logf("Request B: Dial B completed") + + // Wait a bit to allow Dial B goroutine to finish and call freeTurn() + time.Sleep(100 * time.Millisecond) + + // Signal that we're ready for the test to check semaphore state + close(requestBCalledPut) + + // Wait for the test to check QueueLen + time.Sleep(200 * time.Millisecond) + + t.Logf("Request B: Now calling Put()") + testPool.Put(ctx, cn) + t.Logf("Request B: Put() called") + }() + + // Wait for Request B to get the connection + <-requestBGotConn + + // Wait for Dial B to complete and freeTurn() to be called + <-requestBCalledPut + + // NOW WE'RE IN THE CRITICAL WINDOW + // Request B is holding a connection (from Dial A) + // Dial B has completed and returned (freeTurn() has been called) + // With the bug: + // - Dial B freed Request B's turn (BUG!) + // - QueueLen should be 0 + // With the fix: + // - Dial B did NOT free Request B's turn + // - QueueLen should be 1 (Request B still holds the turn) + + t.Logf("\n=== CRITICAL CHECK: QueueLen ===") + t.Logf("Request B is holding a connection, Dial B has completed and returned") + queueLen := testPool.QueueLen() + t.Logf("QueueLen: %d", queueLen) + + // Wait for Request B to finish + select { + case <-requestBDone: + case <-time.After(1 * time.Second): + t.Logf("Request B timed out") + } + + t.Logf("\n=== Results ===") + t.Logf("QueueLen during critical window: %d", queueLen) + t.Logf("Expected with fix: 1 (Request B still holds the turn)") + t.Logf("Expected with bug: 0 (Dial B freed Request B's turn)") + + if queueLen == 0 { + t.Errorf("DOUBLE-FREE BUG DETECTED!") + t.Errorf("QueueLen is 0, meaning Dial B freed Request B's turn") + t.Errorf("But Request B is still holding a connection, so its turn should NOT be freed yet") + } else if queueLen == 1 { + t.Logf("✓ CORRECT: QueueLen is 1") + t.Logf("Request B is still holding the turn (will be freed when Request B calls Put())") + } else { + t.Logf("Unexpected QueueLen: %d (expected 1 with fix, 0 with bug)", queueLen) + } +} + diff --git a/internal/pool/double_freeturn_test.go b/internal/pool/double_freeturn_test.go new file mode 100644 index 000000000..7c8fca8e6 --- /dev/null +++ b/internal/pool/double_freeturn_test.go @@ -0,0 +1,229 @@ +package pool + +import ( + "context" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +// TestDoubleFreeTurnBug demonstrates the double freeTurn bug where: +// 1. Dial goroutine creates a connection +// 2. Original waiter times out +// 3. putIdleConn delivers connection to another waiter +// 4. Dial goroutine calls freeTurn() (FIRST FREE) +// 5. Second waiter uses connection and calls Put() +// 6. Put() calls freeTurn() (SECOND FREE - BUG!) +// +// This causes the semaphore to be released twice, allowing more concurrent +// operations than PoolSize allows. +func TestDoubleFreeTurnBug(t *testing.T) { + var dialCount atomic.Int32 + var putCount atomic.Int32 + + // Slow dialer - 150ms per dial + slowDialer := func(ctx context.Context) (net.Conn, error) { + dialCount.Add(1) + select { + case <-time.After(150 * time.Millisecond): + server, client := net.Pipe() + go func() { + defer server.Close() + buf := make([]byte, 1024) + for { + _, err := server.Read(buf) + if err != nil { + return + } + } + }() + return client, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + opt := &Options{ + Dialer: slowDialer, + PoolSize: 10, // Small pool to make bug easier to trigger + MaxConcurrentDials: 10, + MinIdleConns: 0, + PoolTimeout: 100 * time.Millisecond, + DialTimeout: 5 * time.Second, + } + + connPool := NewConnPool(opt) + defer connPool.Close() + + // Scenario: + // 1. Request A starts dial (100ms timeout - will timeout before dial completes) + // 2. Request B arrives (500ms timeout - will wait in queue) + // 3. Request A times out at 100ms + // 4. Dial completes at 150ms + // 5. putIdleConn delivers connection to Request B + // 6. Dial goroutine calls freeTurn() - FIRST FREE + // 7. Request B uses connection and calls Put() + // 8. Put() calls freeTurn() - SECOND FREE (BUG!) + + var wg sync.WaitGroup + + // Request A: Short timeout, will timeout before dial completes + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + cn, err := connPool.Get(ctx) + if err != nil { + // Expected to timeout + t.Logf("Request A timed out as expected: %v", err) + } else { + // Should not happen + t.Errorf("Request A should have timed out but got connection") + connPool.Put(ctx, cn) + putCount.Add(1) + } + }() + + // Wait a bit for Request A to start dialing + time.Sleep(10 * time.Millisecond) + + // Request B: Long timeout, will receive the connection from putIdleConn + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + cn, err := connPool.Get(ctx) + if err != nil { + t.Errorf("Request B should have succeeded but got error: %v", err) + } else { + t.Logf("Request B got connection successfully") + // Use the connection briefly + time.Sleep(50 * time.Millisecond) + connPool.Put(ctx, cn) + putCount.Add(1) + } + }() + + wg.Wait() + + // Check results + t.Logf("\n=== Results ===") + t.Logf("Dials: %d", dialCount.Load()) + t.Logf("Puts: %d", putCount.Load()) + + // The bug is hard to detect directly without instrumenting freeTurn, + // but we can verify the scenario works correctly: + // - Request A should timeout + // - Request B should succeed and get the connection + // - 1-2 dials may occur (Request A starts one, Request B may start another) + // - 1 put should occur (Request B returning the connection) + + if putCount.Load() != 1 { + t.Errorf("Expected 1 put, got %d", putCount.Load()) + } + + t.Logf("✓ Scenario completed successfully") + t.Logf("Note: The double freeTurn bug would cause semaphore to be released twice,") + t.Logf("allowing more concurrent operations than PoolSize permits.") + t.Logf("With the fix, putIdleConn returns true when delivering to a waiter,") + t.Logf("preventing the dial goroutine from calling freeTurn (waiter will call it later).") +} + +// TestDoubleFreeTurnHighConcurrency tests the bug under high concurrency +func TestDoubleFreeTurnHighConcurrency(t *testing.T) { + var dialCount atomic.Int32 + var getSuccesses atomic.Int32 + var getFailures atomic.Int32 + + slowDialer := func(ctx context.Context) (net.Conn, error) { + dialCount.Add(1) + select { + case <-time.After(200 * time.Millisecond): + server, client := net.Pipe() + go func() { + defer server.Close() + buf := make([]byte, 1024) + for { + _, err := server.Read(buf) + if err != nil { + return + } + } + }() + return client, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + opt := &Options{ + Dialer: slowDialer, + PoolSize: 20, + MaxConcurrentDials: 20, + MinIdleConns: 0, + PoolTimeout: 100 * time.Millisecond, + DialTimeout: 5 * time.Second, + } + + connPool := NewConnPool(opt) + defer connPool.Close() + + // Create many requests with varying timeouts + // Some will timeout before dial completes, triggering the putIdleConn delivery path + const numRequests = 100 + var wg sync.WaitGroup + + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Vary timeout: some short (will timeout), some long (will succeed) + timeout := 100 * time.Millisecond + if id%3 == 0 { + timeout = 500 * time.Millisecond + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cn, err := connPool.Get(ctx) + if err != nil { + getFailures.Add(1) + } else { + getSuccesses.Add(1) + time.Sleep(10 * time.Millisecond) + connPool.Put(ctx, cn) + } + }(i) + + // Stagger requests + if i%10 == 0 { + time.Sleep(5 * time.Millisecond) + } + } + + wg.Wait() + + t.Logf("\n=== High Concurrency Results ===") + t.Logf("Requests: %d", numRequests) + t.Logf("Successes: %d", getSuccesses.Load()) + t.Logf("Failures: %d", getFailures.Load()) + t.Logf("Dials: %d", dialCount.Load()) + + // Verify that some requests succeeded despite timeouts + // This exercises the putIdleConn delivery path + if getSuccesses.Load() == 0 { + t.Errorf("Expected some successful requests, got 0") + } + + t.Logf("✓ High concurrency test completed") + t.Logf("Note: This test exercises the putIdleConn delivery path where the bug occurs") +} + diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 184321c18..5ca6a29b3 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -558,6 +558,10 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { return nil, ctx.Err() } + // Use context.Background() as parent to allow orphaned connection reuse. + // If the requester times out, the dial continues and the connection can be + // delivered to other waiters via putIdleConn() -> dialsQueue.dequeue(). + // This is more efficient than canceling dials when individual requesters timeout. dialCtx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) w := &wantConn{ @@ -568,10 +572,13 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { var err error defer func() { if err != nil { + // Request failed/timed out + // If dial completed before timeout, try to deliver connection to other waiters if cn := w.cancel(); cn != nil { p.putIdleConn(ctx, cn) - p.freeTurn() + // freeTurn will be called by the dial goroutine or by the waiter who receives the connection } + // If dial hasn't completed yet, freeTurn will be called by the dial goroutine } }() @@ -595,14 +602,31 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { cn, cnErr := p.newConn(dialCtx, true) delivered := w.tryDeliver(cn, cnErr) if cnErr == nil && delivered { + // Connection delivered to original waiter - they will free the turn when done return } else if cnErr == nil && !delivered { + // Original waiter gave up or got another connection + // Try to deliver to other waiters or add to idle pool p.putIdleConn(dialCtx, cn) - p.freeTurn() - freeTurnCalled = true + + // Free the turn only if the original waiter did NOT get a connection + // If waiter got a connection (from another dial), they will free the turn when they call Put() + // If waiter did not get a connection (timed out), we should free the turn now + if !w.waiterGotConn() { + p.freeTurn() + freeTurnCalled = true + } else { + // Waiter got a connection from another dial - they will free the turn + freeTurnCalled = true // Mark as handled to prevent defer from freeing + } } else { - p.freeTurn() - freeTurnCalled = true + // Dial failed - free the turn only if waiter didn't get a connection + if !w.waiterGotConn() { + p.freeTurn() + freeTurnCalled = true + } else { + freeTurnCalled = true // Mark as handled + } } }(w) @@ -616,6 +640,8 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { } } +// putIdleConn tries to deliver the connection to waiting goroutines in the queue. +// If no waiters are available, adds the connection to the idle pool. func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) { for { w, ok := p.dialsQueue.dequeue() @@ -623,7 +649,7 @@ func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) { break } if w.tryDeliver(cn, nil) { - return + return // Connection delivered to waiter } } diff --git a/internal/pool/want_conn.go b/internal/pool/want_conn.go index 6f9e4bfa9..d5321fff1 100644 --- a/internal/pool/want_conn.go +++ b/internal/pool/want_conn.go @@ -3,6 +3,7 @@ package pool import ( "context" "sync" + "sync/atomic" ) type wantConn struct { @@ -10,6 +11,7 @@ type wantConn struct { ctx context.Context // context for dial, cleared after delivered or canceled cancelCtx context.CancelFunc done bool // true after delivered or canceled + gotConn atomic.Bool // true if waiter received a connection (not an error) result chan wantConnResult // channel to deliver connection or error } @@ -29,6 +31,7 @@ func (w *wantConn) tryDeliver(cn *Conn, err error) bool { } w.done = true + w.gotConn.Store(cn != nil && err == nil) w.ctx = nil w.result <- wantConnResult{cn: cn, err: err} @@ -57,6 +60,12 @@ func (w *wantConn) cancel() *Conn { return cn } +// waiterGotConn returns true if the waiter received a connection (not an error). +// This is used by the dial goroutine to determine if it should free the turn. +func (w *wantConn) waiterGotConn() bool { + return w.gotConn.Load() +} + type wantConnResult struct { cn *Conn err error diff --git a/redis.go b/redis.go index 73342e67b..bdf8e0fc6 100644 --- a/redis.go +++ b/redis.go @@ -399,13 +399,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { if finalState == pool.StateInitializing { // Another goroutine is initializing - WAIT for it to complete - // Use AwaitAndTransition to wait for IDLE or IN_USE state - // use DialTimeout as the timeout for the wait - waitCtx, cancel := context.WithTimeout(ctx, c.opt.DialTimeout) - defer cancel() - + // Use the request context directly to respect the caller's timeout + // This prevents goroutines from waiting longer than their request timeout finalState, err := cn.GetStateMachine().AwaitAndTransition( - waitCtx, + ctx, []pool.ConnState{pool.StateIdle, pool.StateInUse}, pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op) ) From c80f50b528dc6169f1a74947d62a37b49087ff5a Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 29 Nov 2025 00:54:00 +0200 Subject: [PATCH 2/5] use min timeout to avoid waiting for too long --- config_comparison_test.go | 226 -------------------------------------- redis.go | 29 ++++- 2 files changed, 26 insertions(+), 229 deletions(-) delete mode 100644 config_comparison_test.go diff --git a/config_comparison_test.go b/config_comparison_test.go deleted file mode 100644 index b1cf9a76a..000000000 --- a/config_comparison_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package redis - -import ( - "context" - "fmt" - "net" - "sync" - "sync/atomic" - "testing" - "time" -) - -// TestBadConfigurationHighLoad demonstrates the problem with default configuration -// under high load with slow dials. -func TestBadConfigurationHighLoad(t *testing.T) { - var dialCount atomic.Int32 - var dialsFailed atomic.Int32 - var dialsSucceeded atomic.Int32 - - // Simulate slow network - 300ms per dial (e.g., network latency, TLS handshake) - slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - dialCount.Add(1) - select { - case <-time.After(300 * time.Millisecond): - dialsSucceeded.Add(1) - return &net.TCPConn{}, nil - case <-ctx.Done(): - dialsFailed.Add(1) - return nil, ctx.Err() - } - } - - // BAD CONFIGURATION: Default settings - // On an 8-CPU machine: - // - PoolSize = 10 * 8 = 80 - // - MaxConcurrentDials = 80 - // - MinIdleConns = 0 (no pre-warming) - opt := &Options{ - Addr: "localhost:6379", - Dialer: slowDialer, - PoolSize: 80, // Default: 10 * GOMAXPROCS - MaxConcurrentDials: 80, // Default: same as PoolSize - MinIdleConns: 0, // Default: no pre-warming - DialTimeout: 5 * time.Second, - } - - client := NewClient(opt) - defer client.Close() - - // Simulate high load: 200 concurrent requests with 200ms timeout - // This simulates a burst of traffic (e.g., after a deployment or cache miss) - const numRequests = 200 - const requestTimeout = 200 * time.Millisecond - - var wg sync.WaitGroup - var timeouts atomic.Int32 - var successes atomic.Int32 - var errors atomic.Int32 - - startTime := time.Now() - - for i := 0; i < numRequests; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - - ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) - defer cancel() - - _, err := client.Get(ctx, fmt.Sprintf("key-%d", id)).Result() - - if err != nil { - if ctx.Err() == context.DeadlineExceeded || err == context.DeadlineExceeded { - timeouts.Add(1) - } else { - errors.Add(1) - } - } else { - successes.Add(1) - } - }(i) - - // Stagger requests slightly to simulate real traffic - if i%20 == 0 { - time.Sleep(5 * time.Millisecond) - } - } - - wg.Wait() - totalTime := time.Since(startTime) - - timeoutRate := float64(timeouts.Load()) / float64(numRequests) * 100 - successRate := float64(successes.Load()) / float64(numRequests) * 100 - - t.Logf("\n=== BAD CONFIGURATION (Default Settings) ===") - t.Logf("Configuration:") - t.Logf(" PoolSize: %d", opt.PoolSize) - t.Logf(" MaxConcurrentDials: %d", opt.MaxConcurrentDials) - t.Logf(" MinIdleConns: %d", opt.MinIdleConns) - t.Logf("\nResults:") - t.Logf(" Total time: %v", totalTime) - t.Logf(" Successes: %d (%.1f%%)", successes.Load(), successRate) - t.Logf(" Timeouts: %d (%.1f%%)", timeouts.Load(), timeoutRate) - t.Logf(" Other errors: %d", errors.Load()) - t.Logf(" Total dials: %d (succeeded: %d, failed: %d)", - dialCount.Load(), dialsSucceeded.Load(), dialsFailed.Load()) - - // With bad configuration: - // - MaxConcurrentDials=80 means only 80 dials can run concurrently - // - Each dial takes 300ms, but request timeout is 200ms - // - Requests timeout waiting for dial slots - // - Expected: High timeout rate (>50%) - - if timeoutRate < 50 { - t.Logf("WARNING: Expected high timeout rate (>50%%), got %.1f%%. Test may not be stressing the system enough.", timeoutRate) - } -} - -// TestGoodConfigurationHighLoad demonstrates how proper configuration fixes the problem -func TestGoodConfigurationHighLoad(t *testing.T) { - var dialCount atomic.Int32 - var dialsFailed atomic.Int32 - var dialsSucceeded atomic.Int32 - - // Same slow dialer - 300ms per dial - slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { - dialCount.Add(1) - select { - case <-time.After(300 * time.Millisecond): - dialsSucceeded.Add(1) - return &net.TCPConn{}, nil - case <-ctx.Done(): - dialsFailed.Add(1) - return nil, ctx.Err() - } - } - - // GOOD CONFIGURATION: Tuned for high load - opt := &Options{ - Addr: "localhost:6379", - Dialer: slowDialer, - PoolSize: 300, // Increased from 80 - MaxConcurrentDials: 300, // Increased from 80 - MinIdleConns: 50, // Pre-warm the pool - DialTimeout: 5 * time.Second, - } - - client := NewClient(opt) - defer client.Close() - - // Wait for pool to warm up - time.Sleep(100 * time.Millisecond) - - // Same load: 200 concurrent requests with 200ms timeout - const numRequests = 200 - const requestTimeout = 200 * time.Millisecond - - var wg sync.WaitGroup - var timeouts atomic.Int32 - var successes atomic.Int32 - var errors atomic.Int32 - - startTime := time.Now() - - for i := 0; i < numRequests; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - - ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) - defer cancel() - - _, err := client.Get(ctx, fmt.Sprintf("key-%d", id)).Result() - - if err != nil { - if ctx.Err() == context.DeadlineExceeded || err == context.DeadlineExceeded { - timeouts.Add(1) - } else { - errors.Add(1) - } - } else { - successes.Add(1) - } - }(i) - - // Stagger requests slightly - if i%20 == 0 { - time.Sleep(5 * time.Millisecond) - } - } - - wg.Wait() - totalTime := time.Since(startTime) - - timeoutRate := float64(timeouts.Load()) / float64(numRequests) * 100 - successRate := float64(successes.Load()) / float64(numRequests) * 100 - - t.Logf("\n=== GOOD CONFIGURATION (Tuned Settings) ===") - t.Logf("Configuration:") - t.Logf(" PoolSize: %d", opt.PoolSize) - t.Logf(" MaxConcurrentDials: %d", opt.MaxConcurrentDials) - t.Logf(" MinIdleConns: %d", opt.MinIdleConns) - t.Logf("\nResults:") - t.Logf(" Total time: %v", totalTime) - t.Logf(" Successes: %d (%.1f%%)", successes.Load(), successRate) - t.Logf(" Timeouts: %d (%.1f%%)", timeouts.Load(), timeoutRate) - t.Logf(" Other errors: %d", errors.Load()) - t.Logf(" Total dials: %d (succeeded: %d, failed: %d)", - dialCount.Load(), dialsSucceeded.Load(), dialsFailed.Load()) - - // With good configuration: - // - Higher MaxConcurrentDials allows more concurrent dials - // - MinIdleConns pre-warms the pool - // - Expected: Low timeout rate (<20%) - - if timeoutRate > 20 { - t.Errorf("Expected low timeout rate (<20%%), got %.1f%%", timeoutRate) - } -} - -// TestConfigurationComparison runs both tests and shows the difference -func TestConfigurationComparison(t *testing.T) { - t.Run("BadConfiguration", TestBadConfigurationHighLoad) - t.Run("GoodConfiguration", TestGoodConfigurationHighLoad) -} - diff --git a/redis.go b/redis.go index bdf8e0fc6..fc158e8d1 100644 --- a/redis.go +++ b/redis.go @@ -399,10 +399,33 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { if finalState == pool.StateInitializing { // Another goroutine is initializing - WAIT for it to complete - // Use the request context directly to respect the caller's timeout - // This prevents goroutines from waiting longer than their request timeout + // Use a context with timeout = min(remaining command timeout, DialTimeout) + // This prevents waiting too long while respecting the caller's deadline + waitCtx := ctx + dialTimeout := c.opt.DialTimeout + + if cmdDeadline, hasCmdDeadline := ctx.Deadline(); hasCmdDeadline { + // Calculate remaining time until command deadline + remainingTime := time.Until(cmdDeadline) + // Use the minimum of remaining time and DialTimeout + if remainingTime < dialTimeout { + // Command deadline is sooner, use it + waitCtx = ctx + } else { + // DialTimeout is shorter, cap the wait at DialTimeout + var cancel context.CancelFunc + waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) + defer cancel() + } + } else { + // No command deadline, use DialTimeout to prevent waiting indefinitely + var cancel context.CancelFunc + waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) + defer cancel() + } + finalState, err := cn.GetStateMachine().AwaitAndTransition( - ctx, + waitCtx, []pool.ConnState{pool.StateIdle, pool.StateInUse}, pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op) ) From dc13aefe66c6b1eafb24e7aadfd72217b9ea293d Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 29 Nov 2025 01:12:18 +0200 Subject: [PATCH 3/5] make linter happy --- redis.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/redis.go b/redis.go index fc158e8d1..f63933b0e 100644 --- a/redis.go +++ b/redis.go @@ -401,7 +401,8 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // Another goroutine is initializing - WAIT for it to complete // Use a context with timeout = min(remaining command timeout, DialTimeout) // This prevents waiting too long while respecting the caller's deadline - waitCtx := ctx + var waitCtx context.Context + var cancel context.CancelFunc dialTimeout := c.opt.DialTimeout if cmdDeadline, hasCmdDeadline := ctx.Deadline(); hasCmdDeadline { @@ -413,13 +414,11 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { waitCtx = ctx } else { // DialTimeout is shorter, cap the wait at DialTimeout - var cancel context.CancelFunc waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) defer cancel() } } else { // No command deadline, use DialTimeout to prevent waiting indefinitely - var cancel context.CancelFunc waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) defer cancel() } From 2890d5139ca33a22ef6118fc7e34391ab618f36f Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Sat, 29 Nov 2025 01:47:15 +0200 Subject: [PATCH 4/5] Update redis.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- redis.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/redis.go b/redis.go index f63933b0e..a6a710677 100644 --- a/redis.go +++ b/redis.go @@ -415,11 +415,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } else { // DialTimeout is shorter, cap the wait at DialTimeout waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) - defer cancel() } } else { // No command deadline, use DialTimeout to prevent waiting indefinitely waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) + } + if cancel != nil { defer cancel() } From 9710b24e3e7d0045f88e886195ef21dd9c6023e2 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Sat, 29 Nov 2025 02:32:34 +0200 Subject: [PATCH 5/5] resolve semaphore leak --- internal/pool/pool.go | 6 +- internal/pool/race_freeturn_test.go | 146 ++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 internal/pool/race_freeturn_test.go diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 5ca6a29b3..3da15ccd0 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -576,7 +576,11 @@ func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { // If dial completed before timeout, try to deliver connection to other waiters if cn := w.cancel(); cn != nil { p.putIdleConn(ctx, cn) - // freeTurn will be called by the dial goroutine or by the waiter who receives the connection + // Free the turn since: + // - Dial goroutine returned thinking delivery succeeded (tryDeliver returned true) + // - Original waiter won't call Put() (they got an error, not a connection) + // - Another waiter will receive this connection but won't free this turn + p.freeTurn() } // If dial hasn't completed yet, freeTurn will be called by the dial goroutine } diff --git a/internal/pool/race_freeturn_test.go b/internal/pool/race_freeturn_test.go new file mode 100644 index 000000000..f998ddac6 --- /dev/null +++ b/internal/pool/race_freeturn_test.go @@ -0,0 +1,146 @@ +package pool_test + +import ( + "context" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// TestRaceConditionFreeTurn tests the race condition where: +// 1. Dial completes and tryDeliver succeeds +// 2. Waiter's context times out before receiving from result channel +// 3. Waiter's defer retrieves connection via cancel() and delivers to another waiter +// 4. Turn must be freed by the defer, not by dial goroutine or new waiter +func TestRaceConditionFreeTurn(t *testing.T) { + // Create a pool with PoolSize=2 to make the race easier to trigger + opt := &pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + // Slow dial to allow context timeout to race with delivery + time.Sleep(50 * time.Millisecond) + return dummyDialer(ctx) + }, + PoolSize: 2, + MaxConcurrentDials: 2, + DialTimeout: 1 * time.Second, + PoolTimeout: 1 * time.Second, + } + + connPool := pool.NewConnPool(opt) + defer connPool.Close() + + // Run multiple iterations to increase chance of hitting the race + for iteration := 0; iteration < 10; iteration++ { + // Request 1: Will timeout quickly + ctx1, cancel1 := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancel1() + + var wg sync.WaitGroup + wg.Add(2) + + // Goroutine 1: Request with short timeout + go func() { + defer wg.Done() + cn, err := connPool.Get(ctx1) + if err == nil { + // If we got a connection, put it back + connPool.Put(ctx1, cn) + } + // Expected: context deadline exceeded + }() + + // Goroutine 2: Request with longer timeout, should receive the orphaned connection + go func() { + defer wg.Done() + time.Sleep(20 * time.Millisecond) // Start slightly after first request + ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel2() + + cn, err := connPool.Get(ctx2) + if err != nil { + t.Logf("Request 2 error: %v", err) + return + } + // Got connection, put it back + connPool.Put(ctx2, cn) + }() + + wg.Wait() + + // Give some time for all operations to complete + time.Sleep(100 * time.Millisecond) + + // Check QueueLen - should be 0 (all turns freed) + queueLen := connPool.QueueLen() + if queueLen != 0 { + t.Errorf("Iteration %d: QueueLen = %d, expected 0 (turn leak detected!)", iteration, queueLen) + } + } +} + +// TestRaceConditionFreeTurnStress is a stress test for the race condition +func TestRaceConditionFreeTurnStress(t *testing.T) { + var dialIndex atomic.Int32 + opt := &pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + // Variable dial time to create more race opportunities + // Use atomic increment to avoid data race + idx := dialIndex.Add(1) + time.Sleep(time.Duration(10+idx%40) * time.Millisecond) + return dummyDialer(ctx) + }, + PoolSize: 10, + MaxConcurrentDials: 10, + DialTimeout: 1 * time.Second, + PoolTimeout: 500 * time.Millisecond, + } + + connPool := pool.NewConnPool(opt) + defer connPool.Close() + + const numRequests = 50 + + var wg sync.WaitGroup + wg.Add(numRequests) + + // Launch many concurrent requests with varying timeouts + for i := 0; i < numRequests; i++ { + go func(i int) { + defer wg.Done() + + // Varying timeouts to create race conditions + timeout := time.Duration(20+i%80) * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cn, err := connPool.Get(ctx) + if err == nil { + // Simulate some work + time.Sleep(time.Duration(i%20) * time.Millisecond) + connPool.Put(ctx, cn) + } + }(i) + } + + wg.Wait() + + // Give time for all cleanup to complete + time.Sleep(200 * time.Millisecond) + + // Check for turn leaks + queueLen := connPool.QueueLen() + if queueLen != 0 { + t.Errorf("QueueLen = %d, expected 0 (turn leak detected!)", queueLen) + t.Errorf("This indicates that some turns were never freed") + } + + // Also check pool stats + stats := connPool.Stats() + t.Logf("Pool stats: Hits=%d, Misses=%d, Timeouts=%d, TotalConns=%d, IdleConns=%d", + stats.Hits, stats.Misses, stats.Timeouts, stats.TotalConns, stats.IdleConns) +} +