/ src / tools / NotebookReadTool / NotebookReadTool.tsx
NotebookReadTool.tsx
  1  import type {
  2    ImageBlockParam,
  3    TextBlockParam,
  4  } from '@anthropic-ai/sdk/resources/index.mjs'
  5  
  6  import { existsSync, readFileSync } from 'fs'
  7  import { Text } from 'ink'
  8  import { extname, isAbsolute, relative, resolve } from 'path'
  9  import * as React from 'react'
 10  import { z } from 'zod'
 11  import { FallbackToolUseRejectedMessage } from '../../components/FallbackToolUseRejectedMessage.js'
 12  import { Tool } from '../../Tool.js'
 13  import {
 14    NotebookCellSource,
 15    NotebookContent,
 16    NotebookCell,
 17    NotebookOutputImage,
 18    NotebookCellSourceOutput,
 19    NotebookCellOutput,
 20    NotebookCellType,
 21  } from '../../types/notebook.js'
 22  import { formatOutput } from '../BashTool/utils.js'
 23  import { getCwd } from '../../utils/state.js'
 24  import { findSimilarFile } from '../../utils/file.js'
 25  import { DESCRIPTION, PROMPT } from './prompt.js'
 26  import { hasReadPermission } from '../../utils/permissions/filesystem.js'
 27  
 28  const inputSchema = z.strictObject({
 29    notebook_path: z
 30      .string()
 31      .describe(
 32        'The absolute path to the Jupyter notebook file to read (must be absolute, not relative)',
 33      ),
 34  })
 35  
 36  type In = typeof inputSchema
 37  type Out = NotebookCellSource[]
 38  
 39  function renderResultForAssistant(data: NotebookCellSource[]) {
 40    const allResults = data.flatMap(getToolResultFromCell)
 41  
 42    // Merge adjacent text blocks
 43    return allResults.reduce<(TextBlockParam | ImageBlockParam)[]>(
 44      (acc, curr) => {
 45        if (acc.length === 0) return [curr]
 46  
 47        const prev = acc[acc.length - 1]
 48        if (prev && prev.type === 'text' && curr.type === 'text') {
 49          // Merge the text blocks
 50          prev.text += '\n' + curr.text
 51          return acc
 52        }
 53  
 54        return [...acc, curr]
 55      },
 56      [],
 57    )
 58  }
 59  
 60  export const NotebookReadTool = {
 61    name: 'ReadNotebook',
 62    async description() {
 63      return DESCRIPTION
 64    },
 65    async prompt() {
 66      return PROMPT
 67    },
 68    isReadOnly() {
 69      return true
 70    },
 71    inputSchema,
 72    userFacingName() {
 73      return 'Read Notebook'
 74    },
 75    async isEnabled() {
 76      return true
 77    },
 78    needsPermissions({ notebook_path }) {
 79      return !hasReadPermission(notebook_path)
 80    },
 81    async validateInput({ notebook_path }) {
 82      const fullFilePath = isAbsolute(notebook_path)
 83        ? notebook_path
 84        : resolve(getCwd(), notebook_path)
 85  
 86      if (!existsSync(fullFilePath)) {
 87        // Try to find a similar file with a different extension
 88        const similarFilename = findSimilarFile(fullFilePath)
 89        let message = 'File does not exist.'
 90  
 91        // If we found a similar file, suggest it to the assistant
 92        if (similarFilename) {
 93          message += ` Did you mean ${similarFilename}?`
 94        }
 95  
 96        return {
 97          result: false,
 98          message,
 99        }
100      }
101  
102      if (extname(fullFilePath) !== '.ipynb') {
103        return {
104          result: false,
105          message: 'File must be a Jupyter notebook (.ipynb file).',
106        }
107      }
108  
109      return { result: true }
110    },
111    renderToolUseMessage(input, { verbose }) {
112      return `notebook_path: ${verbose ? input.notebook_path : relative(getCwd(), input.notebook_path)}`
113    },
114    renderToolUseRejectedMessage() {
115      return <FallbackToolUseRejectedMessage />
116    },
117  
118    renderToolResultMessage(content) {
119      if (!content) {
120        return <Text>No cells found in notebook</Text>
121      }
122      if (content.length < 1 || !content[0]) {
123        return <Text>No cells found in notebook</Text>
124      }
125      return <Text>Read {content.length} cells</Text>
126    },
127    async *call({ notebook_path }) {
128      const fullPath = isAbsolute(notebook_path)
129        ? notebook_path
130        : resolve(getCwd(), notebook_path)
131  
132      const content = readFileSync(fullPath, 'utf-8')
133      const notebook = JSON.parse(content) as NotebookContent
134      const language = notebook.metadata.language_info?.name ?? 'python'
135      const cells = notebook.cells.map((cell, index) =>
136        processCell(cell, index, language),
137      )
138  
139      yield {
140        type: 'result',
141        resultForAssistant: renderResultForAssistant(cells),
142        data: cells,
143      }
144    },
145    renderResultForAssistant,
146  } satisfies Tool<In, Out>
147  
148  function processOutputText(text: string | string[] | undefined): string {
149    if (!text) return ''
150    const rawText = Array.isArray(text) ? text.join('') : text
151    const { truncatedContent } = formatOutput(rawText)
152    return truncatedContent
153  }
154  
155  function extractImage(
156    data: Record<string, unknown>,
157  ): NotebookOutputImage | undefined {
158    if (typeof data['image/png'] === 'string') {
159      return {
160        image_data: data['image/png'] as string,
161        media_type: 'image/png',
162      }
163    }
164    if (typeof data['image/jpeg'] === 'string') {
165      return {
166        image_data: data['image/jpeg'] as string,
167        media_type: 'image/jpeg',
168      }
169    }
170    return undefined
171  }
172  
173  function processOutput(output: NotebookCellOutput) {
174    switch (output.output_type) {
175      case 'stream':
176        return {
177          output_type: output.output_type,
178          text: processOutputText(output.text),
179        }
180      case 'execute_result':
181      case 'display_data':
182        return {
183          output_type: output.output_type,
184          text: processOutputText(output.data?.['text/plain']),
185          image: output.data && extractImage(output.data),
186        }
187      case 'error':
188        return {
189          output_type: output.output_type,
190          text: processOutputText(
191            `${output.ename}: ${output.evalue}\n${output.traceback.join('\n')}`,
192          ),
193        }
194    }
195  }
196  
197  function processCell(
198    cell: NotebookCell,
199    index: number,
200    language: string,
201  ): NotebookCellSource {
202    const cellData: NotebookCellSource = {
203      cell: index,
204      cellType: cell.cell_type,
205      source: Array.isArray(cell.source) ? cell.source.join('') : cell.source,
206      language,
207      execution_count: cell.execution_count,
208    }
209  
210    if (cell.outputs?.length) {
211      cellData.outputs = cell.outputs.map(processOutput)
212    }
213  
214    return cellData
215  }
216  
217  function cellContentToToolResult(cell: NotebookCellSource): TextBlockParam {
218    const metadata = []
219    if (cell.cellType !== 'code') {
220      metadata.push(`<cell_type>${cell.cellType}</cell_type>`)
221    }
222    if (cell.language !== 'python' && cell.cellType === 'code') {
223      metadata.push(`<language>${cell.language}</language>`)
224    }
225    const cellContent = `<cell ${cell.cell}>${metadata.join('')}${cell.source}</cell ${cell.cell}>`
226    return {
227      text: cellContent,
228      type: 'text',
229    }
230  }
231  
232  function cellOutputToToolResult(output: NotebookCellSourceOutput) {
233    const outputs: (TextBlockParam | ImageBlockParam)[] = []
234    if (output.text) {
235      outputs.push({
236        text: `\n${output.text}`,
237        type: 'text',
238      })
239    }
240    if (output.image) {
241      outputs.push({
242        type: 'image',
243        source: {
244          data: output.image.image_data,
245          media_type: output.image.media_type,
246          type: 'base64',
247        },
248      })
249    }
250    return outputs
251  }
252  
253  function getToolResultFromCell(cell: NotebookCellSource) {
254    const contentResult = cellContentToToolResult(cell)
255    const outputResults = cell.outputs?.flatMap(cellOutputToToolResult)
256    return [contentResult, ...(outputResults ?? [])]
257  }
258  
259  export function isNotebookCellType(
260    value: string | null,
261  ): value is NotebookCellType {
262    return value === 'code' || value === 'markdown'
263  }