/ go / test / provider_pool_test.go
provider_pool_test.go
  1  package test
  2  
  3  import (
  4  	"context"
  5  	"errors"
  6  	"testing"
  7  	"time"
  8  
  9  	"github.com/TransformerOS/kamaji-go/internal/providers"
 10  )
 11  
 12  // Mock provider for testing
 13  type mockProvider struct {
 14  	name      string
 15  	shouldFail bool
 16  	response  string
 17  	delay     time.Duration
 18  }
 19  
 20  func (m *mockProvider) Call(ctx context.Context, prompt string) (string, error) {
 21  	if m.delay > 0 {
 22  		time.Sleep(m.delay)
 23  	}
 24  	
 25  	if m.shouldFail {
 26  		return "", errors.New("mock provider error")
 27  	}
 28  	
 29  	return m.response, nil
 30  }
 31  
 32  func TestProviderPoolFailover(t *testing.T) {
 33  	pool := providers.NewProviderPool(providers.Failover)
 34  	
 35  	// Add failing provider first
 36  	failingProvider := &mockProvider{
 37  		name:       "failing",
 38  		shouldFail: true,
 39  	}
 40  	pool.AddProvider(failingProvider)
 41  	
 42  	// Add working provider second
 43  	workingProvider := &mockProvider{
 44  		name:       "working",
 45  		shouldFail: false,
 46  		response:   "success",
 47  	}
 48  	pool.AddProvider(workingProvider)
 49  	
 50  	ctx := context.Background()
 51  	result, err := pool.Call(ctx, "test prompt")
 52  	
 53  	if err != nil {
 54  		t.Fatalf("Expected success with failover, got error: %v", err)
 55  	}
 56  	
 57  	if result != "success" {
 58  		t.Errorf("Expected 'success', got '%s'", result)
 59  	}
 60  }
 61  
 62  func TestProviderPoolRoundRobin(t *testing.T) {
 63  	pool := providers.NewProviderPool(providers.RoundRobin)
 64  	
 65  	// Add two working providers
 66  	provider1 := &mockProvider{
 67  		name:     "provider1",
 68  		response: "response1",
 69  	}
 70  	provider2 := &mockProvider{
 71  		name:     "provider2", 
 72  		response: "response2",
 73  	}
 74  	
 75  	pool.AddProvider(provider1)
 76  	pool.AddProvider(provider2)
 77  	
 78  	ctx := context.Background()
 79  	
 80  	// First call should use provider1
 81  	result1, err := pool.Call(ctx, "test")
 82  	if err != nil {
 83  		t.Fatalf("First call failed: %v", err)
 84  	}
 85  	
 86  	// Second call should use provider2
 87  	result2, err := pool.Call(ctx, "test")
 88  	if err != nil {
 89  		t.Fatalf("Second call failed: %v", err)
 90  	}
 91  	
 92  	// Results should be different (round robin)
 93  	if result1 == result2 {
 94  		t.Error("Expected different responses from round robin, got same")
 95  	}
 96  }
 97  
 98  func TestProviderPoolHealthTracking(t *testing.T) {
 99  	pool := providers.NewProviderPool(providers.Failover)
100  	
101  	// Add provider that will fail
102  	provider := &mockProvider{
103  		name:       "test",
104  		shouldFail: true,
105  	}
106  	pool.AddProvider(provider)
107  	
108  	ctx := context.Background()
109  	
110  	// Make failing calls - pool will stop calling after provider becomes unhealthy
111  	var lastErr error
112  	for i := 0; i < 10; i++ {
113  		_, err := pool.Call(ctx, "test")
114  		if err != nil {
115  			lastErr = err
116  		}
117  	}
118  	
119  	// Should eventually get "all providers failed" error
120  	if lastErr == nil || lastErr.Error() != "all providers failed" {
121  		t.Errorf("Expected 'all providers failed' error, got: %v", lastErr)
122  	}
123  	
124  	// Check health status
125  	status := pool.GetStatus()
126  	if len(status) != 1 {
127  		t.Errorf("Expected 1 provider status, got %d", len(status))
128  	}
129  	
130  	health := status[0]
131  	if health.ErrorCount == 0 {
132  		t.Error("Expected error count > 0")
133  	}
134  	
135  	// Provider should be marked unavailable after enough errors
136  	if health.Available && health.ErrorCount >= 5 {
137  		t.Error("Expected provider to be marked unavailable after 5+ errors")
138  	}
139  }
140  
141  func TestProviderPoolEmpty(t *testing.T) {
142  	pool := providers.NewProviderPool(providers.Failover)
143  	
144  	ctx := context.Background()
145  	_, err := pool.Call(ctx, "test")
146  	
147  	if err == nil {
148  		t.Error("Expected error when calling empty provider pool")
149  	}
150  	
151  	if err.Error() != "no providers available" {
152  		t.Errorf("Expected 'no providers available', got '%s'", err.Error())
153  	}
154  }