/ src / utils / hooks / postSamplingHooks.ts
postSamplingHooks.ts
 1  import type { QuerySource } from '../../constants/querySource.js'
 2  import type { ToolUseContext } from '../../Tool.js'
 3  import type { Message } from '../../types/message.js'
 4  import { toError } from '../errors.js'
 5  import { logError } from '../log.js'
 6  import type { SystemPrompt } from '../systemPromptType.js'
 7  
 8  // Post-sampling hook - not exposed in settings.json config (yet), only used programmatically
 9  
10  // Generic context for REPL hooks (both post-sampling and stop hooks)
11  export type REPLHookContext = {
12    messages: Message[] // Full message history including assistant responses
13    systemPrompt: SystemPrompt
14    userContext: { [k: string]: string }
15    systemContext: { [k: string]: string }
16    toolUseContext: ToolUseContext
17    querySource?: QuerySource
18  }
19  
20  export type PostSamplingHook = (
21    context: REPLHookContext,
22  ) => Promise<void> | void
23  
24  // Internal registry for post-sampling hooks
25  const postSamplingHooks: PostSamplingHook[] = []
26  
27  /**
28   * Register a post-sampling hook that will be called after model sampling completes
29   * This is an internal API not exposed through settings
30   */
31  export function registerPostSamplingHook(hook: PostSamplingHook): void {
32    postSamplingHooks.push(hook)
33  }
34  
35  /**
36   * Clear all registered post-sampling hooks (for testing)
37   */
38  export function clearPostSamplingHooks(): void {
39    postSamplingHooks.length = 0
40  }
41  
42  /**
43   * Execute all registered post-sampling hooks
44   */
45  export async function executePostSamplingHooks(
46    messages: Message[],
47    systemPrompt: SystemPrompt,
48    userContext: { [k: string]: string },
49    systemContext: { [k: string]: string },
50    toolUseContext: ToolUseContext,
51    querySource?: QuerySource,
52  ): Promise<void> {
53    const context: REPLHookContext = {
54      messages,
55      systemPrompt,
56      userContext,
57      systemContext,
58      toolUseContext,
59      querySource,
60    }
61  
62    for (const hook of postSamplingHooks) {
63      try {
64        await hook(context)
65      } catch (error) {
66        // Log but don't fail on hook errors
67        logError(toError(error))
68      }
69    }
70  }