/ tests / langchain / sample_code / workflow.py
workflow.py
  1  import json
  2  import os
  3  from typing import Any, Sequence
  4  
  5  from langchain_core.language_models import LanguageModelLike
  6  from langchain_core.messages import AIMessage, ToolCall
  7  from langchain_core.outputs import ChatGeneration, ChatResult
  8  from langchain_core.runnables import RunnableConfig, RunnableLambda
  9  from langchain_core.tools import BaseTool, tool
 10  from langchain_openai import ChatOpenAI
 11  from langgraph.graph import END, StateGraph
 12  from langgraph.graph.state import CompiledStateGraph
 13  from langgraph.prebuilt import ToolNode
 14  
 15  import mlflow
 16  from mlflow.langchain.chat_agent_langgraph import (
 17      ChatAgentState,
 18      ChatAgentToolNode,
 19  )
 20  
 21  os.environ["OPENAI_API_KEY"] = "test"
 22  
 23  
 24  class FakeOpenAI(ChatOpenAI, extra="allow"):
 25      def __init__(self, *args, **kwargs):
 26          super().__init__(*args, **kwargs)
 27  
 28          self._responses = iter([
 29              AIMessage(
 30                  content="",
 31                  tool_calls=[ToolCall(name="uc_tool_format", args={}, id="123")],
 32              ),
 33              AIMessage(
 34                  content="",
 35                  tool_calls=[ToolCall(name="lc_tool_format", args={}, id="456")],
 36              ),
 37              AIMessage(content="Successfully generated", id="789"),
 38          ])
 39  
 40      def _generate(self, *args, **kwargs):
 41          return ChatResult(generations=[ChatGeneration(message=next(self._responses))])
 42  
 43  
 44  @tool
 45  def uc_tool_format() -> str:
 46      """Returns uc tool format"""
 47      return json.dumps({
 48          "format": "SCALAR",
 49          "value": '{"content":"hi","attachments":{"a":"b"},"custom_outputs":{"c":"d"}}',
 50          "truncated": False,
 51      })
 52  
 53  
 54  @tool
 55  def lc_tool_format() -> dict[str, Any]:
 56      """Returns lc tool format"""
 57      nums = [1, 2]
 58      return {
 59          "content": f"Successfully generated array of 2 random ints: {nums}.",
 60          "attachments": {"key1": "attach1", "key2": "attach2"},
 61          "custom_outputs": {"random_nums": nums},
 62      }
 63  
 64  
 65  tools = [uc_tool_format, lc_tool_format]
 66  
 67  
 68  def create_tool_calling_agent(
 69      model: LanguageModelLike,
 70      tools: ToolNode | Sequence[BaseTool],
 71      agent_prompt: str | None = None,
 72  ) -> CompiledStateGraph:
 73      model = model.bind_tools(tools)
 74  
 75      def should_continue(state: ChatAgentState):
 76          messages = state["messages"]
 77          last_message = messages[-1]
 78          # If there are function calls, continue. else, end
 79          if last_message.get("tool_calls"):
 80              return "continue"
 81          else:
 82              return "end"
 83  
 84      preprocessor = RunnableLambda(lambda state: state["messages"])
 85      model_runnable = preprocessor | model
 86  
 87      @mlflow.trace
 88      def call_model(
 89          state: ChatAgentState,
 90          config: RunnableConfig,
 91      ):
 92          response = model_runnable.invoke(state, config)
 93  
 94          return {"messages": [response]}
 95  
 96      workflow = StateGraph(ChatAgentState)
 97  
 98      workflow.add_node("agent", RunnableLambda(call_model))
 99      workflow.add_node("tools", ChatAgentToolNode(tools))
100  
101      workflow.set_entry_point("agent")
102      workflow.add_conditional_edges(
103          "agent",
104          should_continue,
105          {
106              "continue": "tools",
107              "end": END,
108          },
109      )
110      workflow.add_edge("tools", "agent")
111  
112      return workflow.compile()
113  
114  
115  mlflow.langchain.autolog()
116  llm = FakeOpenAI()
117  graph = create_tool_calling_agent(llm, tools)
118  
119  mlflow.models.set_model(graph)