workflow.py
1 import json 2 import os 3 from typing import Any, Sequence 4 5 from langchain_core.language_models import LanguageModelLike 6 from langchain_core.messages import AIMessage, ToolCall 7 from langchain_core.outputs import ChatGeneration, ChatResult 8 from langchain_core.runnables import RunnableConfig, RunnableLambda 9 from langchain_core.tools import BaseTool, tool 10 from langchain_openai import ChatOpenAI 11 from langgraph.graph import END, StateGraph 12 from langgraph.graph.state import CompiledStateGraph 13 from langgraph.prebuilt import ToolNode 14 15 import mlflow 16 from mlflow.langchain.chat_agent_langgraph import ( 17 ChatAgentState, 18 ChatAgentToolNode, 19 ) 20 21 os.environ["OPENAI_API_KEY"] = "test" 22 23 24 class FakeOpenAI(ChatOpenAI, extra="allow"): 25 def __init__(self, *args, **kwargs): 26 super().__init__(*args, **kwargs) 27 28 self._responses = iter([ 29 AIMessage( 30 content="", 31 tool_calls=[ToolCall(name="uc_tool_format", args={}, id="123")], 32 ), 33 AIMessage( 34 content="", 35 tool_calls=[ToolCall(name="lc_tool_format", args={}, id="456")], 36 ), 37 AIMessage(content="Successfully generated", id="789"), 38 ]) 39 40 def _generate(self, *args, **kwargs): 41 return ChatResult(generations=[ChatGeneration(message=next(self._responses))]) 42 43 44 @tool 45 def uc_tool_format() -> str: 46 """Returns uc tool format""" 47 return json.dumps({ 48 "format": "SCALAR", 49 "value": '{"content":"hi","attachments":{"a":"b"},"custom_outputs":{"c":"d"}}', 50 "truncated": False, 51 }) 52 53 54 @tool 55 def lc_tool_format() -> dict[str, Any]: 56 """Returns lc tool format""" 57 nums = [1, 2] 58 return { 59 "content": f"Successfully generated array of 2 random ints: {nums}.", 60 "attachments": {"key1": "attach1", "key2": "attach2"}, 61 "custom_outputs": {"random_nums": nums}, 62 } 63 64 65 tools = [uc_tool_format, lc_tool_format] 66 67 68 def create_tool_calling_agent( 69 model: LanguageModelLike, 70 tools: ToolNode | Sequence[BaseTool], 71 agent_prompt: str | None = None, 72 ) -> CompiledStateGraph: 73 model = model.bind_tools(tools) 74 75 def should_continue(state: ChatAgentState): 76 messages = state["messages"] 77 last_message = messages[-1] 78 # If there are function calls, continue. else, end 79 if last_message.get("tool_calls"): 80 return "continue" 81 else: 82 return "end" 83 84 preprocessor = RunnableLambda(lambda state: state["messages"]) 85 model_runnable = preprocessor | model 86 87 @mlflow.trace 88 def call_model( 89 state: ChatAgentState, 90 config: RunnableConfig, 91 ): 92 response = model_runnable.invoke(state, config) 93 94 return {"messages": [response]} 95 96 workflow = StateGraph(ChatAgentState) 97 98 workflow.add_node("agent", RunnableLambda(call_model)) 99 workflow.add_node("tools", ChatAgentToolNode(tools)) 100 101 workflow.set_entry_point("agent") 102 workflow.add_conditional_edges( 103 "agent", 104 should_continue, 105 { 106 "continue": "tools", 107 "end": END, 108 }, 109 ) 110 workflow.add_edge("tools", "agent") 111 112 return workflow.compile() 113 114 115 mlflow.langchain.autolog() 116 llm = FakeOpenAI() 117 graph = create_tool_calling_agent(llm, tools) 118 119 mlflow.models.set_model(graph)