openai_agent.py
1 import itertools 2 3 from langchain.agents import create_agent 4 from langchain.tools import tool 5 from langchain_core.messages import AIMessageChunk, ToolCall 6 from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult 7 from langchain_openai import ChatOpenAI 8 9 import mlflow 10 11 12 class FakeOpenAI(ChatOpenAI, extra="allow"): 13 # In normal LangChain tests, we use the fake OpenAI server to mock the OpenAI REST API. 14 # The fake server returns the input payload as it is. However, for agent tests, the 15 # response should be a specific format so that the agent can parse it correctly. 16 # Also, mocking with mock.patch does not work for testing model serving (as the server 17 # will run in a separate process). 18 # Therefore, we mock the OpenAI client in the model definition here. 19 def __init__(self, *args, **kwargs): 20 super().__init__(*args, **kwargs) 21 # Using itertools.cycle to create an infinite iterator 22 self._responses = itertools.cycle([ 23 AIMessageChunk( 24 content="", 25 tool_calls=[ToolCall(name="multiply", args={"a": 2, "b": 3}, id="123")], 26 ), 27 AIMessageChunk(content="The result of 2 * 3 is 6."), 28 ]) 29 30 def _generate(self, *args, **kwargs): 31 return ChatResult(generations=[ChatGeneration(message=next(self._responses))]) 32 33 def _stream(self, *args, **kwargs): 34 yield ChatGenerationChunk(message=next(self._responses)) 35 36 37 @tool 38 def add(a: int, b: int) -> int: 39 """Add two numbers.""" 40 return a + b 41 42 43 @tool 44 def multiply(a: int, b: int) -> int: 45 """Multiply two numbers.""" 46 return a * b 47 48 49 llm = FakeOpenAI() 50 agent = create_agent(llm, [add, multiply], system_prompt="You are a helpful assistant") 51 mlflow.models.set_model(agent)