/ tests / langchain / sample_code / langchain_chat_agent.py
langchain_chat_agent.py
  1  from operator import itemgetter
  2  from typing import Any, Generator
  3  
  4  from langchain_core.messages import AIMessage, AIMessageChunk
  5  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
  6  from langchain_core.prompts import ChatPromptTemplate
  7  from langchain_core.runnables import RunnableLambda
  8  from langchain_core.runnables.base import Runnable
  9  from langchain_openai import ChatOpenAI
 10  
 11  import mlflow
 12  from mlflow.langchain.output_parsers import ChatAgentOutputParser
 13  from mlflow.pyfunc.model import ChatAgent
 14  from mlflow.types.agent import ChatAgentChunk, ChatAgentMessage, ChatAgentResponse, ChatContext
 15  
 16  
 17  class FakeOpenAI(ChatOpenAI, extra="allow"):
 18      def __init__(self, *args, **kwargs):
 19          super().__init__(*args, **kwargs)
 20  
 21          self._responses = iter([AIMessage(content="1")])
 22          self._stream_responses = iter([
 23              AIMessageChunk(content="1"),
 24              AIMessageChunk(content="2"),
 25              AIMessageChunk(content="3"),
 26          ])
 27  
 28      def _generate(self, *args, **kwargs):
 29          return ChatResult(generations=[ChatGeneration(message=next(self._responses))])
 30  
 31      def _stream(self, *args, **kwargs):
 32          for r in self._stream_responses:
 33              yield ChatGenerationChunk(message=r)
 34  
 35  
 36  mlflow.langchain.autolog()
 37  
 38  
 39  # Helper functions
 40  def extract_user_query_string(messages):
 41      return messages[-1]["content"]
 42  
 43  
 44  def extract_chat_history(messages):
 45      return messages[:-1]
 46  
 47  
 48  # Define components
 49  prompt = ChatPromptTemplate.from_template(
 50      """Previous conversation:
 51  {chat_history}
 52  
 53  User's question:
 54  {question}"""
 55  )
 56  
 57  model = FakeOpenAI()
 58  output_parser = ChatAgentOutputParser()
 59  
 60  # Chain definition
 61  chain = (
 62      {
 63          "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
 64          "chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history),
 65      }
 66      | prompt
 67      | model
 68      | output_parser
 69  )
 70  
 71  
 72  class LangChainChatAgent(ChatAgent):
 73      """
 74      Helper class to wrap a LangChain runnable as a :py:class:`ChatAgent <mlflow.pyfunc.ChatAgent>`.
 75      Use this class with
 76      :py:class:`ChatAgentOutputParser <mlflow.langchain.output_parsers.ChatAgentOutputParser>`.
 77      """
 78  
 79      def __init__(self, agent: Runnable):
 80          self.agent = agent
 81  
 82      def predict(
 83          self,
 84          messages: list[ChatAgentMessage],
 85          context: ChatContext | None = None,
 86          custom_inputs: dict[str, Any] | None = None,
 87      ) -> ChatAgentResponse:
 88          response = self.agent.invoke({"messages": self._convert_messages_to_dict(messages)})
 89          return ChatAgentResponse(**response)
 90  
 91      def predict_stream(
 92          self,
 93          messages: list[ChatAgentMessage],
 94          context: ChatContext | None = None,
 95          custom_inputs: dict[str, Any] | None = None,
 96      ) -> Generator[ChatAgentChunk, None, None]:
 97          for event in self.agent.stream({"messages": self._convert_messages_to_dict(messages)}):
 98              yield ChatAgentChunk(**event)
 99  
100  
101  chat_agent = LangChainChatAgent(chain)
102  
103  mlflow.models.set_model(chat_agent)