langchain_chat_agent.py
1 from operator import itemgetter 2 from typing import Any, Generator 3 4 from langchain_core.messages import AIMessage, AIMessageChunk 5 from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult 6 from langchain_core.prompts import ChatPromptTemplate 7 from langchain_core.runnables import RunnableLambda 8 from langchain_core.runnables.base import Runnable 9 from langchain_openai import ChatOpenAI 10 11 import mlflow 12 from mlflow.langchain.output_parsers import ChatAgentOutputParser 13 from mlflow.pyfunc.model import ChatAgent 14 from mlflow.types.agent import ChatAgentChunk, ChatAgentMessage, ChatAgentResponse, ChatContext 15 16 17 class FakeOpenAI(ChatOpenAI, extra="allow"): 18 def __init__(self, *args, **kwargs): 19 super().__init__(*args, **kwargs) 20 21 self._responses = iter([AIMessage(content="1")]) 22 self._stream_responses = iter([ 23 AIMessageChunk(content="1"), 24 AIMessageChunk(content="2"), 25 AIMessageChunk(content="3"), 26 ]) 27 28 def _generate(self, *args, **kwargs): 29 return ChatResult(generations=[ChatGeneration(message=next(self._responses))]) 30 31 def _stream(self, *args, **kwargs): 32 for r in self._stream_responses: 33 yield ChatGenerationChunk(message=r) 34 35 36 mlflow.langchain.autolog() 37 38 39 # Helper functions 40 def extract_user_query_string(messages): 41 return messages[-1]["content"] 42 43 44 def extract_chat_history(messages): 45 return messages[:-1] 46 47 48 # Define components 49 prompt = ChatPromptTemplate.from_template( 50 """Previous conversation: 51 {chat_history} 52 53 User's question: 54 {question}""" 55 ) 56 57 model = FakeOpenAI() 58 output_parser = ChatAgentOutputParser() 59 60 # Chain definition 61 chain = ( 62 { 63 "question": itemgetter("messages") | RunnableLambda(extract_user_query_string), 64 "chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history), 65 } 66 | prompt 67 | model 68 | output_parser 69 ) 70 71 72 class LangChainChatAgent(ChatAgent): 73 """ 74 Helper class to wrap a LangChain runnable as a :py:class:`ChatAgent <mlflow.pyfunc.ChatAgent>`. 75 Use this class with 76 :py:class:`ChatAgentOutputParser <mlflow.langchain.output_parsers.ChatAgentOutputParser>`. 77 """ 78 79 def __init__(self, agent: Runnable): 80 self.agent = agent 81 82 def predict( 83 self, 84 messages: list[ChatAgentMessage], 85 context: ChatContext | None = None, 86 custom_inputs: dict[str, Any] | None = None, 87 ) -> ChatAgentResponse: 88 response = self.agent.invoke({"messages": self._convert_messages_to_dict(messages)}) 89 return ChatAgentResponse(**response) 90 91 def predict_stream( 92 self, 93 messages: list[ChatAgentMessage], 94 context: ChatContext | None = None, 95 custom_inputs: dict[str, Any] | None = None, 96 ) -> Generator[ChatAgentChunk, None, None]: 97 for event in self.agent.stream({"messages": self._convert_messages_to_dict(messages)}): 98 yield ChatAgentChunk(**event) 99 100 101 chat_agent = LangChainChatAgent(chain) 102 103 mlflow.models.set_model(chat_agent)