/ internal / agent / retry_test.go
retry_test.go
  1  package agent
  2  
  3  import (
  4  	"fmt"
  5  	"testing"
  6  
  7  	"github.com/Kocoro-lab/ShanClaw/internal/client"
  8  )
  9  
 10  func TestIsRetryableLLMError(t *testing.T) {
 11  	tests := []struct {
 12  		name      string
 13  		err       error
 14  		retryable bool
 15  	}{
 16  		{"nil", nil, false},
 17  		// Typed APIError (primary path)
 18  		{"typed 429", &client.APIError{StatusCode: 429, Body: "rate limit"}, true},
 19  		{"typed 500", &client.APIError{StatusCode: 500, Body: "internal"}, true},
 20  		{"typed 502", &client.APIError{StatusCode: 502, Body: "bad gateway"}, true},
 21  		{"typed 503", &client.APIError{StatusCode: 503}, true},
 22  		{"typed 529", &client.APIError{StatusCode: 529, Body: "overloaded"}, true},
 23  		{"typed 400", &client.APIError{StatusCode: 400, Body: "invalid"}, false},
 24  		{"typed 401", &client.APIError{StatusCode: 401, Body: "unauthorized"}, false},
 25  		{"typed 403", &client.APIError{StatusCode: 403, Body: "forbidden"}, false},
 26  		// Wrapped typed APIError (errors.As unwraps)
 27  		{"wrapped 429", fmt.Errorf("LLM call failed: %w", &client.APIError{StatusCode: 429}), true},
 28  		{"wrapped 400", fmt.Errorf("LLM call failed: %w", &client.APIError{StatusCode: 400}), false},
 29  		// Network/stream errors (string-matched)
 30  		{"network timeout", fmt.Errorf("request failed: context deadline exceeded"), true},
 31  		{"connection reset", fmt.Errorf("request failed: connection reset"), true},
 32  		{"stream read error", fmt.Errorf("stream read error: unexpected EOF"), true},
 33  		{"stream ended early", fmt.Errorf("stream ended without done event"), true},
 34  		// Non-retryable
 35  		{"marshal error", fmt.Errorf("marshal request: json error"), false},
 36  		{"decode error", fmt.Errorf("decode response: unexpected EOF"), false},
 37  		{"generic error", fmt.Errorf("something unexpected"), false},
 38  	}
 39  	for _, tt := range tests {
 40  		t.Run(tt.name, func(t *testing.T) {
 41  			got := isRetryableLLMError(tt.err)
 42  			if got != tt.retryable {
 43  				t.Errorf("isRetryableLLMError(%v) = %v, want %v", tt.err, got, tt.retryable)
 44  			}
 45  		})
 46  	}
 47  }
 48  
 49  func TestClassifyLLMError(t *testing.T) {
 50  	tests := []struct {
 51  		name   string
 52  		err    error
 53  		expect string
 54  	}{
 55  		{"nil", nil, "unknown"},
 56  		{"rate limit", &client.APIError{StatusCode: 429}, "rate limited"},
 57  		{"overloaded", &client.APIError{StatusCode: 529}, "API overloaded"},
 58  		{"server 500", &client.APIError{StatusCode: 500}, "server error"},
 59  		{"server 502", &client.APIError{StatusCode: 502}, "server error"},
 60  		{"server 503", &client.APIError{StatusCode: 503}, "server error"},
 61  		{"bad request", &client.APIError{StatusCode: 400}, "HTTP 400"},
 62  		{"timeout", fmt.Errorf("request failed: context deadline exceeded"), "request timeout"},
 63  		{"connection reset", fmt.Errorf("request failed: connection reset"), "connection error"},
 64  		{"stream error", fmt.Errorf("stream read error: unexpected EOF"), "stream interrupted"},
 65  		{"generic", fmt.Errorf("something else"), "transient error"},
 66  	}
 67  	for _, tt := range tests {
 68  		t.Run(tt.name, func(t *testing.T) {
 69  			got := classifyLLMError(tt.err)
 70  			if got != tt.expect {
 71  				t.Errorf("classifyLLMError(%v) = %q, want %q", tt.err, got, tt.expect)
 72  			}
 73  		})
 74  	}
 75  }
 76  
 77  func TestIsContextLengthError(t *testing.T) {
 78  	tests := []struct {
 79  		name   string
 80  		err    error
 81  		expect bool
 82  	}{
 83  		{"nil", nil, false},
 84  		{"prompt too long", &client.APIError{StatusCode: 400, Body: `{"error":"prompt is too long"}`}, true},
 85  		{"context_length_exceeded", &client.APIError{StatusCode: 400, Body: `{"error":"context_length_exceeded"}`}, true},
 86  		{"case insensitive", &client.APIError{StatusCode: 400, Body: `Prompt Is Too Long`}, true},
 87  		{"wrapped", fmt.Errorf("call failed: %w", &client.APIError{StatusCode: 400, Body: "prompt is too long"}), true},
 88  		// Must NOT match
 89  		{"max_tokens", &client.APIError{StatusCode: 400, Body: `{"error":"max_tokens exceeded"}`}, false},
 90  		{"unrelated 400", &client.APIError{StatusCode: 400, Body: `{"error":"invalid request"}`}, false},
 91  		{"server error", &client.APIError{StatusCode: 500, Body: "prompt is too long"}, false},
 92  		{"non-api error", fmt.Errorf("prompt is too long"), false},
 93  		{"rate limit", &client.APIError{StatusCode: 429, Body: "rate limited"}, false},
 94  	}
 95  	for _, tt := range tests {
 96  		t.Run(tt.name, func(t *testing.T) {
 97  			got := isContextLengthError(tt.err)
 98  			if got != tt.expect {
 99  				t.Errorf("isContextLengthError(%v) = %v, want %v", tt.err, got, tt.expect)
100  			}
101  		})
102  	}
103  }