/ mlflow / langchain / retriever_chain.py
retriever_chain.py
  1  """Chain for wrapping a retriever."""
  2  
  3  from __future__ import annotations
  4  
  5  import json
  6  from pathlib import Path
  7  from typing import Any
  8  
  9  import yaml
 10  from pydantic import ConfigDict, Field
 11  
 12  from mlflow.langchain._compat import (
 13      import_async_callback_manager_for_chain_run,
 14      import_base_retriever,
 15      import_callback_manager_for_chain_run,
 16      import_document,
 17      try_import_chain,
 18  )
 19  
 20  AsyncCallbackManagerForChainRun = import_async_callback_manager_for_chain_run()
 21  CallbackManagerForChainRun = import_callback_manager_for_chain_run()
 22  BaseRetriever = import_base_retriever()
 23  Document = import_document()
 24  Chain = try_import_chain()
 25  
 26  if Chain is None:
 27      raise ImportError(
 28          "Chain class not found. MLflow's retriever_chain functionality requires langchain<1.0.0. "
 29          "For langchain 1.0.0+, please use LangGraph instead."
 30      )
 31  
 32  
 33  class _RetrieverChain(Chain):
 34      """
 35      Chain that wraps a retriever for use with MLflow.
 36  
 37      The MLflow ``langchain`` flavor provides the functionality to log a retriever object and
 38      evaluate it individually. This is useful if you want to evaluate the quality of the
 39      relevant documents returned by a retriever object without directing these documents
 40      through a large language model (LLM) to yield a summarized response.
 41  
 42      In order to log the retriever object in the ``langchain`` flavor, the retriever object
 43      needs to be wrapped within a ``_RetrieverChain``.
 44  
 45      See ``examples/langchain/retriever_chain.py`` for how to log the ``_RetrieverChain``.
 46  
 47      Args:
 48          retriever: The retriever to wrap.
 49      """
 50  
 51      input_key: str = "query"
 52      output_key: str = "source_documents"
 53      retriever: BaseRetriever = Field(exclude=True)
 54  
 55      model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
 56  
 57      @property
 58      def input_keys(self) -> list[str]:
 59          """Return the input keys."""
 60          return [self.input_key]
 61  
 62      @property
 63      def output_keys(self) -> list[str]:
 64          """Return the output keys."""
 65          return [self.output_key]
 66  
 67      def _get_docs(self, question: str) -> list[Document]:
 68          """Get documents from the retriever."""
 69          return self.retriever.get_relevant_documents(question)
 70  
 71      def _call(
 72          self,
 73          inputs: dict[str, Any],
 74          run_manager: CallbackManagerForChainRun | None = None,
 75      ) -> dict[str, Any]:
 76          """Run _get_docs on input query.
 77          Returns the retrieved documents under the key 'source_documents'.
 78  
 79          Example:
 80  
 81          .. code-block:: python
 82  
 83              chain = _RetrieverChain(retriever=...)
 84              res = chain({"query": "This is my query"})
 85              docs = res["source_documents"]
 86          """
 87          question = inputs[self.input_key]
 88          docs = self._get_docs(question)
 89          list_of_str_page_content = [doc.page_content for doc in docs]
 90          return {self.output_key: json.dumps(list_of_str_page_content)}
 91  
 92      async def _aget_docs(self, question: str) -> list[Document]:
 93          """Get documents from the retriever."""
 94          return await self.retriever.aget_relevant_documents(question)
 95  
 96      async def _acall(
 97          self,
 98          inputs: dict[str, Any],
 99          run_manager: AsyncCallbackManagerForChainRun | None = None,
100      ) -> dict[str, Any]:
101          """Run _get_docs on input query.
102          Returns the retrieved documents under the key 'source_documents'.
103  
104          Example:
105  
106          .. code-block:: python
107  
108              chain = _RetrieverChain(retriever=...)
109              res = chain({"query": "This is my query"})
110              docs = res["source_documents"]
111          """
112          question = inputs[self.input_key]
113          docs = await self._aget_docs(question)
114          list_of_str_page_content = [doc.page_content for doc in docs]
115          return {self.output_key: json.dumps(list_of_str_page_content)}
116  
117      @property
118      def _chain_type(self) -> str:
119          """Return the chain type."""
120          return "retriever_chain"
121  
122      @classmethod
123      def load(cls, file: str | Path, **kwargs: Any) -> _RetrieverChain:
124          """Load a _RetrieverChain from a file."""
125          # Convert file to Path object.
126          file_path = Path(file) if isinstance(file, str) else file
127          # Load from either json or yaml.
128          if file_path.suffix == ".json":
129              with open(file_path) as f:
130                  config = json.load(f)
131          elif file_path.suffix in (".yaml", ".yml"):
132              with open(file_path) as f:
133                  # This is to ignore certain tags that are not supported
134                  # with pydantic >= 2.0
135                  yaml.add_multi_constructor(
136                      "tag:yaml.org,2002:python/object",
137                      lambda loader, suffix, node: None,
138                      Loader=yaml.SafeLoader,
139                  )
140                  config = yaml.load(f, yaml.SafeLoader)
141          else:
142              raise ValueError("File type must be json or yaml")
143  
144          # Override default 'verbose' and 'memory' for the chain
145          if verbose := kwargs.pop("verbose", None):
146              config["verbose"] = verbose
147          if memory := kwargs.pop("memory", None):
148              config["memory"] = memory
149  
150          if "_type" not in config:
151              raise ValueError("Must specify a chain Type in config")
152          config_type = config.pop("_type")
153  
154          if config_type != "retriever_chain":
155              raise ValueError(f"Loading {config_type} chain not supported")
156  
157          retriever = kwargs.pop("retriever", None)
158          if retriever is None:
159              raise ValueError("`retriever` must be present.")
160  
161          config.pop("retriever", None)
162  
163          return cls(
164              retriever=retriever,
165              **config,
166          )