retriever_chain.py
1 """Chain for wrapping a retriever.""" 2 3 from __future__ import annotations 4 5 import json 6 from pathlib import Path 7 from typing import Any 8 9 import yaml 10 from pydantic import ConfigDict, Field 11 12 from mlflow.langchain._compat import ( 13 import_async_callback_manager_for_chain_run, 14 import_base_retriever, 15 import_callback_manager_for_chain_run, 16 import_document, 17 try_import_chain, 18 ) 19 20 AsyncCallbackManagerForChainRun = import_async_callback_manager_for_chain_run() 21 CallbackManagerForChainRun = import_callback_manager_for_chain_run() 22 BaseRetriever = import_base_retriever() 23 Document = import_document() 24 Chain = try_import_chain() 25 26 if Chain is None: 27 raise ImportError( 28 "Chain class not found. MLflow's retriever_chain functionality requires langchain<1.0.0. " 29 "For langchain 1.0.0+, please use LangGraph instead." 30 ) 31 32 33 class _RetrieverChain(Chain): 34 """ 35 Chain that wraps a retriever for use with MLflow. 36 37 The MLflow ``langchain`` flavor provides the functionality to log a retriever object and 38 evaluate it individually. This is useful if you want to evaluate the quality of the 39 relevant documents returned by a retriever object without directing these documents 40 through a large language model (LLM) to yield a summarized response. 41 42 In order to log the retriever object in the ``langchain`` flavor, the retriever object 43 needs to be wrapped within a ``_RetrieverChain``. 44 45 See ``examples/langchain/retriever_chain.py`` for how to log the ``_RetrieverChain``. 46 47 Args: 48 retriever: The retriever to wrap. 49 """ 50 51 input_key: str = "query" 52 output_key: str = "source_documents" 53 retriever: BaseRetriever = Field(exclude=True) 54 55 model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) 56 57 @property 58 def input_keys(self) -> list[str]: 59 """Return the input keys.""" 60 return [self.input_key] 61 62 @property 63 def output_keys(self) -> list[str]: 64 """Return the output keys.""" 65 return [self.output_key] 66 67 def _get_docs(self, question: str) -> list[Document]: 68 """Get documents from the retriever.""" 69 return self.retriever.get_relevant_documents(question) 70 71 def _call( 72 self, 73 inputs: dict[str, Any], 74 run_manager: CallbackManagerForChainRun | None = None, 75 ) -> dict[str, Any]: 76 """Run _get_docs on input query. 77 Returns the retrieved documents under the key 'source_documents'. 78 79 Example: 80 81 .. code-block:: python 82 83 chain = _RetrieverChain(retriever=...) 84 res = chain({"query": "This is my query"}) 85 docs = res["source_documents"] 86 """ 87 question = inputs[self.input_key] 88 docs = self._get_docs(question) 89 list_of_str_page_content = [doc.page_content for doc in docs] 90 return {self.output_key: json.dumps(list_of_str_page_content)} 91 92 async def _aget_docs(self, question: str) -> list[Document]: 93 """Get documents from the retriever.""" 94 return await self.retriever.aget_relevant_documents(question) 95 96 async def _acall( 97 self, 98 inputs: dict[str, Any], 99 run_manager: AsyncCallbackManagerForChainRun | None = None, 100 ) -> dict[str, Any]: 101 """Run _get_docs on input query. 102 Returns the retrieved documents under the key 'source_documents'. 103 104 Example: 105 106 .. code-block:: python 107 108 chain = _RetrieverChain(retriever=...) 109 res = chain({"query": "This is my query"}) 110 docs = res["source_documents"] 111 """ 112 question = inputs[self.input_key] 113 docs = await self._aget_docs(question) 114 list_of_str_page_content = [doc.page_content for doc in docs] 115 return {self.output_key: json.dumps(list_of_str_page_content)} 116 117 @property 118 def _chain_type(self) -> str: 119 """Return the chain type.""" 120 return "retriever_chain" 121 122 @classmethod 123 def load(cls, file: str | Path, **kwargs: Any) -> _RetrieverChain: 124 """Load a _RetrieverChain from a file.""" 125 # Convert file to Path object. 126 file_path = Path(file) if isinstance(file, str) else file 127 # Load from either json or yaml. 128 if file_path.suffix == ".json": 129 with open(file_path) as f: 130 config = json.load(f) 131 elif file_path.suffix in (".yaml", ".yml"): 132 with open(file_path) as f: 133 # This is to ignore certain tags that are not supported 134 # with pydantic >= 2.0 135 yaml.add_multi_constructor( 136 "tag:yaml.org,2002:python/object", 137 lambda loader, suffix, node: None, 138 Loader=yaml.SafeLoader, 139 ) 140 config = yaml.load(f, yaml.SafeLoader) 141 else: 142 raise ValueError("File type must be json or yaml") 143 144 # Override default 'verbose' and 'memory' for the chain 145 if verbose := kwargs.pop("verbose", None): 146 config["verbose"] = verbose 147 if memory := kwargs.pop("memory", None): 148 config["memory"] = memory 149 150 if "_type" not in config: 151 raise ValueError("Must specify a chain Type in config") 152 config_type = config.pop("_type") 153 154 if config_type != "retriever_chain": 155 raise ValueError(f"Loading {config_type} chain not supported") 156 157 retriever = kwargs.pop("retriever", None) 158 if retriever is None: 159 raise ValueError("`retriever` must be present.") 160 161 config.pop("retriever", None) 162 163 return cls( 164 retriever=retriever, 165 **config, 166 )