/ test / vision_debug_test.go
vision_debug_test.go
  1  package test
  2  
  3  import (
  4  	"context"
  5  	"encoding/json"
  6  	"fmt"
  7  	"net/http"
  8  	"net/http/httptest"
  9  	"testing"
 10  
 11  	"github.com/Kocoro-lab/ShanClaw/internal/agent"
 12  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 13  	"github.com/Kocoro-lab/ShanClaw/internal/tools"
 14  )
 15  
 16  // TestVisionLoop_FullPipeline verifies that a real screenshot's base64 data
 17  // actually arrives in the API request payload as image content blocks.
 18  func TestVisionLoop_FullPipeline(t *testing.T) {
 19  	var capturedMessages []json.RawMessage
 20  
 21  	callCount := 0
 22  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 23  		callCount++
 24  
 25  		// Capture raw request body to inspect image blocks
 26  		var raw map[string]json.RawMessage
 27  		json.NewDecoder(r.Body).Decode(&raw)
 28  		if msgs, ok := raw["messages"]; ok {
 29  			capturedMessages = append(capturedMessages[:0], msgs) // store latest
 30  		}
 31  
 32  		if callCount == 1 {
 33  			// First call: tell the model to call screenshot
 34  			json.NewEncoder(w).Encode(client.CompletionResponse{
 35  				OutputText:   "",
 36  				FinishReason: "tool_use",
 37  				FunctionCall: &client.FunctionCall{
 38  					Name:      "screenshot",
 39  					Arguments: json.RawMessage(`{"target":"fullscreen"}`),
 40  				},
 41  				Usage: client.Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
 42  			})
 43  		} else {
 44  			// Second call: return response — but first, inspect what was sent
 45  			json.NewEncoder(w).Encode(client.CompletionResponse{
 46  				OutputText:   "I see a desktop",
 47  				FinishReason: "end_turn",
 48  				Usage:        client.Usage{InputTokens: 1000, OutputTokens: 50, TotalTokens: 1050},
 49  			})
 50  		}
 51  	}))
 52  	defer server.Close()
 53  
 54  	gw := client.NewGatewayClient(server.URL, "")
 55  	reg := agent.NewToolRegistry()
 56  	reg.Register(&tools.ScreenshotTool{})
 57  	loop := agent.NewAgentLoop(gw, reg, "medium", "", 10, 50000, 200, nil, nil, nil)
 58  	loop.SetBypassPermissions(true)
 59  
 60  	result, usage, err := loop.Run(context.Background(), "take a screenshot", nil, nil)
 61  	if err != nil {
 62  		t.Fatalf("agent loop error: %v", err)
 63  	}
 64  
 65  	t.Logf("Result: %s", result)
 66  	t.Logf("LLM calls: %d, tokens: %d", usage.LLMCalls, usage.TotalTokens)
 67  
 68  	if callCount < 2 {
 69  		t.Fatalf("expected at least 2 LLM calls, got %d", callCount)
 70  	}
 71  
 72  	// Parse the captured messages from the 2nd API call to verify image blocks
 73  	if len(capturedMessages) == 0 {
 74  		t.Fatal("no messages captured from API request")
 75  	}
 76  
 77  	var messages []json.RawMessage
 78  	json.Unmarshal(capturedMessages[0], &messages)
 79  
 80  	t.Logf("Messages in 2nd API call: %d", len(messages))
 81  
 82  	foundImage := false
 83  	imageBytes := 0
 84  	for i, msgRaw := range messages {
 85  		var msg struct {
 86  			Role    string          `json:"role"`
 87  			Content json.RawMessage `json:"content"`
 88  		}
 89  		json.Unmarshal(msgRaw, &msg)
 90  
 91  		// Check if content is an array (content blocks)
 92  		var blocks []struct {
 93  			Type   string `json:"type"`
 94  			Text   string `json:"text,omitempty"`
 95  			Source *struct {
 96  				Type      string `json:"type"`
 97  				MediaType string `json:"media_type"`
 98  				Data      string `json:"data"`
 99  			} `json:"source,omitempty"`
100  		}
101  		if err := json.Unmarshal(msg.Content, &blocks); err == nil && len(blocks) > 0 {
102  			for _, b := range blocks {
103  				if b.Type == "image" && b.Source != nil {
104  					foundImage = true
105  					imageBytes = len(b.Source.Data) * 3 / 4
106  					t.Logf("✅ msg[%d] role=%s: FOUND IMAGE — %s, %d KB base64 (%d KB raw)",
107  						i, msg.Role, b.Source.MediaType, len(b.Source.Data)/1024, imageBytes/1024)
108  				}
109  				if b.Type == "text" {
110  					t.Logf("   msg[%d] role=%s: text block: %.80s...", i, msg.Role, b.Text)
111  				}
112  			}
113  		} else {
114  			// String content
115  			var s string
116  			json.Unmarshal(msg.Content, &s)
117  			if len(s) > 100 {
118  				s = s[:100] + "..."
119  			}
120  			t.Logf("   msg[%d] role=%s: string: %.80s", i, msg.Role, s)
121  		}
122  	}
123  
124  	if !foundImage {
125  		t.Fatal("❌ NO IMAGE BLOCK found in API request — vision pipeline broken!")
126  	}
127  	if imageBytes < 10000 {
128  		t.Errorf("image seems too small (%d bytes) — may be a broken/empty screenshot", imageBytes)
129  	}
130  	fmt.Printf("\n✅ Vision pipeline verified: real screenshot (%d KB) delivered as image content block in API request\n", imageBytes/1024)
131  }