provider_pool_test.go
1 package test 2 3 import ( 4 "context" 5 "errors" 6 "testing" 7 "time" 8 9 "github.com/TransformerOS/kamaji-go/internal/providers" 10 ) 11 12 // Mock provider for testing 13 type mockProvider struct { 14 name string 15 shouldFail bool 16 response string 17 delay time.Duration 18 } 19 20 func (m *mockProvider) Call(ctx context.Context, prompt string) (string, error) { 21 if m.delay > 0 { 22 time.Sleep(m.delay) 23 } 24 25 if m.shouldFail { 26 return "", errors.New("mock provider error") 27 } 28 29 return m.response, nil 30 } 31 32 func TestProviderPoolFailover(t *testing.T) { 33 pool := providers.NewProviderPool(providers.Failover) 34 35 // Add failing provider first 36 failingProvider := &mockProvider{ 37 name: "failing", 38 shouldFail: true, 39 } 40 pool.AddProvider(failingProvider) 41 42 // Add working provider second 43 workingProvider := &mockProvider{ 44 name: "working", 45 shouldFail: false, 46 response: "success", 47 } 48 pool.AddProvider(workingProvider) 49 50 ctx := context.Background() 51 result, err := pool.Call(ctx, "test prompt") 52 53 if err != nil { 54 t.Fatalf("Expected success with failover, got error: %v", err) 55 } 56 57 if result != "success" { 58 t.Errorf("Expected 'success', got '%s'", result) 59 } 60 } 61 62 func TestProviderPoolRoundRobin(t *testing.T) { 63 pool := providers.NewProviderPool(providers.RoundRobin) 64 65 // Add two working providers 66 provider1 := &mockProvider{ 67 name: "provider1", 68 response: "response1", 69 } 70 provider2 := &mockProvider{ 71 name: "provider2", 72 response: "response2", 73 } 74 75 pool.AddProvider(provider1) 76 pool.AddProvider(provider2) 77 78 ctx := context.Background() 79 80 // First call should use provider1 81 result1, err := pool.Call(ctx, "test") 82 if err != nil { 83 t.Fatalf("First call failed: %v", err) 84 } 85 86 // Second call should use provider2 87 result2, err := pool.Call(ctx, "test") 88 if err != nil { 89 t.Fatalf("Second call failed: %v", err) 90 } 91 92 // Results should be different (round robin) 93 if result1 == result2 { 94 t.Error("Expected different responses from round robin, got same") 95 } 96 } 97 98 func TestProviderPoolHealthTracking(t *testing.T) { 99 pool := providers.NewProviderPool(providers.Failover) 100 101 // Add provider that will fail 102 provider := &mockProvider{ 103 name: "test", 104 shouldFail: true, 105 } 106 pool.AddProvider(provider) 107 108 ctx := context.Background() 109 110 // Make failing calls - pool will stop calling after provider becomes unhealthy 111 var lastErr error 112 for i := 0; i < 10; i++ { 113 _, err := pool.Call(ctx, "test") 114 if err != nil { 115 lastErr = err 116 } 117 } 118 119 // Should eventually get "all providers failed" error 120 if lastErr == nil || lastErr.Error() != "all providers failed" { 121 t.Errorf("Expected 'all providers failed' error, got: %v", lastErr) 122 } 123 124 // Check health status 125 status := pool.GetStatus() 126 if len(status) != 1 { 127 t.Errorf("Expected 1 provider status, got %d", len(status)) 128 } 129 130 health := status[0] 131 if health.ErrorCount == 0 { 132 t.Error("Expected error count > 0") 133 } 134 135 // Provider should be marked unavailable after enough errors 136 if health.Available && health.ErrorCount >= 5 { 137 t.Error("Expected provider to be marked unavailable after 5+ errors") 138 } 139 } 140 141 func TestProviderPoolEmpty(t *testing.T) { 142 pool := providers.NewProviderPool(providers.Failover) 143 144 ctx := context.Background() 145 _, err := pool.Call(ctx, "test") 146 147 if err == nil { 148 t.Error("Expected error when calling empty provider pool") 149 } 150 151 if err.Error() != "no providers available" { 152 t.Errorf("Expected 'no providers available', got '%s'", err.Error()) 153 } 154 }