From 0da482366012cd9d30632d49fd11dffc2116aa98 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Wed, 26 Nov 2025 18:27:48 -0800 Subject: [PATCH 1/2] Expand redis rate limiter tests --- app/redis_test.go | 142 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 132 insertions(+), 10 deletions(-) diff --git a/app/redis_test.go b/app/redis_test.go index bb8ea50..cf72b8d 100644 --- a/app/redis_test.go +++ b/app/redis_test.go @@ -1,14 +1,45 @@ package main import ( - "bufio" - "fmt" - "io" - "net" - "testing" - "time" + "bufio" + "fmt" + "io" + "net" + "strconv" + "strings" + "testing" + "time" ) +func readRESPCommand(t *testing.T, br *bufio.Reader) []string { + t.Helper() + + line, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read command length: %v", err) + } + if !strings.HasPrefix(line, "*") { + t.Fatalf("unexpected command prefix %q", line) + } + n, err := strconv.Atoi(strings.TrimSpace(line[1:])) + if err != nil { + t.Fatalf("parse command length: %v", err) + } + + args := make([]string, n) + for i := 0; i < n; i++ { + if _, err := br.ReadString('\n'); err != nil { + t.Fatalf("read bulk len: %v", err) + } + arg, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read bulk data: %v", err) + } + args[i] = strings.TrimSpace(arg) + } + return args +} + func TestRedisCmdInt(t *testing.T) { srv, cli := net.Pipe() defer srv.Close() @@ -290,9 +321,9 @@ func TestRedisCmdStringSimple(t *testing.T) { } func TestRedisCmdStringInteger(t *testing.T) { - srv, cli := net.Pipe() - defer srv.Close() - defer cli.Close() + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() go func() { br := bufio.NewReader(srv) @@ -307,5 +338,96 @@ func TestRedisCmdStringInteger(t *testing.T) { } if val != "5" { t.Fatalf("expected 5, got %q", val) - } + } +} + +func TestAllowRedisTokenBucketEmpty(t *testing.T) { + oldAddr := *redisAddr + *redisAddr = "redis://example:6379" + t.Cleanup(func() { *redisAddr = oldAddr }) + + rl := NewRateLimiter(1, time.Second, "token_bucket") + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + now := time.Now().UnixNano() + + go func() { + br := bufio.NewReader(srv) + + // GET k + readRESPCommand(t, br) + payload := fmt.Sprintf("0 %d", now) + srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) + + // SET k + readRESPCommand(t, br) + srv.Write([]byte("+OK\r\n")) + + // PEXPIRE k + readRESPCommand(t, br) + srv.Write([]byte(":1\r\n")) + }() + + allowed, err := rl.allowRedisTokenBucket(cli, "k") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if allowed { + t.Fatal("expected request to be rate limited when bucket is empty") + } +} + +func TestAllowRedisLeakyBucketOverLimit(t *testing.T) { + oldAddr := *redisAddr + *redisAddr = "redis://example:6379" + t.Cleanup(func() { *redisAddr = oldAddr }) + + rl := NewRateLimiter(1, time.Second, "leaky_bucket") + now := time.Now().UnixNano() + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + go func() { + br := bufio.NewReader(srv) + + // GET k + readRESPCommand(t, br) + payload := fmt.Sprintf("2 %d", now) + srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) + + // SET k + readRESPCommand(t, br) + srv.Write([]byte("+OK\r\n")) + + // PEXPIRE k + readRESPCommand(t, br) + srv.Write([]byte(":1\r\n")) + }() + + allowed, err := rl.allowRedisLeakyBucket(cli, "k") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if allowed { + t.Fatal("expected request to be rate limited when bucket is over limit") + } +} + +func TestRetryAfterRedisTLSMissingCA(t *testing.T) { + oldAddr := *redisAddr + oldCA := *redisCA + *redisAddr = "rediss://example.com:6379" + *redisCA = "does-not-exist" + t.Cleanup(func() { + *redisAddr = oldAddr + *redisCA = oldCA + }) + + rl := NewRateLimiter(1, time.Second, "fixed_window") + if _, err := rl.retryAfterRedis("key"); err == nil { + t.Fatal("expected error when CA file cannot be read") + } } From a96e7d2c6f2785f5287bc0ff198e33e8fe845149 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Thu, 27 Nov 2025 23:16:08 -0800 Subject: [PATCH 2/2] Add redis error path tests --- app/main_test.go | 34 ++++++ app/redis_test.go | 257 ++++++++++++++++++++++++---------------------- 2 files changed, 171 insertions(+), 120 deletions(-) diff --git a/app/main_test.go b/app/main_test.go index edc573b..fea22b3 100644 --- a/app/main_test.go +++ b/app/main_test.go @@ -716,6 +716,23 @@ func TestAllowRedisTokenBucketReject(t *testing.T) { <-done } +func TestAllowRedisTokenBucketError(t *testing.T) { + rl := NewRateLimiter(1, time.Second, "token_bucket") + t.Cleanup(rl.Stop) + srv, cli := net.Pipe() + done := make(chan struct{}) + go func() { + defer func() { srv.Close(); close(done) }() + br := bufio.NewReader(srv) + parseRedisCommand(t, br) + srv.Write([]byte("-ERR nope\r\n")) + }() + if _, err := rl.allowRedisTokenBucket(cli, "k"); err == nil { + t.Fatal("expected redis error") + } + <-done +} + func TestAllowRedisLeakyBucketPoolFullClosesConnection(t *testing.T) { old := *redisAddr *redisAddr = "dummy" @@ -815,6 +832,23 @@ func TestAllowRedisLeakyBucketReject(t *testing.T) { <-done } +func TestAllowRedisLeakyBucketError(t *testing.T) { + rl := NewRateLimiter(1, time.Second, "leaky_bucket") + t.Cleanup(rl.Stop) + srv, cli := net.Pipe() + done := make(chan struct{}) + go func() { + defer func() { srv.Close(); close(done) }() + br := bufio.NewReader(srv) + parseRedisCommand(t, br) + srv.Write([]byte("-ERR boom\r\n")) + }() + if _, err := rl.allowRedisLeakyBucket(cli, "k"); err == nil { + t.Fatal("expected redis GET error") + } + <-done +} + func TestAllowRedisLeakyBucketAllow(t *testing.T) { rl := NewRateLimiter(1, time.Hour, "leaky_bucket") t.Cleanup(rl.Stop) diff --git a/app/redis_test.go b/app/redis_test.go index cf72b8d..2a653ad 100644 --- a/app/redis_test.go +++ b/app/redis_test.go @@ -1,43 +1,43 @@ package main import ( - "bufio" - "fmt" - "io" - "net" - "strconv" - "strings" - "testing" - "time" + "bufio" + "fmt" + "io" + "net" + "strconv" + "strings" + "testing" + "time" ) func readRESPCommand(t *testing.T, br *bufio.Reader) []string { - t.Helper() - - line, err := br.ReadString('\n') - if err != nil { - t.Fatalf("read command length: %v", err) - } - if !strings.HasPrefix(line, "*") { - t.Fatalf("unexpected command prefix %q", line) - } - n, err := strconv.Atoi(strings.TrimSpace(line[1:])) - if err != nil { - t.Fatalf("parse command length: %v", err) - } - - args := make([]string, n) - for i := 0; i < n; i++ { - if _, err := br.ReadString('\n'); err != nil { - t.Fatalf("read bulk len: %v", err) - } - arg, err := br.ReadString('\n') - if err != nil { - t.Fatalf("read bulk data: %v", err) - } - args[i] = strings.TrimSpace(arg) - } - return args + t.Helper() + + line, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read command length: %v", err) + } + if !strings.HasPrefix(line, "*") { + t.Fatalf("unexpected command prefix %q", line) + } + n, err := strconv.Atoi(strings.TrimSpace(line[1:])) + if err != nil { + t.Fatalf("parse command length: %v", err) + } + + args := make([]string, n) + for i := 0; i < n; i++ { + if _, err := br.ReadString('\n'); err != nil { + t.Fatalf("read bulk len: %v", err) + } + arg, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read bulk data: %v", err) + } + args[i] = strings.TrimSpace(arg) + } + return args } func TestRedisCmdInt(t *testing.T) { @@ -299,6 +299,23 @@ func TestRedisCmdStringReadError(t *testing.T) { } } +func TestRedisCmdStringBulkReadError(t *testing.T) { + srv, cli := net.Pipe() + defer cli.Close() + + go func() { + br := bufio.NewReader(srv) + br.ReadBytes('\n') + br.ReadBytes('\n') + srv.Write([]byte("$4\r\n")) + srv.Close() + }() + + if _, err := redisCmdString(cli, "GET", "key"); err == nil { + t.Fatal("expected bulk read error") + } +} + func TestRedisCmdStringSimple(t *testing.T) { srv, cli := net.Pipe() defer srv.Close() @@ -321,9 +338,9 @@ func TestRedisCmdStringSimple(t *testing.T) { } func TestRedisCmdStringInteger(t *testing.T) { - srv, cli := net.Pipe() - defer srv.Close() - defer cli.Close() + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() go func() { br := bufio.NewReader(srv) @@ -338,96 +355,96 @@ func TestRedisCmdStringInteger(t *testing.T) { } if val != "5" { t.Fatalf("expected 5, got %q", val) - } + } } func TestAllowRedisTokenBucketEmpty(t *testing.T) { - oldAddr := *redisAddr - *redisAddr = "redis://example:6379" - t.Cleanup(func() { *redisAddr = oldAddr }) - - rl := NewRateLimiter(1, time.Second, "token_bucket") - srv, cli := net.Pipe() - defer srv.Close() - defer cli.Close() - - now := time.Now().UnixNano() - - go func() { - br := bufio.NewReader(srv) - - // GET k - readRESPCommand(t, br) - payload := fmt.Sprintf("0 %d", now) - srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) - - // SET k - readRESPCommand(t, br) - srv.Write([]byte("+OK\r\n")) - - // PEXPIRE k - readRESPCommand(t, br) - srv.Write([]byte(":1\r\n")) - }() - - allowed, err := rl.allowRedisTokenBucket(cli, "k") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if allowed { - t.Fatal("expected request to be rate limited when bucket is empty") - } + oldAddr := *redisAddr + *redisAddr = "redis://example:6379" + t.Cleanup(func() { *redisAddr = oldAddr }) + + rl := NewRateLimiter(1, time.Second, "token_bucket") + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + now := time.Now().UnixNano() + + go func() { + br := bufio.NewReader(srv) + + // GET k + readRESPCommand(t, br) + payload := fmt.Sprintf("0 %d", now) + srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) + + // SET k + readRESPCommand(t, br) + srv.Write([]byte("+OK\r\n")) + + // PEXPIRE k + readRESPCommand(t, br) + srv.Write([]byte(":1\r\n")) + }() + + allowed, err := rl.allowRedisTokenBucket(cli, "k") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if allowed { + t.Fatal("expected request to be rate limited when bucket is empty") + } } func TestAllowRedisLeakyBucketOverLimit(t *testing.T) { - oldAddr := *redisAddr - *redisAddr = "redis://example:6379" - t.Cleanup(func() { *redisAddr = oldAddr }) - - rl := NewRateLimiter(1, time.Second, "leaky_bucket") - now := time.Now().UnixNano() - srv, cli := net.Pipe() - defer srv.Close() - defer cli.Close() - - go func() { - br := bufio.NewReader(srv) - - // GET k - readRESPCommand(t, br) - payload := fmt.Sprintf("2 %d", now) - srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) - - // SET k - readRESPCommand(t, br) - srv.Write([]byte("+OK\r\n")) - - // PEXPIRE k - readRESPCommand(t, br) - srv.Write([]byte(":1\r\n")) - }() - - allowed, err := rl.allowRedisLeakyBucket(cli, "k") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if allowed { - t.Fatal("expected request to be rate limited when bucket is over limit") - } + oldAddr := *redisAddr + *redisAddr = "redis://example:6379" + t.Cleanup(func() { *redisAddr = oldAddr }) + + rl := NewRateLimiter(1, time.Second, "leaky_bucket") + now := time.Now().UnixNano() + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + go func() { + br := bufio.NewReader(srv) + + // GET k + readRESPCommand(t, br) + payload := fmt.Sprintf("2 %d", now) + srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) + + // SET k + readRESPCommand(t, br) + srv.Write([]byte("+OK\r\n")) + + // PEXPIRE k + readRESPCommand(t, br) + srv.Write([]byte(":1\r\n")) + }() + + allowed, err := rl.allowRedisLeakyBucket(cli, "k") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if allowed { + t.Fatal("expected request to be rate limited when bucket is over limit") + } } func TestRetryAfterRedisTLSMissingCA(t *testing.T) { - oldAddr := *redisAddr - oldCA := *redisCA - *redisAddr = "rediss://example.com:6379" - *redisCA = "does-not-exist" - t.Cleanup(func() { - *redisAddr = oldAddr - *redisCA = oldCA - }) - - rl := NewRateLimiter(1, time.Second, "fixed_window") - if _, err := rl.retryAfterRedis("key"); err == nil { - t.Fatal("expected error when CA file cannot be read") - } + oldAddr := *redisAddr + oldCA := *redisCA + *redisAddr = "rediss://example.com:6379" + *redisCA = "does-not-exist" + t.Cleanup(func() { + *redisAddr = oldAddr + *redisCA = oldCA + }) + + rl := NewRateLimiter(1, time.Second, "fixed_window") + if _, err := rl.retryAfterRedis("key"); err == nil { + t.Fatal("expected error when CA file cannot be read") + } }