openai-compatible-tool-trigger-routing.test.ts
1 import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' 2 3 import { openAiCompatibleProvider } from '@/server/providers/openai-compatible' 4 import type { ProviderChatInput, ProviderChatStreamChunk } from '@/server/providers/types' 5 6 vi.mock('@/server/costs/ledger', () => ({ 7 recordApiCall: vi.fn(), 8 })) 9 10 vi.mock('@/server/storage/chat-store', () => ({ 11 getUploadById: vi.fn(() => null), 12 })) 13 14 vi.mock('@/server/uploads/files', () => ({ 15 readUploadFile: vi.fn(), 16 })) 17 18 const originalEnv = { ...process.env } 19 const originalFetch = global.fetch 20 21 function jsonResponse(body: unknown, status = 200): Response { 22 return new Response(JSON.stringify(body), { 23 status, 24 headers: { 25 'Content-Type': 'application/json', 26 }, 27 }) 28 } 29 30 function buildSseResponse(events: Array<Record<string, unknown> | '[DONE]'>): Response { 31 const encoder = new TextEncoder() 32 const stream = new ReadableStream<Uint8Array>({ 33 start(controller) { 34 for (const payload of events) { 35 const line = 36 payload === '[DONE]' 37 ? 'data: [DONE]\n\n' 38 : `data: ${JSON.stringify(payload)}\n\n` 39 controller.enqueue(encoder.encode(line)) 40 } 41 controller.close() 42 }, 43 }) 44 45 return new Response(stream, { 46 status: 200, 47 headers: { 48 'Content-Type': 'text/event-stream', 49 }, 50 }) 51 } 52 53 async function collectDeltaText( 54 stream: AsyncIterable<ProviderChatStreamChunk>, 55 ): Promise<string> { 56 let text = '' 57 for await (const chunk of stream) { 58 if (chunk.type === 'delta') { 59 text += chunk.delta 60 } 61 } 62 return text 63 } 64 65 describe('openAiCompatibleProvider tool trigger routing', () => { 66 beforeEach(() => { 67 vi.clearAllMocks() 68 process.env = { 69 ...originalEnv, 70 LLM_BASE_URL: 'https://example.test/v1', 71 LLM_CHAT_MODEL: 'minimax-m2', 72 LLM_TIMEOUT_MS: '30000', 73 } 74 global.fetch = vi.fn() as typeof fetch 75 }) 76 77 afterEach(() => { 78 process.env = originalEnv 79 global.fetch = originalFetch 80 }) 81 82 it('does not force local tool loop for non-tool prompts when bash is enabled', async () => { 83 const fetchMock = vi.mocked(global.fetch) 84 fetchMock.mockResolvedValueOnce( 85 buildSseResponse([ 86 { type: 'response.output_text.delta', delta: 'hello' }, 87 '[DONE]', 88 ]), 89 ) 90 91 const input: ProviderChatInput = { 92 systemPrompt: 'You are helpful.', 93 compactedSummary: '', 94 memories: [], 95 messages: [ 96 { 97 role: 'user', 98 text: 'say hello politely', 99 attachments: [], 100 }, 101 ], 102 providerOverride: { 103 baseUrl: 'https://router.example.com/v1', 104 apiKey: null, 105 chatEndpointMode: 'auto', 106 }, 107 modelOverride: 'minimax-m2', 108 allowDangerousBashTool: true, 109 } 110 111 const result = await openAiCompatibleProvider.streamReply(input) 112 const text = await collectDeltaText(result.stream) 113 114 expect(text).toContain('hello') 115 expect(fetchMock).toHaveBeenCalledTimes(1) 116 const requestInit = fetchMock.mock.calls[0]?.[1] as RequestInit 117 const payload = JSON.parse(String(requestInit.body)) as { 118 stream?: boolean 119 } 120 expect(payload.stream).toBe(true) 121 }) 122 123 it('still routes explicit bash prompts through the local tool loop parser path', async () => { 124 const fetchMock = vi.mocked(global.fetch) 125 fetchMock.mockResolvedValueOnce( 126 jsonResponse({ 127 model: 'minimax-m2', 128 output_text: 'No command was run.', 129 }), 130 ) 131 132 const input: ProviderChatInput = { 133 systemPrompt: 'You are helpful.', 134 compactedSummary: '', 135 memories: [], 136 messages: [ 137 { 138 role: 'user', 139 text: 'run tool option 2 for me', 140 attachments: [], 141 }, 142 ], 143 providerOverride: { 144 baseUrl: 'https://router.example.com/v1', 145 apiKey: null, 146 chatEndpointMode: 'auto', 147 }, 148 modelOverride: 'minimax-m2', 149 allowDangerousBashTool: true, 150 } 151 152 const result = await openAiCompatibleProvider.streamReply(input) 153 const text = await collectDeltaText(result.stream) 154 155 expect(text).toContain('No command was run.') 156 expect(fetchMock).toHaveBeenCalledTimes(1) 157 const requestInit = fetchMock.mock.calls[0]?.[1] as RequestInit 158 const payload = JSON.parse(String(requestInit.body)) as { 159 stream?: boolean 160 tools?: Array<{ name?: string }> 161 } 162 expect(payload.stream).toBe(false) 163 expect(payload.tools?.some((tool) => tool.name === 'bash')).toBe(true) 164 }) 165 166 it('routes follow-up retry prompts through local tool loop when recent context has unresolved tool calls', async () => { 167 const fetchMock = vi.mocked(global.fetch) 168 fetchMock.mockResolvedValueOnce( 169 jsonResponse({ 170 model: 'minimax-m2', 171 output_text: 'Trying the tool path again now.', 172 }), 173 ) 174 175 const input: ProviderChatInput = { 176 systemPrompt: 'You are helpful.', 177 compactedSummary: '', 178 memories: [], 179 messages: [ 180 { 181 role: 'assistant', 182 text: '<tool_call>\nweb_search({"query":"hn front page"})\n</tool_call>', 183 attachments: [], 184 }, 185 { 186 role: 'user', 187 text: 'try again', 188 attachments: [], 189 }, 190 ], 191 providerOverride: { 192 baseUrl: 'https://router.example.com/v1', 193 apiKey: null, 194 chatEndpointMode: 'auto', 195 }, 196 modelOverride: 'minimax-m2', 197 allowDangerousBashTool: true, 198 } 199 200 const result = await openAiCompatibleProvider.streamReply(input) 201 const text = await collectDeltaText(result.stream) 202 203 expect(text).toContain('Trying the tool path again now.') 204 expect(fetchMock).toHaveBeenCalledTimes(1) 205 const requestInit = fetchMock.mock.calls[0]?.[1] as RequestInit 206 const payload = JSON.parse(String(requestInit.body)) as { 207 stream?: boolean 208 } 209 expect(payload.stream).toBe(false) 210 }) 211 })