retry_test.go
1 package agent 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/Kocoro-lab/ShanClaw/internal/client" 8 ) 9 10 func TestIsRetryableLLMError(t *testing.T) { 11 tests := []struct { 12 name string 13 err error 14 retryable bool 15 }{ 16 {"nil", nil, false}, 17 // Typed APIError (primary path) 18 {"typed 429", &client.APIError{StatusCode: 429, Body: "rate limit"}, true}, 19 {"typed 500", &client.APIError{StatusCode: 500, Body: "internal"}, true}, 20 {"typed 502", &client.APIError{StatusCode: 502, Body: "bad gateway"}, true}, 21 {"typed 503", &client.APIError{StatusCode: 503}, true}, 22 {"typed 529", &client.APIError{StatusCode: 529, Body: "overloaded"}, true}, 23 {"typed 400", &client.APIError{StatusCode: 400, Body: "invalid"}, false}, 24 {"typed 401", &client.APIError{StatusCode: 401, Body: "unauthorized"}, false}, 25 {"typed 403", &client.APIError{StatusCode: 403, Body: "forbidden"}, false}, 26 // Wrapped typed APIError (errors.As unwraps) 27 {"wrapped 429", fmt.Errorf("LLM call failed: %w", &client.APIError{StatusCode: 429}), true}, 28 {"wrapped 400", fmt.Errorf("LLM call failed: %w", &client.APIError{StatusCode: 400}), false}, 29 // Network/stream errors (string-matched) 30 {"network timeout", fmt.Errorf("request failed: context deadline exceeded"), true}, 31 {"connection reset", fmt.Errorf("request failed: connection reset"), true}, 32 {"stream read error", fmt.Errorf("stream read error: unexpected EOF"), true}, 33 {"stream ended early", fmt.Errorf("stream ended without done event"), true}, 34 // Non-retryable 35 {"marshal error", fmt.Errorf("marshal request: json error"), false}, 36 {"decode error", fmt.Errorf("decode response: unexpected EOF"), false}, 37 {"generic error", fmt.Errorf("something unexpected"), false}, 38 } 39 for _, tt := range tests { 40 t.Run(tt.name, func(t *testing.T) { 41 got := isRetryableLLMError(tt.err) 42 if got != tt.retryable { 43 t.Errorf("isRetryableLLMError(%v) = %v, want %v", tt.err, got, tt.retryable) 44 } 45 }) 46 } 47 } 48 49 func TestClassifyLLMError(t *testing.T) { 50 tests := []struct { 51 name string 52 err error 53 expect string 54 }{ 55 {"nil", nil, "unknown"}, 56 {"rate limit", &client.APIError{StatusCode: 429}, "rate limited"}, 57 {"overloaded", &client.APIError{StatusCode: 529}, "API overloaded"}, 58 {"server 500", &client.APIError{StatusCode: 500}, "server error"}, 59 {"server 502", &client.APIError{StatusCode: 502}, "server error"}, 60 {"server 503", &client.APIError{StatusCode: 503}, "server error"}, 61 {"bad request", &client.APIError{StatusCode: 400}, "HTTP 400"}, 62 {"timeout", fmt.Errorf("request failed: context deadline exceeded"), "request timeout"}, 63 {"connection reset", fmt.Errorf("request failed: connection reset"), "connection error"}, 64 {"stream error", fmt.Errorf("stream read error: unexpected EOF"), "stream interrupted"}, 65 {"generic", fmt.Errorf("something else"), "transient error"}, 66 } 67 for _, tt := range tests { 68 t.Run(tt.name, func(t *testing.T) { 69 got := classifyLLMError(tt.err) 70 if got != tt.expect { 71 t.Errorf("classifyLLMError(%v) = %q, want %q", tt.err, got, tt.expect) 72 } 73 }) 74 } 75 } 76 77 func TestIsContextLengthError(t *testing.T) { 78 tests := []struct { 79 name string 80 err error 81 expect bool 82 }{ 83 {"nil", nil, false}, 84 {"prompt too long", &client.APIError{StatusCode: 400, Body: `{"error":"prompt is too long"}`}, true}, 85 {"context_length_exceeded", &client.APIError{StatusCode: 400, Body: `{"error":"context_length_exceeded"}`}, true}, 86 {"case insensitive", &client.APIError{StatusCode: 400, Body: `Prompt Is Too Long`}, true}, 87 {"wrapped", fmt.Errorf("call failed: %w", &client.APIError{StatusCode: 400, Body: "prompt is too long"}), true}, 88 // Must NOT match 89 {"max_tokens", &client.APIError{StatusCode: 400, Body: `{"error":"max_tokens exceeded"}`}, false}, 90 {"unrelated 400", &client.APIError{StatusCode: 400, Body: `{"error":"invalid request"}`}, false}, 91 {"server error", &client.APIError{StatusCode: 500, Body: "prompt is too long"}, false}, 92 {"non-api error", fmt.Errorf("prompt is too long"), false}, 93 {"rate limit", &client.APIError{StatusCode: 429, Body: "rate limited"}, false}, 94 } 95 for _, tt := range tests { 96 t.Run(tt.name, func(t *testing.T) { 97 got := isContextLengthError(tt.err) 98 if got != tt.expect { 99 t.Errorf("isContextLengthError(%v) = %v, want %v", tt.err, got, tt.expect) 100 } 101 }) 102 } 103 }