/ tests / chat-stream-golden-trace.test.ts
chat-stream-golden-trace.test.ts
  1  import { randomUUID } from 'node:crypto'
  2  
  3  import { beforeEach, describe, expect, it, vi } from 'vitest'
  4  
  5  import type { ChatStreamEvent } from '@/lib/shared/chat'
  6  import { streamChatMessageEvents } from '@/server/chat/service'
  7  
  8  const streamReplyMock = vi.hoisted(() => vi.fn())
  9  
 10  vi.mock('@/server/chat/memory-retrieval', () => ({
 11    refreshMemoryEmbeddings: vi.fn(async () => {}),
 12    selectRelevantMemories: vi.fn(async () => []),
 13  }))
 14  
 15  vi.mock('@/server/providers', () => ({
 16    getChatProvider: vi.fn(() => ({
 17      streamReply: streamReplyMock,
 18      generateReply: vi.fn(async () => ({
 19        provider: 'mock-provider',
 20        mocked: false,
 21        text: 'unused',
 22      })),
 23    })),
 24  }))
 25  
 26  describe('chat stream golden trace', () => {
 27    beforeEach(() => {
 28      vi.clearAllMocks()
 29    })
 30  
 31    it('preserves event ordering and tool lifecycle invariants for a streamed turn', async () => {
 32      const sessionId = randomUUID()
 33  
 34      streamReplyMock.mockImplementationOnce(async () => ({
 35        provider: 'openai-compatible',
 36        mocked: false,
 37        stream: (async function* () {
 38          yield {
 39            type: 'tool_start' as const,
 40            tool: 'bash',
 41            toolCallId: 'call_1',
 42            source: 'function' as const,
 43            input: { command: 'pwd' },
 44          }
 45          yield {
 46            type: 'tool_progress' as const,
 47            tool: 'bash',
 48            toolCallId: 'call_1',
 49            source: 'function' as const,
 50            progress: 'Running command...',
 51          }
 52          yield {
 53            type: 'thinking' as const,
 54            thinking: 'Inspecting output',
 55          }
 56          yield {
 57            type: 'delta' as const,
 58            delta: 'Phase 1. ',
 59          }
 60          yield {
 61            type: 'tool_complete' as const,
 62            tool: 'bash',
 63            toolCallId: 'call_1',
 64            source: 'function' as const,
 65            ok: true,
 66            output: '/Users/justinedwards/git/helper',
 67          }
 68          yield {
 69            type: 'tool_start' as const,
 70            tool: 'web_search',
 71            toolCallId: 'call_2',
 72            source: 'function' as const,
 73            input: { query: 'vite config docs' },
 74          }
 75          yield {
 76            type: 'delta' as const,
 77            delta: 'Phase 2. ',
 78          }
 79          yield {
 80            type: 'tool_progress' as const,
 81            tool: 'web_search',
 82            toolCallId: 'call_2',
 83            source: 'function' as const,
 84            progress: 'Collecting sources...',
 85          }
 86          yield {
 87            type: 'tool_complete' as const,
 88            tool: 'web_search',
 89            toolCallId: 'call_2',
 90            source: 'function' as const,
 91            ok: true,
 92            output: '2 sources',
 93          }
 94          yield {
 95            type: 'delta' as const,
 96            delta: 'Final answer.',
 97          }
 98        })(),
 99      }))
100  
101      const events: ChatStreamEvent[] = []
102      for await (const event of streamChatMessageEvents({
103        sessionId,
104        text: 'run tools and summarize',
105        dangerousTools: { bash: true },
106      })) {
107        events.push(event)
108      }
109  
110      const types = events.map((event) => event.type)
111      expect(types).toContain('task_update')
112      expect(types).toContain('task_event')
113  
114      const nonTaskTypes = types.filter(
115        (type) => type !== 'task_update' && type !== 'task_event',
116      )
117      expect(nonTaskTypes).toEqual([
118        'session',
119        'tool_start',
120        'tool_progress',
121        'thinking',
122        'delta',
123        'tool_complete',
124        'tool_start',
125        'delta',
126        'tool_progress',
127        'tool_complete',
128        'delta',
129        'done',
130      ])
131  
132      expect(events.length).toBeGreaterThan(0)
133      const sessionEvent = events[0]
134      expect(sessionEvent.type).toBe('session')
135      if (sessionEvent.type !== 'session') {
136        throw new Error('Expected first event to be a session event.')
137      }
138      const runId = sessionEvent.runId
139      expect(runId).toBeTruthy()
140  
141      for (const event of events) {
142        if (event.type === 'session') continue
143        expect(event.runId).toBe(runId)
144      }
145  
146      const seenToolStarts = new Set<string>()
147      for (const event of events) {
148        if (event.type === 'tool_start') {
149          seenToolStarts.add(event.toolCallId ?? '')
150        }
151        if (event.type === 'tool_progress' || event.type === 'tool_complete') {
152          expect(seenToolStarts.has(event.toolCallId ?? '')).toBe(true)
153        }
154      }
155  
156      const finalEvent = events.at(-1)
157      expect(finalEvent?.type).toBe('done')
158      if (finalEvent?.type === 'done') {
159        expect(finalEvent.assistantMessage.text).toBe(
160          'Phase 1. Phase 2. Final answer.',
161        )
162      }
163    })
164  })