gateway_test.go
1 package client 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "io" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 "time" 13 ) 14 15 func TestCompleteUsesCompletionsEndpoint(t *testing.T) { 16 got := struct { 17 Messages []Message `json:"messages"` 18 Tools []Tool `json:"tools"` 19 }{} 20 21 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 22 if r.URL.Path != "/v1/completions" { 23 t.Errorf("unexpected path: %s", r.URL.Path) 24 } 25 if r.Method != http.MethodPost { 26 t.Errorf("expected POST, got %s", r.Method) 27 } 28 body, err := io.ReadAll(r.Body) 29 if err != nil { 30 t.Fatalf("read body: %v", err) 31 } 32 if err := json.Unmarshal(body, &got); err != nil { 33 t.Fatalf("decode request: %v", err) 34 } 35 w.Header().Set("Content-Type", "application/json") 36 json.NewEncoder(w).Encode(CompletionResponse{ 37 OutputText: "hello", 38 FinishReason: "end_turn", 39 Usage: Usage{ 40 InputTokens: 3, 41 OutputTokens: 4, 42 TotalTokens: 7, 43 }, 44 }) 45 })) 46 defer server.Close() 47 48 gw := NewGatewayClient(server.URL, "key") 49 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 50 defer cancel() 51 52 resp, err := gw.Complete(ctx, CompletionRequest{ 53 Messages: []Message{{Role: "user", Content: NewTextContent("ping")}}, 54 Tools: []Tool{{Type: "function"}}, 55 }) 56 if err != nil { 57 t.Fatalf("unexpected error: %v", err) 58 } 59 if resp.OutputText != "hello" { 60 t.Fatalf("expected output hello, got %s", resp.OutputText) 61 } 62 if len(got.Messages) != 1 || got.Messages[0].Content.Text() != "ping" { 63 t.Errorf("request body messages not preserved") 64 } 65 if len(got.Tools) != 1 || got.Tools[0].Type != "function" { 66 t.Errorf("expected tool payload to include tools") 67 } 68 } 69 70 func TestListTools(t *testing.T) { 71 tools := []ServerToolSchema{ 72 {Name: "web_search", Description: "Search the web", Parameters: map[string]any{"type": "object"}}, 73 {Name: "getStockBars", Description: "Get stock bars"}, 74 } 75 76 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 77 if r.URL.Path != "/api/v1/tools" { 78 t.Errorf("unexpected path: %s", r.URL.Path) 79 } 80 if r.Method != http.MethodGet { 81 t.Errorf("expected GET, got %s", r.Method) 82 } 83 if r.Header.Get("X-API-Key") != "test-key" { 84 t.Errorf("expected X-API-Key=test-key, got %s", r.Header.Get("X-API-Key")) 85 } 86 w.Header().Set("Content-Type", "application/json") 87 json.NewEncoder(w).Encode(tools) 88 })) 89 defer server.Close() 90 91 gw := NewGatewayClient(server.URL, "test-key") 92 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 93 defer cancel() 94 95 got, err := gw.ListTools(ctx) 96 if err != nil { 97 t.Fatalf("unexpected error: %v", err) 98 } 99 if len(got) != 2 { 100 t.Fatalf("expected 2 tools, got %d", len(got)) 101 } 102 if got[0].Name != "web_search" { 103 t.Errorf("expected web_search, got %s", got[0].Name) 104 } 105 if got[1].Name != "getStockBars" { 106 t.Errorf("expected getStockBars, got %s", got[1].Name) 107 } 108 } 109 110 func TestListTools_ServerError(t *testing.T) { 111 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 112 w.WriteHeader(http.StatusInternalServerError) 113 w.Write([]byte("internal error")) 114 })) 115 defer server.Close() 116 117 gw := NewGatewayClient(server.URL, "") 118 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 119 defer cancel() 120 121 _, err := gw.ListTools(ctx) 122 if err == nil { 123 t.Fatal("expected error") 124 } 125 } 126 127 func TestExecuteTool(t *testing.T) { 128 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 129 if r.URL.Path != "/api/v1/tools/web_search/execute" { 130 t.Errorf("unexpected path: %s", r.URL.Path) 131 } 132 if r.Method != http.MethodPost { 133 t.Errorf("expected POST, got %s", r.Method) 134 } 135 136 var req ToolExecuteRequest 137 json.NewDecoder(r.Body).Decode(&req) 138 if req.Arguments["query"] != "golang testing" { 139 t.Errorf("expected query=golang testing, got %v", req.Arguments["query"]) 140 } 141 142 w.Header().Set("Content-Type", "application/json") 143 json.NewEncoder(w).Encode(ToolExecuteResponse{ 144 Success: true, 145 Output: json.RawMessage(`{"results":["found 10 results"]}`), 146 }) 147 })) 148 defer server.Close() 149 150 gw := NewGatewayClient(server.URL, "key") 151 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 152 defer cancel() 153 154 resp, err := gw.ExecuteTool(ctx, "web_search", map[string]any{"query": "golang testing"}, "") 155 if err != nil { 156 t.Fatalf("unexpected error: %v", err) 157 } 158 if !resp.Success { 159 t.Error("expected success=true") 160 } 161 if string(resp.Output) != `{"results":["found 10 results"]}` { 162 t.Errorf("unexpected output: %s", string(resp.Output)) 163 } 164 } 165 166 func TestExecuteTool_UrlEscapesName(t *testing.T) { 167 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 168 // r.URL.RawPath preserves the percent-encoding; r.URL.Path is decoded 169 want := "/api/v1/tools/my%2Ftool/execute" 170 if r.URL.RawPath != want { 171 t.Errorf("expected raw path %s, got %s", want, r.URL.RawPath) 172 } 173 174 w.Header().Set("Content-Type", "application/json") 175 json.NewEncoder(w).Encode(ToolExecuteResponse{Success: true, Output: json.RawMessage(`"ok"`)}) 176 })) 177 defer server.Close() 178 179 gw := NewGatewayClient(server.URL, "") 180 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 181 defer cancel() 182 183 _, err := gw.ExecuteTool(ctx, "my/tool", map[string]any{}, "") 184 if err != nil { 185 t.Fatalf("unexpected error: %v", err) 186 } 187 } 188 189 func TestExecuteTool_403(t *testing.T) { 190 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 191 w.WriteHeader(http.StatusForbidden) 192 w.Write([]byte("tool not allowed")) 193 })) 194 defer server.Close() 195 196 gw := NewGatewayClient(server.URL, "") 197 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 198 defer cancel() 199 200 _, err := gw.ExecuteTool(ctx, "dangerous_tool", map[string]any{}, "") 201 if err == nil { 202 t.Fatal("expected error for 403") 203 } 204 } 205 206 func TestCompletionRequest_MarshalsCacheSourceField(t *testing.T) { 207 req := CompletionRequest{ 208 Messages: []Message{{Role: "user", Content: NewTextContent("hi")}}, 209 CacheSource: "webhook", 210 } 211 b, err := json.Marshal(req) 212 if err != nil { 213 t.Fatalf("marshal failed: %v", err) 214 } 215 if !bytes.Contains(b, []byte(`"cache_source":"webhook"`)) { 216 t.Fatalf("cache_source missing on wire: %s", b) 217 } 218 } 219 220 func TestCompletionRequest_OmitsCacheSourceWhenEmpty(t *testing.T) { 221 // Unset CacheSource must not emit the field — Shannon interprets absence 222 // as "unknown" and falls back to 5m TTL. 223 req := CompletionRequest{ 224 Messages: []Message{{Role: "user", Content: NewTextContent("hi")}}, 225 } 226 b, err := json.Marshal(req) 227 if err != nil { 228 t.Fatalf("marshal failed: %v", err) 229 } 230 if bytes.Contains(b, []byte("cache_source")) { 231 t.Fatalf("expected cache_source omitted when empty, got: %s", b) 232 } 233 } 234 235 func TestMessageContent_MarshalString(t *testing.T) { 236 msg := Message{Role: "user", Content: NewTextContent("hello")} 237 data, err := json.Marshal(msg) 238 if err != nil { 239 t.Fatalf("marshal error: %v", err) 240 } 241 var raw map[string]json.RawMessage 242 json.Unmarshal(data, &raw) 243 var content string 244 if err := json.Unmarshal(raw["content"], &content); err != nil { 245 t.Fatalf("content should be a string, got: %s", string(raw["content"])) 246 } 247 if content != "hello" { 248 t.Errorf("expected 'hello', got %q", content) 249 } 250 } 251 252 func TestMessageContent_MarshalBlocks(t *testing.T) { 253 msg := Message{ 254 Role: "user", 255 Content: NewBlockContent([]ContentBlock{ 256 {Type: "text", Text: "describe this"}, 257 {Type: "image", Source: &ImageSource{Type: "base64", MediaType: "image/png", Data: "abc123"}}, 258 }), 259 } 260 data, err := json.Marshal(msg) 261 if err != nil { 262 t.Fatalf("marshal error: %v", err) 263 } 264 var raw map[string]json.RawMessage 265 json.Unmarshal(data, &raw) 266 var blocks []ContentBlock 267 if err := json.Unmarshal(raw["content"], &blocks); err != nil { 268 t.Fatalf("content should be an array, got: %s", string(raw["content"])) 269 } 270 if len(blocks) != 2 { 271 t.Fatalf("expected 2 blocks, got %d", len(blocks)) 272 } 273 } 274 275 func TestMessageContent_UnmarshalString(t *testing.T) { 276 raw := `{"role":"user","content":"hello"}` 277 var msg Message 278 if err := json.Unmarshal([]byte(raw), &msg); err != nil { 279 t.Fatalf("unmarshal error: %v", err) 280 } 281 if msg.Content.Text() != "hello" { 282 t.Errorf("expected 'hello', got %q", msg.Content.Text()) 283 } 284 } 285 286 func TestMessageContent_UnmarshalBlocks(t *testing.T) { 287 raw := `{"role":"user","content":[{"type":"text","text":"hi"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"xyz"}}]}` 288 var msg Message 289 if err := json.Unmarshal([]byte(raw), &msg); err != nil { 290 t.Fatalf("unmarshal error: %v", err) 291 } 292 if !msg.Content.HasBlocks() { 293 t.Fatal("expected blocks") 294 } 295 blocks := msg.Content.Blocks() 296 if len(blocks) != 2 { 297 t.Fatalf("expected 2 blocks, got %d", len(blocks)) 298 } 299 } 300 301 func TestContentBlock_ToolUse_MarshalJSON(t *testing.T) { 302 block := NewToolUseBlock("toolu_abc123", "bash", json.RawMessage(`{"command":"ls"}`)) 303 data, err := json.Marshal(block) 304 if err != nil { 305 t.Fatalf("marshal error: %v", err) 306 } 307 var m map[string]any 308 json.Unmarshal(data, &m) 309 if m["type"] != "tool_use" { 310 t.Errorf("expected type=tool_use, got %v", m["type"]) 311 } 312 if m["id"] != "toolu_abc123" { 313 t.Errorf("expected id=toolu_abc123, got %v", m["id"]) 314 } 315 if m["name"] != "bash" { 316 t.Errorf("expected name=bash, got %v", m["name"]) 317 } 318 } 319 320 func TestContentBlock_ToolResult_MarshalJSON_StringContent(t *testing.T) { 321 block := NewToolResultBlock("toolu_abc123", "file1.txt\nfile2.txt", false) 322 data, err := json.Marshal(block) 323 if err != nil { 324 t.Fatalf("marshal error: %v", err) 325 } 326 var m map[string]any 327 json.Unmarshal(data, &m) 328 if m["type"] != "tool_result" { 329 t.Errorf("expected type=tool_result, got %v", m["type"]) 330 } 331 if m["tool_use_id"] != "toolu_abc123" { 332 t.Errorf("expected tool_use_id=toolu_abc123, got %v", m["tool_use_id"]) 333 } 334 if m["content"] != "file1.txt\nfile2.txt" { 335 t.Errorf("unexpected content: %v", m["content"]) 336 } 337 if _, ok := m["is_error"]; ok { 338 t.Error("is_error should be omitted when false") 339 } 340 } 341 342 func TestContentBlock_ToolResult_MarshalJSON_ArrayContent(t *testing.T) { 343 block := NewToolResultBlockWithImages("toolu_xyz", "Screenshot captured", []ContentBlock{ 344 {Type: "image", Source: &ImageSource{Type: "base64", MediaType: "image/png", Data: "fakedata"}}, 345 }, false) 346 data, err := json.Marshal(block) 347 if err != nil { 348 t.Fatalf("marshal error: %v", err) 349 } 350 var m map[string]any 351 json.Unmarshal(data, &m) 352 contentArr, ok := m["content"].([]any) 353 if !ok { 354 t.Fatalf("expected content to be array, got %T: %v", m["content"], m["content"]) 355 } 356 if len(contentArr) != 2 { 357 t.Fatalf("expected 2 content blocks (text+image), got %d", len(contentArr)) 358 } 359 } 360 361 func TestContentBlock_ToolResult_RoundTrip(t *testing.T) { 362 // String content round-trip 363 original := NewToolResultBlock("toolu_abc", "result text", true) 364 data, err := json.Marshal(original) 365 if err != nil { 366 t.Fatalf("marshal error: %v", err) 367 } 368 var decoded ContentBlock 369 if err := json.Unmarshal(data, &decoded); err != nil { 370 t.Fatalf("unmarshal error: %v", err) 371 } 372 if decoded.Type != "tool_result" { 373 t.Errorf("type mismatch: %s", decoded.Type) 374 } 375 if decoded.ToolUseID != "toolu_abc" { 376 t.Errorf("tool_use_id mismatch: %s", decoded.ToolUseID) 377 } 378 if !decoded.IsError { 379 t.Error("is_error should be true") 380 } 381 text, ok := decoded.ToolContent.(string) 382 if !ok { 383 t.Fatalf("expected string content, got %T", decoded.ToolContent) 384 } 385 if text != "result text" { 386 t.Errorf("content mismatch: %s", text) 387 } 388 389 // Array content round-trip 390 original2 := NewToolResultBlockWithImages("toolu_xyz", "Screenshot", []ContentBlock{ 391 {Type: "image", Source: &ImageSource{Type: "base64", MediaType: "image/png", Data: "abc"}}, 392 }, false) 393 data2, _ := json.Marshal(original2) 394 var decoded2 ContentBlock 395 if err := json.Unmarshal(data2, &decoded2); err != nil { 396 t.Fatalf("unmarshal error: %v", err) 397 } 398 blocks, ok := decoded2.ToolContent.([]ContentBlock) 399 if !ok { 400 t.Fatalf("expected []ContentBlock, got %T", decoded2.ToolContent) 401 } 402 if len(blocks) != 2 { 403 t.Fatalf("expected 2 nested blocks, got %d", len(blocks)) 404 } 405 } 406 407 func TestFunctionCall_ID(t *testing.T) { 408 raw := `{"id":"toolu_abc","name":"bash","arguments":{"command":"ls"}}` 409 var fc FunctionCall 410 if err := json.Unmarshal([]byte(raw), &fc); err != nil { 411 t.Fatalf("unmarshal error: %v", err) 412 } 413 if fc.ID != "toolu_abc" { 414 t.Errorf("expected ID=toolu_abc, got %q", fc.ID) 415 } 416 if fc.Name != "bash" { 417 t.Errorf("expected Name=bash, got %q", fc.Name) 418 } 419 } 420 421 func TestToolResultText_Extraction(t *testing.T) { 422 // String content 423 b1 := NewToolResultBlock("id1", "hello world", false) 424 if got := ToolResultText(b1); got != "hello world" { 425 t.Errorf("expected 'hello world', got %q", got) 426 } 427 // Array content 428 b2 := NewToolResultBlockWithImages("id2", "screenshot taken", nil, false) 429 if got := ToolResultText(b2); got != "screenshot taken" { 430 t.Errorf("expected 'screenshot taken', got %q", got) 431 } 432 // Non-tool_result 433 b3 := ContentBlock{Type: "text", Text: "plain"} 434 if got := ToolResultText(b3); got != "" { 435 t.Errorf("expected empty, got %q", got) 436 } 437 } 438 439 // TestNormalizeToolInput verifies the shared helper that coerces null/empty 440 // tool_use input to an empty object. See issue #45. 441 func TestNormalizeToolInput(t *testing.T) { 442 cases := []struct { 443 name string 444 in json.RawMessage 445 want string 446 }{ 447 {"nil", nil, "{}"}, 448 {"empty bytes", json.RawMessage(""), "{}"}, 449 {"literal null", json.RawMessage("null"), "{}"}, 450 {"null with leading whitespace", json.RawMessage(" null"), "{}"}, 451 {"null with trailing whitespace", json.RawMessage("null "), "{}"}, 452 {"null with surrounding whitespace", json.RawMessage(" null "), "{}"}, 453 {"whitespace only", json.RawMessage(" "), "{}"}, 454 {"empty object preserved", json.RawMessage("{}"), "{}"}, 455 {"populated object preserved", json.RawMessage(`{"x":1}`), `{"x":1}`}, 456 {"nested object preserved", json.RawMessage(`{"a":{"b":2}}`), `{"a":{"b":2}}`}, 457 // Double-encoded string unwrap — OpenAI-shaped adapters sometimes 458 // return tool arguments as a JSON-encoded string wrapping an object. 459 // Anthropic's tool_use.input validator rejects these unless unwrapped. 460 {"double-encoded simple", json.RawMessage(`"{\"command\":\"ls\"}"`), `{"command":"ls"}`}, 461 {"double-encoded nested", json.RawMessage(`"{\"a\":{\"b\":2}}"`), `{"a":{"b":2}}`}, 462 {"double-encoded empty object", json.RawMessage(`"{}"`), `{}`}, 463 {"double-encoded with whitespace inside", json.RawMessage(`" {\"x\":1} "`), `{"x":1}`}, 464 // Non-object scalars / strings still pass through untouched so 465 // genuine provider bugs remain visible rather than silently masked. 466 // See TestContentBlock_MarshalJSON_PreservesNonObjectToolUseInput. 467 {"plain string passthrough", json.RawMessage(`"hello"`), `"hello"`}, 468 {"empty string passthrough", json.RawMessage(`""`), `""`}, 469 {"quoted null passthrough", json.RawMessage(`"null"`), `"null"`}, 470 {"encoded array passthrough", json.RawMessage(`"[1,2,3]"`), `"[1,2,3]"`}, 471 } 472 for _, tc := range cases { 473 t.Run(tc.name, func(t *testing.T) { 474 got := normalizeToolInput(tc.in) 475 if string(got) != tc.want { 476 t.Errorf("normalizeToolInput(%q) = %q, want %q", string(tc.in), string(got), tc.want) 477 } 478 }) 479 } 480 } 481 482 // TestNormalizeToolInput_CanonicalizesKeyOrdering verifies multi-key objects 483 // produce identical bytes regardless of source key order. This closes the 484 // byte-drift class that caused session 2026-04-15-69f601dc1c98's 17 distinct 485 // system_h variants over 61 requests (same system_len) → −13pp CHR regression 486 // vs the session-peer median. 487 func TestNormalizeToolInput_CanonicalizesKeyOrdering(t *testing.T) { 488 // Same logical content, different source key orders → must marshal identically. 489 cases := []struct { 490 name string 491 a, b json.RawMessage 492 }{ 493 { 494 "flat two keys", 495 json.RawMessage(`{"path":"/etc","line":5}`), 496 json.RawMessage(`{"line":5,"path":"/etc"}`), 497 }, 498 { 499 "nested map", 500 json.RawMessage(`{"x":{"b":1,"a":2},"y":{"d":3,"c":4}}`), 501 json.RawMessage(`{"y":{"c":4,"d":3},"x":{"a":2,"b":1}}`), 502 }, 503 { 504 "deeply nested", 505 json.RawMessage(`{"outer":{"mid":{"z":1,"a":2}}}`), 506 json.RawMessage(`{"outer":{"mid":{"a":2,"z":1}}}`), 507 }, 508 } 509 for _, tc := range cases { 510 t.Run(tc.name, func(t *testing.T) { 511 ga := normalizeToolInput(tc.a) 512 gb := normalizeToolInput(tc.b) 513 if string(ga) != string(gb) { 514 t.Fatalf("canonical output differs:\n a=%s\n b=%s\n got_a=%s\n got_b=%s", 515 tc.a, tc.b, ga, gb) 516 } 517 }) 518 } 519 } 520 521 // TestNormalizeToolInput_PreservesLargeIntegerPrecision guards against a 522 // regression where the canonical-ordering roundtrip decoded JSON numbers into 523 // float64 and silently truncated integers above 2^53. Real payloads that hit 524 // this: Unix nanosecond timestamps (19 digits), 64-bit row IDs, byte sizes. 525 // The fix uses json.Decoder.UseNumber() so digits round-trip verbatim. 526 func TestNormalizeToolInput_PreservesLargeIntegerPrecision(t *testing.T) { 527 cases := []struct { 528 name string 529 in json.RawMessage 530 want string // substring that must appear in the normalized output 531 }{ 532 { 533 "unix nanoseconds", 534 json.RawMessage(`{"nanos":1716398400000000000}`), 535 `1716398400000000000`, 536 }, 537 { 538 "near max int64", 539 json.RawMessage(`{"id":9223372036854775807}`), 540 `9223372036854775807`, 541 }, 542 { 543 "nested large int", 544 json.RawMessage(`{"meta":{"row_id":1234567890123456789}}`), 545 `1234567890123456789`, 546 }, 547 } 548 for _, tc := range cases { 549 t.Run(tc.name, func(t *testing.T) { 550 got := string(normalizeToolInput(tc.in)) 551 if !strings.Contains(got, tc.want) { 552 t.Fatalf("precision lost:\n in=%s\n want substring=%s\n got=%s", 553 tc.in, tc.want, got) 554 } 555 }) 556 } 557 } 558 559 // TestNormalizeToolInput_DoubleEncodedCanonicalization verifies that 560 // double-encoded multi-key objects get both unwrapped AND canonicalized. 561 func TestNormalizeToolInput_DoubleEncodedCanonicalization(t *testing.T) { 562 // Keys in reverse-alpha order inside the double-encoded string. 563 in := json.RawMessage(`"{\"z_path\":\"/etc\",\"a_line\":5}"`) 564 got := string(normalizeToolInput(in)) 565 want := `{"a_line":5,"z_path":"/etc"}` 566 if got != want { 567 t.Fatalf("double-encoded multi-key not canonicalized:\n got =%s\n want=%s", got, want) 568 } 569 } 570 571 // TestNewToolUseBlock_NormalizesInput verifies that the constructor coerces 572 // null/empty input to {} so in-memory consumers (ollama.go, microcompact, etc.) 573 // never see a literal "null" when reading block.Input. 574 func TestNewToolUseBlock_NormalizesInput(t *testing.T) { 575 cases := []struct { 576 name string 577 in json.RawMessage 578 want string 579 }{ 580 {"nil input", nil, "{}"}, 581 {"literal null", json.RawMessage("null"), "{}"}, 582 {"empty bytes", json.RawMessage(""), "{}"}, 583 {"valid object passthrough", json.RawMessage(`{"url":"x"}`), `{"url":"x"}`}, 584 {"double-encoded string unwraps to object", json.RawMessage(`"{\"url\":\"x\"}"`), `{"url":"x"}`}, 585 } 586 for _, tc := range cases { 587 t.Run(tc.name, func(t *testing.T) { 588 b := NewToolUseBlock("tu_1", "browser_snapshot", tc.in) 589 if b.Type != "tool_use" { 590 t.Errorf("Type = %q, want tool_use", b.Type) 591 } 592 if string(b.Input) != tc.want { 593 t.Errorf("Input = %q, want %q", string(b.Input), tc.want) 594 } 595 }) 596 } 597 } 598 599 // TestContentBlock_MarshalJSON_ForcesToolUseInput is the load-bearing test for 600 // issue #45. Even if a tool_use block was constructed with nil/null Input 601 // (e.g. via a code path that bypasses NewToolUseBlock), MarshalJSON must 602 // always emit a concrete JSON object for tool_use.input. The serialized bytes 603 // must contain "input":{} and must never contain "input":null, and must never 604 // omit the input field entirely. 605 func TestContentBlock_MarshalJSON_ForcesToolUseInput(t *testing.T) { 606 cases := []struct { 607 name string 608 block ContentBlock 609 }{ 610 {"nil input", ContentBlock{Type: "tool_use", ID: "tu_1", Name: "browser_snapshot"}}, 611 {"literal null input", ContentBlock{Type: "tool_use", ID: "tu_2", Name: "browser_close", Input: json.RawMessage("null")}}, 612 {"whitespace null input", ContentBlock{Type: "tool_use", ID: "tu_3", Name: "browser_snapshot", Input: json.RawMessage(" null ")}}, 613 {"empty bytes input", ContentBlock{Type: "tool_use", ID: "tu_4", Name: "noop", Input: json.RawMessage("")}}, 614 } 615 for _, tc := range cases { 616 t.Run(tc.name, func(t *testing.T) { 617 data, err := json.Marshal(tc.block) 618 if err != nil { 619 t.Fatalf("marshal failed: %v", err) 620 } 621 s := string(data) 622 if !strings.Contains(s, `"input":{}`) { 623 t.Errorf("expected %q to contain %q", s, `"input":{}`) 624 } 625 if strings.Contains(s, `"input":null`) { 626 t.Errorf("serialized output must not contain \"input\":null, got %s", s) 627 } 628 // Verify by round-trip that "input" key exists and is a JSON object. 629 var m map[string]any 630 if err := json.Unmarshal(data, &m); err != nil { 631 t.Fatalf("round-trip unmarshal failed: %v", err) 632 } 633 input, ok := m["input"] 634 if !ok { 635 t.Errorf("input field is missing from serialized output: %s", s) 636 } 637 if _, isObj := input.(map[string]any); !isObj { 638 t.Errorf("input field is not a JSON object, got %T: %v", input, input) 639 } 640 }) 641 } 642 } 643 644 // TestContentBlock_MarshalJSON_PreservesValidToolUseInput ensures the 645 // normalization only kicks in for null/empty inputs and leaves populated 646 // inputs untouched. 647 func TestContentBlock_MarshalJSON_PreservesValidToolUseInput(t *testing.T) { 648 b := ContentBlock{ 649 Type: "tool_use", 650 ID: "tu_valid", 651 Name: "browser_navigate", 652 Input: json.RawMessage(`{"url":"https://example.com"}`), 653 } 654 data, err := json.Marshal(b) 655 if err != nil { 656 t.Fatalf("marshal failed: %v", err) 657 } 658 s := string(data) 659 if !strings.Contains(s, `"input":{"url":"https://example.com"}`) { 660 t.Errorf("expected populated input to be preserved, got %s", s) 661 } 662 } 663 664 // TestContentBlock_MarshalJSON_PreservesNonObjectToolUseInput verifies that 665 // scalar/array/string/bool tool inputs are NOT silently coerced to {} even 666 // though they are not valid tool_use.input per Anthropic's schema. The 667 // normalization intentionally targets only null/empty — any other value is 668 // passed through so the provider bug stays visible instead of being masked. 669 func TestContentBlock_MarshalJSON_PreservesNonObjectToolUseInput(t *testing.T) { 670 cases := []struct { 671 name string 672 rawInput json.RawMessage 673 expect string 674 }{ 675 {"number", json.RawMessage(`42`), `"input":42`}, 676 {"string", json.RawMessage(`"hello"`), `"input":"hello"`}, 677 {"array", json.RawMessage(`[1,2,3]`), `"input":[1,2,3]`}, 678 {"bool true", json.RawMessage(`true`), `"input":true`}, 679 {"bool false", json.RawMessage(`false`), `"input":false`}, 680 } 681 for _, tc := range cases { 682 t.Run(tc.name, func(t *testing.T) { 683 b := ContentBlock{Type: "tool_use", ID: "tu_scalar", Name: "odd_tool", Input: tc.rawInput} 684 data, err := json.Marshal(b) 685 if err != nil { 686 t.Fatalf("marshal failed: %v", err) 687 } 688 s := string(data) 689 if !strings.Contains(s, tc.expect) { 690 t.Errorf("expected %q in output, got %s", tc.expect, s) 691 } 692 if strings.Contains(s, `"input":{}`) { 693 t.Errorf("non-null scalar must not be coerced to {}, got %s", s) 694 } 695 }) 696 } 697 } 698 699 // TestContentBlock_MarshalJSON_OtherBlocksNoInputField ensures the 700 // normalization only applies to tool_use blocks. Other block types 701 // (text, image, tool_result) must NOT have an "input" field injected. 702 func TestContentBlock_MarshalJSON_OtherBlocksNoInputField(t *testing.T) { 703 cases := []struct { 704 name string 705 block ContentBlock 706 }{ 707 {"text block", ContentBlock{Type: "text", Text: "hello"}}, 708 {"tool_result block", NewToolResultBlock("tu_1", "ok", false)}, 709 } 710 for _, tc := range cases { 711 t.Run(tc.name, func(t *testing.T) { 712 data, err := json.Marshal(tc.block) 713 if err != nil { 714 t.Fatalf("marshal failed: %v", err) 715 } 716 s := string(data) 717 if strings.Contains(s, `"input"`) { 718 t.Errorf("non-tool_use block must not serialize an input field, got %s", s) 719 } 720 }) 721 } 722 } 723 724 // TestCompletionRequest_Serialization_NoNullToolInput is the full-payload 725 // regression test for issue #45. It constructs a CompletionRequest containing 726 // a tool_use block with nil/null Input (simulating the poisoned history from 727 // a previous gateway response) and asserts the final serialized JSON bytes 728 // would not be rejected by Anthropic's schema validator. 729 func TestCompletionRequest_Serialization_NoNullToolInput(t *testing.T) { 730 req := CompletionRequest{ 731 Messages: []Message{ 732 {Role: "user", Content: NewTextContent("take a snapshot")}, 733 {Role: "assistant", Content: NewBlockContent([]ContentBlock{ 734 {Type: "tool_use", ID: "tu_a", Name: "browser_snapshot"}, // nil Input 735 {Type: "tool_use", ID: "tu_b", Name: "browser_close", Input: json.RawMessage("null")}, // poisoned 736 })}, 737 {Role: "user", Content: NewBlockContent([]ContentBlock{ 738 NewToolResultBlock("tu_a", "ok", false), 739 NewToolResultBlock("tu_b", "ok", false), 740 })}, 741 }, 742 } 743 data, err := json.Marshal(req) 744 if err != nil { 745 t.Fatalf("marshal failed: %v", err) 746 } 747 s := string(data) 748 if strings.Contains(s, `"input":null`) { 749 t.Errorf("payload must not contain \"input\":null, got %s", s) 750 } 751 // Each tool_use block must have a concrete input object in the final JSON. 752 // Count occurrences of "input":{} — we expect exactly 2. 753 if strings.Count(s, `"input":{}`) != 2 { 754 t.Errorf("expected 2 occurrences of \"input\":{}, got %d: %s", strings.Count(s, `"input":{}`), s) 755 } 756 } 757 758 // TestArgumentsString_NullHandling verifies that FunctionCall.ArgumentsString 759 // also coerces literal "null" to "{}". This protects XML-fallback and any 760 // consumer reading argument strings for logging/audit. 761 func TestArgumentsString_NullHandling(t *testing.T) { 762 cases := []struct { 763 name string 764 raw json.RawMessage 765 want string 766 }{ 767 {"nil", nil, "{}"}, 768 {"empty", json.RawMessage(""), "{}"}, 769 {"literal null", json.RawMessage("null"), "{}"}, 770 {"whitespace null", json.RawMessage(" null "), "{}"}, 771 {"empty object", json.RawMessage("{}"), "{}"}, 772 {"populated object", json.RawMessage(`{"url":"x"}`), `{"url":"x"}`}, 773 {"json-encoded string", json.RawMessage(`"already a string"`), "already a string"}, 774 } 775 for _, tc := range cases { 776 t.Run(tc.name, func(t *testing.T) { 777 fc := FunctionCall{Name: "noop", Arguments: tc.raw} 778 got := fc.ArgumentsString() 779 if got != tc.want { 780 t.Errorf("ArgumentsString() = %q, want %q", got, tc.want) 781 } 782 }) 783 } 784 } 785 786 func TestUsage_JSON_5m1hSplit(t *testing.T) { 787 raw := []byte(`{ 788 "input_tokens": 100, 789 "output_tokens": 50, 790 "cache_creation_tokens": 300, 791 "cache_creation_5m_tokens": 100, 792 "cache_creation_1h_tokens": 200 793 }`) 794 var u Usage 795 if err := json.Unmarshal(raw, &u); err != nil { 796 t.Fatalf("unmarshal failed: %v", err) 797 } 798 if u.CacheCreation5mTokens != 100 { 799 t.Errorf("expected CacheCreation5mTokens=100, got %d", u.CacheCreation5mTokens) 800 } 801 if u.CacheCreation1hTokens != 200 { 802 t.Errorf("expected CacheCreation1hTokens=200, got %d", u.CacheCreation1hTokens) 803 } 804 if u.CacheCreationTokens != 300 { 805 t.Errorf("expected legacy CacheCreationTokens=300, got %d", u.CacheCreationTokens) 806 } 807 } 808 809 func TestUsage_JSON_BackwardCompat_MissingSplit(t *testing.T) { 810 // Old gateway responses that don't include the split fields yet must parse cleanly. 811 raw := []byte(`{ 812 "input_tokens": 100, 813 "cache_creation_tokens": 300 814 }`) 815 var u Usage 816 if err := json.Unmarshal(raw, &u); err != nil { 817 t.Fatalf("unmarshal failed: %v", err) 818 } 819 if u.CacheCreationTokens != 300 { 820 t.Errorf("legacy field broken, got %d", u.CacheCreationTokens) 821 } 822 if u.CacheCreation5mTokens != 0 || u.CacheCreation1hTokens != 0 { 823 t.Errorf("expected zero for absent split fields, got 5m=%d 1h=%d", 824 u.CacheCreation5mTokens, u.CacheCreation1hTokens) 825 } 826 }