NotebookReadTool.tsx
1 import type { 2 ImageBlockParam, 3 TextBlockParam, 4 } from '@anthropic-ai/sdk/resources/index.mjs' 5 6 import { existsSync, readFileSync } from 'fs' 7 import { Text } from 'ink' 8 import { extname, isAbsolute, relative, resolve } from 'path' 9 import * as React from 'react' 10 import { z } from 'zod' 11 import { FallbackToolUseRejectedMessage } from '../../components/FallbackToolUseRejectedMessage.js' 12 import { Tool } from '../../Tool.js' 13 import { 14 NotebookCellSource, 15 NotebookContent, 16 NotebookCell, 17 NotebookOutputImage, 18 NotebookCellSourceOutput, 19 NotebookCellOutput, 20 NotebookCellType, 21 } from '../../types/notebook.js' 22 import { formatOutput } from '../BashTool/utils.js' 23 import { getCwd } from '../../utils/state.js' 24 import { findSimilarFile } from '../../utils/file.js' 25 import { DESCRIPTION, PROMPT } from './prompt.js' 26 import { hasReadPermission } from '../../utils/permissions/filesystem.js' 27 28 const inputSchema = z.strictObject({ 29 notebook_path: z 30 .string() 31 .describe( 32 'The absolute path to the Jupyter notebook file to read (must be absolute, not relative)', 33 ), 34 }) 35 36 type In = typeof inputSchema 37 type Out = NotebookCellSource[] 38 39 function renderResultForAssistant(data: NotebookCellSource[]) { 40 const allResults = data.flatMap(getToolResultFromCell) 41 42 // Merge adjacent text blocks 43 return allResults.reduce<(TextBlockParam | ImageBlockParam)[]>( 44 (acc, curr) => { 45 if (acc.length === 0) return [curr] 46 47 const prev = acc[acc.length - 1] 48 if (prev && prev.type === 'text' && curr.type === 'text') { 49 // Merge the text blocks 50 prev.text += '\n' + curr.text 51 return acc 52 } 53 54 return [...acc, curr] 55 }, 56 [], 57 ) 58 } 59 60 export const NotebookReadTool = { 61 name: 'ReadNotebook', 62 async description() { 63 return DESCRIPTION 64 }, 65 async prompt() { 66 return PROMPT 67 }, 68 isReadOnly() { 69 return true 70 }, 71 inputSchema, 72 userFacingName() { 73 return 'Read Notebook' 74 }, 75 async isEnabled() { 76 return true 77 }, 78 needsPermissions({ notebook_path }) { 79 return !hasReadPermission(notebook_path) 80 }, 81 async validateInput({ notebook_path }) { 82 const fullFilePath = isAbsolute(notebook_path) 83 ? notebook_path 84 : resolve(getCwd(), notebook_path) 85 86 if (!existsSync(fullFilePath)) { 87 // Try to find a similar file with a different extension 88 const similarFilename = findSimilarFile(fullFilePath) 89 let message = 'File does not exist.' 90 91 // If we found a similar file, suggest it to the assistant 92 if (similarFilename) { 93 message += ` Did you mean ${similarFilename}?` 94 } 95 96 return { 97 result: false, 98 message, 99 } 100 } 101 102 if (extname(fullFilePath) !== '.ipynb') { 103 return { 104 result: false, 105 message: 'File must be a Jupyter notebook (.ipynb file).', 106 } 107 } 108 109 return { result: true } 110 }, 111 renderToolUseMessage(input, { verbose }) { 112 return `notebook_path: ${verbose ? input.notebook_path : relative(getCwd(), input.notebook_path)}` 113 }, 114 renderToolUseRejectedMessage() { 115 return <FallbackToolUseRejectedMessage /> 116 }, 117 118 renderToolResultMessage(content) { 119 if (!content) { 120 return <Text>No cells found in notebook</Text> 121 } 122 if (content.length < 1 || !content[0]) { 123 return <Text>No cells found in notebook</Text> 124 } 125 return <Text>Read {content.length} cells</Text> 126 }, 127 async *call({ notebook_path }) { 128 const fullPath = isAbsolute(notebook_path) 129 ? notebook_path 130 : resolve(getCwd(), notebook_path) 131 132 const content = readFileSync(fullPath, 'utf-8') 133 const notebook = JSON.parse(content) as NotebookContent 134 const language = notebook.metadata.language_info?.name ?? 'python' 135 const cells = notebook.cells.map((cell, index) => 136 processCell(cell, index, language), 137 ) 138 139 yield { 140 type: 'result', 141 resultForAssistant: renderResultForAssistant(cells), 142 data: cells, 143 } 144 }, 145 renderResultForAssistant, 146 } satisfies Tool<In, Out> 147 148 function processOutputText(text: string | string[] | undefined): string { 149 if (!text) return '' 150 const rawText = Array.isArray(text) ? text.join('') : text 151 const { truncatedContent } = formatOutput(rawText) 152 return truncatedContent 153 } 154 155 function extractImage( 156 data: Record<string, unknown>, 157 ): NotebookOutputImage | undefined { 158 if (typeof data['image/png'] === 'string') { 159 return { 160 image_data: data['image/png'] as string, 161 media_type: 'image/png', 162 } 163 } 164 if (typeof data['image/jpeg'] === 'string') { 165 return { 166 image_data: data['image/jpeg'] as string, 167 media_type: 'image/jpeg', 168 } 169 } 170 return undefined 171 } 172 173 function processOutput(output: NotebookCellOutput) { 174 switch (output.output_type) { 175 case 'stream': 176 return { 177 output_type: output.output_type, 178 text: processOutputText(output.text), 179 } 180 case 'execute_result': 181 case 'display_data': 182 return { 183 output_type: output.output_type, 184 text: processOutputText(output.data?.['text/plain']), 185 image: output.data && extractImage(output.data), 186 } 187 case 'error': 188 return { 189 output_type: output.output_type, 190 text: processOutputText( 191 `${output.ename}: ${output.evalue}\n${output.traceback.join('\n')}`, 192 ), 193 } 194 } 195 } 196 197 function processCell( 198 cell: NotebookCell, 199 index: number, 200 language: string, 201 ): NotebookCellSource { 202 const cellData: NotebookCellSource = { 203 cell: index, 204 cellType: cell.cell_type, 205 source: Array.isArray(cell.source) ? cell.source.join('') : cell.source, 206 language, 207 execution_count: cell.execution_count, 208 } 209 210 if (cell.outputs?.length) { 211 cellData.outputs = cell.outputs.map(processOutput) 212 } 213 214 return cellData 215 } 216 217 function cellContentToToolResult(cell: NotebookCellSource): TextBlockParam { 218 const metadata = [] 219 if (cell.cellType !== 'code') { 220 metadata.push(`<cell_type>${cell.cellType}</cell_type>`) 221 } 222 if (cell.language !== 'python' && cell.cellType === 'code') { 223 metadata.push(`<language>${cell.language}</language>`) 224 } 225 const cellContent = `<cell ${cell.cell}>${metadata.join('')}${cell.source}</cell ${cell.cell}>` 226 return { 227 text: cellContent, 228 type: 'text', 229 } 230 } 231 232 function cellOutputToToolResult(output: NotebookCellSourceOutput) { 233 const outputs: (TextBlockParam | ImageBlockParam)[] = [] 234 if (output.text) { 235 outputs.push({ 236 text: `\n${output.text}`, 237 type: 'text', 238 }) 239 } 240 if (output.image) { 241 outputs.push({ 242 type: 'image', 243 source: { 244 data: output.image.image_data, 245 media_type: output.image.media_type, 246 type: 'base64', 247 }, 248 }) 249 } 250 return outputs 251 } 252 253 function getToolResultFromCell(cell: NotebookCellSource) { 254 const contentResult = cellContentToToolResult(cell) 255 const outputResults = cell.outputs?.flatMap(cellOutputToToolResult) 256 return [contentResult, ...(outputResults ?? [])] 257 } 258 259 export function isNotebookCellType( 260 value: string | null, 261 ): value is NotebookCellType { 262 return value === 'code' || value === 'markdown' 263 }