/ chat_workflow / tools / __init__.py
__init__.py
 1  import chainlit as cl
 2  import json
 3  from typing import List, Dict, Optional
 4  from langchain_core.messages import ToolMessage
 5  from langchain_core.runnables import RunnableConfig, Runnable
 6  
 7  
 8  class BasicToolNode(Runnable):
 9      """A node that runs the tools requested in the last AIMessage."""
10  
11      def __init__(self, tools: List) -> None:
12          self.tools_by_name = {tool.__name__: tool for tool in tools}
13  
14      async def ainvoke(self, inputs: Dict, config: Optional[RunnableConfig] = None) -> Dict:
15          if messages := inputs.get("messages", []):
16              message = messages[-1]
17          else:
18              raise ValueError("No message found in input")
19          outputs = []
20          for tool_call in message.tool_calls:
21              async with cl.Step(f"tool [{tool_call['name']}]") as step:
22                  tool_result = await self.tools_by_name[tool_call["name"]](**tool_call["args"])
23                  outputs.append(
24                      ToolMessage(
25                          content=json.dumps(tool_result),
26                          name=tool_call["name"],
27                          tool_call_id=tool_call["id"],
28                      )
29                  )
30                  await step.remove()
31          return {"messages": outputs}
32  
33      def invoke(self, input: Dict, config: Optional[RunnableConfig] = None) -> Dict:
34          raise NotImplementedError(
35              "BasicToolNode only supports async invocation")