chain.py
 1  from operator import itemgetter
 2  from typing import Any
 3  
 4  from langchain.agents import AgentExecutor, tool
 5  from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
 6  from langchain.callbacks.manager import CallbackManagerForLLMRun
 7  from langchain.chat_models.base import SimpleChatModel
 8  from langchain.prompts import PromptTemplate
 9  from langchain.schema.messages import BaseMessage
10  from langchain.schema.runnable import RunnableLambda
11  
12  from mlflow.models import ModelConfig, set_model
13  
14  base_config = ModelConfig(development_config="tests/langchain/agent_executor/config.yml")
15  
16  prompt_with_history = PromptTemplate(
17      input_variables=["chat_history", "question"],
18      template=base_config.get("prompt_with_history_str"),
19  )
20  
21  
22  def extract_question(input):
23      return input[-1]["content"]
24  
25  
26  def extract_history(input):
27      return input[:-1]
28  
29  
30  @tool
31  def custom_tool(query: str):
32      """
33      Mock a tool
34      """
35      return "Databricks"
36  
37  
38  class FakeChatModel(SimpleChatModel):
39      """Fake Chat Model wrapper for testing purposes."""
40  
41      endpoint_name: str = "fake-endpoint"
42  
43      def _call(
44          self,
45          messages: list[BaseMessage],
46          stop: list[str] | None = None,
47          run_manager: CallbackManagerForLLMRun | None = None,
48          **kwargs: Any,
49      ) -> str:
50          return "Databricks"
51  
52      @property
53      def _llm_type(self) -> str:
54          return "fake chat model"
55  
56  
57  fake_chat_model = FakeChatModel()
58  llm_with_tools = fake_chat_model.bind(tools=[custom_tool])
59  agent = (
60      {
61          "question": itemgetter("messages") | RunnableLambda(extract_question),
62          "chat_history": itemgetter("messages") | RunnableLambda(extract_history),
63      }
64      | prompt_with_history
65      | llm_with_tools
66      | ToolsAgentOutputParser()
67  )
68  
69  model = AgentExecutor(agent=agent, tools=[custom_tool])
70  set_model(model)