chain.py
1 from operator import itemgetter 2 from typing import Any 3 4 from langchain.agents import AgentExecutor, tool 5 from langchain.agents.output_parsers.tools import ToolsAgentOutputParser 6 from langchain.callbacks.manager import CallbackManagerForLLMRun 7 from langchain.chat_models.base import SimpleChatModel 8 from langchain.prompts import PromptTemplate 9 from langchain.schema.messages import BaseMessage 10 from langchain.schema.runnable import RunnableLambda 11 12 from mlflow.models import ModelConfig, set_model 13 14 base_config = ModelConfig(development_config="tests/langchain/agent_executor/config.yml") 15 16 prompt_with_history = PromptTemplate( 17 input_variables=["chat_history", "question"], 18 template=base_config.get("prompt_with_history_str"), 19 ) 20 21 22 def extract_question(input): 23 return input[-1]["content"] 24 25 26 def extract_history(input): 27 return input[:-1] 28 29 30 @tool 31 def custom_tool(query: str): 32 """ 33 Mock a tool 34 """ 35 return "Databricks" 36 37 38 class FakeChatModel(SimpleChatModel): 39 """Fake Chat Model wrapper for testing purposes.""" 40 41 endpoint_name: str = "fake-endpoint" 42 43 def _call( 44 self, 45 messages: list[BaseMessage], 46 stop: list[str] | None = None, 47 run_manager: CallbackManagerForLLMRun | None = None, 48 **kwargs: Any, 49 ) -> str: 50 return "Databricks" 51 52 @property 53 def _llm_type(self) -> str: 54 return "fake chat model" 55 56 57 fake_chat_model = FakeChatModel() 58 llm_with_tools = fake_chat_model.bind(tools=[custom_tool]) 59 agent = ( 60 { 61 "question": itemgetter("messages") | RunnableLambda(extract_question), 62 "chat_history": itemgetter("messages") | RunnableLambda(extract_history), 63 } 64 | prompt_with_history 65 | llm_with_tools 66 | ToolsAgentOutputParser() 67 ) 68 69 model = AgentExecutor(agent=agent, tools=[custom_tool]) 70 set_model(model)