/ tests / langchain / sample_code / openai_agent.py
openai_agent.py
 1  import itertools
 2  
 3  from langchain.agents import create_agent
 4  from langchain.tools import tool
 5  from langchain_core.messages import AIMessageChunk, ToolCall
 6  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
 7  from langchain_openai import ChatOpenAI
 8  
 9  import mlflow
10  
11  
12  class FakeOpenAI(ChatOpenAI, extra="allow"):
13      # In normal LangChain tests, we use the fake OpenAI server to mock the OpenAI REST API.
14      # The fake server returns the input payload as it is. However, for agent tests, the
15      # response should be a specific format so that the agent can parse it correctly.
16      # Also, mocking with mock.patch does not work for testing model serving (as the server
17      # will run in a separate process).
18      # Therefore, we mock the OpenAI client in the model definition here.
19      def __init__(self, *args, **kwargs):
20          super().__init__(*args, **kwargs)
21          # Using itertools.cycle to create an infinite iterator
22          self._responses = itertools.cycle([
23              AIMessageChunk(
24                  content="",
25                  tool_calls=[ToolCall(name="multiply", args={"a": 2, "b": 3}, id="123")],
26              ),
27              AIMessageChunk(content="The result of 2 * 3 is 6."),
28          ])
29  
30      def _generate(self, *args, **kwargs):
31          return ChatResult(generations=[ChatGeneration(message=next(self._responses))])
32  
33      def _stream(self, *args, **kwargs):
34          yield ChatGenerationChunk(message=next(self._responses))
35  
36  
37  @tool
38  def add(a: int, b: int) -> int:
39      """Add two numbers."""
40      return a + b
41  
42  
43  @tool
44  def multiply(a: int, b: int) -> int:
45      """Multiply two numbers."""
46      return a * b
47  
48  
49  llm = FakeOpenAI()
50  agent = create_agent(llm, [add, multiply], system_prompt="You are a helpful assistant")
51  mlflow.models.set_model(agent)