/ tests / langchain / test_langchain_databricks_dependency_extraction.py
test_langchain_databricks_dependency_extraction.py
  1  from collections import Counter, defaultdict
  2  from unittest import mock
  3  
  4  import langchain
  5  import pytest
  6  from databricks.vector_search.client import VectorSearchIndex
  7  from packaging.version import Version
  8  
  9  from mlflow.langchain.databricks_dependencies import (
 10      _detect_databricks_dependencies,
 11      _extract_databricks_dependencies_from_chat_model,
 12      _extract_databricks_dependencies_from_llm,
 13      _extract_databricks_dependencies_from_retriever,
 14      _extract_dependency_list_from_lc_model,
 15  )
 16  from mlflow.models.resources import (
 17      DatabricksFunction,
 18      DatabricksServingEndpoint,
 19      DatabricksSQLWarehouse,
 20      DatabricksVectorSearchIndex,
 21  )
 22  
 23  # TODO: Remove this once databricks-langchain supports v1
 24  if Version(langchain.__version__).major >= 1:
 25      pytest.skip("databricks-langchain does not support v1 yet", allow_module_level=True)
 26  
 27  
 28  class MockDatabricksServingEndpointClient:
 29      def __init__(
 30          self,
 31          host: str,
 32          api_token: str,
 33          endpoint_name: str,
 34          databricks_uri: str,
 35          task: str,
 36      ):
 37          self.host = host
 38          self.api_token = api_token
 39          self.endpoint_name = endpoint_name
 40          self.databricks_uri = databricks_uri
 41          self.task = task
 42  
 43  
 44  def _is_partner_package_installed():
 45      try:
 46          import databricks_langchain  # noqa: F401
 47  
 48          return True
 49      except ImportError:
 50          return False
 51  
 52  
 53  def remove_langchain_community(monkeypatch):
 54      # Simulate the environment where langchain_community is not installed
 55      original_import = __import__
 56  
 57      def mock_import(name, *args, **kwargs):
 58          if name.startswith("langchain_community"):
 59              raise ImportError("No module named 'langchain_community'")
 60          return original_import(name, *args, **kwargs)
 61  
 62      monkeypatch.setattr("builtins.__import__", mock_import)
 63  
 64  
 65  def test_parsing_dependency_from_databricks_llm(monkeypatch: pytest.MonkeyPatch):
 66      from langchain_community.llms import Databricks
 67  
 68      from mlflow.langchain.utils.logging import IS_PICKLE_SERIALIZATION_RESTRICTED
 69  
 70      monkeypatch.setattr(
 71          "langchain_community.llms.databricks._DatabricksServingEndpointClient",
 72          MockDatabricksServingEndpointClient,
 73      )
 74      monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
 75      monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
 76  
 77      llm_kwargs = {"endpoint_name": "databricks-mixtral-8x7b-instruct"}
 78      if IS_PICKLE_SERIALIZATION_RESTRICTED:
 79          llm_kwargs["allow_dangerous_deserialization"] = True
 80  
 81      llm = Databricks(**llm_kwargs)
 82      resources = list(_extract_databricks_dependencies_from_llm(llm))
 83      assert resources == [
 84          DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct")
 85      ]
 86  
 87  
 88  class MockVectorSearchIndex(VectorSearchIndex):
 89      def __init__(self, endpoint_name, index_name, has_embedding_endpoint=False) -> None:
 90          self.endpoint_name = endpoint_name
 91          self.name = index_name
 92          self.has_embedding_endpoint = has_embedding_endpoint
 93  
 94      def describe(self):
 95          if self.has_embedding_endpoint:
 96              return {
 97                  "name": self.name,
 98                  "endpoint_name": self.endpoint_name,
 99                  "primary_key": "id",
100                  "index_type": "DELTA_SYNC",
101                  "delta_sync_index_spec": {
102                      "source_table": "ml.schema.databricks_documentation",
103                      "embedding_source_columns": [
104                          {"name": "content", "embedding_model_endpoint_name": "embedding-model"}
105                      ],
106                      "pipeline_type": "TRIGGERED",
107                      "pipeline_id": "79a76fcc-67ad-4ac6-8d8e-20f7d485ffa6",
108                  },
109                  "status": {
110                      "detailed_state": "OFFLINE_FAILED",
111                      "message": "Index creation failed.",
112                      "indexed_row_count": 0,
113                      "failed_status": {"error_message": ""},
114                      "ready": False,
115                      "index_url": "e2-dogfood.staging.cloud.databricks.com/rest_of_url",
116                  },
117                  "creator": "first.last@databricks.com",
118              }
119          else:
120              return {
121                  "name": self.name,
122                  "endpoint_name": self.endpoint_name,
123                  "primary_key": "id",
124                  "index_type": "DELTA_SYNC",
125                  "delta_sync_index_spec": {
126                      "source_table": "ml.schema.databricks_documentation",
127                      "embedding_vector_columns": [],
128                      "pipeline_type": "TRIGGERED",
129                      "pipeline_id": "fbbd5bf1-2b9b-4a7e-8c8d-c0f6cc1030de",
130                  },
131                  "status": {
132                      "detailed_state": "ONLINE",
133                      "message": "Index is currently online",
134                      "indexed_row_count": 17183,
135                      "ready": True,
136                      "index_url": "e2-dogfood.staging.cloud.databricks.com/rest_of_url",
137                  },
138                  "creator": "first.last@databricks.com",
139              }
140  
141  
142  def get_vector_search(
143      endpoint_name: str,
144      index_name: str,
145      has_embedding_endpoint=False,
146      **kwargs,
147  ):
148      index = MockVectorSearchIndex(endpoint_name, index_name, has_embedding_endpoint)
149  
150      from databricks_langchain import DatabricksVectorSearch
151  
152      with mock.patch("databricks.vector_search.client.VectorSearchClient") as mock_client:
153          mock_client().get_index.return_value = index
154          return DatabricksVectorSearch(
155              endpoint=endpoint_name,
156              index_name=index_name,
157              **kwargs,
158          )
159  
160  
161  def test_parsing_dependency_from_databricks_retriever(monkeypatch):
162      from databricks_langchain import ChatDatabricks, DatabricksEmbeddings
163  
164      remove_langchain_community(monkeypatch)
165      with pytest.raises(ImportError, match="No module named 'langchain_community"):
166          from langchain_community.embeddings import DatabricksEmbeddings
167  
168      embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
169  
170      # vs_index_1 is a direct access index
171      vectorstore_1 = get_vector_search(
172          endpoint_name="vs_endpoint",
173          index_name="mlflow.rag.vs_index_1",
174          text_column="content",
175          embedding=embedding_model,
176      )
177      retriever_1 = vectorstore_1.as_retriever()
178  
179      # vs_index_2 has builtin embedding endpoint "embedding-model"
180      vectorstore_2 = get_vector_search(
181          endpoint_name="vs_endpoint",
182          index_name="mlflow.rag.vs_index_2",
183          has_embedding_endpoint=True,
184      )
185      retriever_2 = vectorstore_2.as_retriever()
186  
187      llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", temperature=0)
188  
189      assert list(_extract_databricks_dependencies_from_retriever(retriever_1)) == [
190          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"),
191          DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
192      ]
193  
194      assert list(_extract_databricks_dependencies_from_retriever(retriever_2)) == [
195          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_2"),
196          DatabricksServingEndpoint(endpoint_name="embedding-model"),
197      ]
198  
199      try:
200          from langchain.retrievers import (
201              ContextualCompressionRetriever,
202              EnsembleRetriever,
203              TimeWeightedVectorStoreRetriever,
204          )
205          from langchain.retrievers.document_compressors import LLMChainExtractor
206          from langchain.retrievers.multi_query import MultiQueryRetriever
207  
208      except ImportError:
209          from langchain_classic.retrievers import (
210              ContextualCompressionRetriever,
211              EnsembleRetriever,
212              TimeWeightedVectorStoreRetriever,
213          )
214          from langchain_classic.retrievers.document_compressors import LLMChainExtractor
215          from langchain_classic.retrievers.multi_query import MultiQueryRetriever
216  
217      multi_query_retriever = MultiQueryRetriever.from_llm(retriever=retriever_1, llm=llm)
218      assert list(_extract_databricks_dependencies_from_retriever(multi_query_retriever)) == [
219          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"),
220          DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
221      ]
222  
223      compressor = LLMChainExtractor.from_llm(llm)
224      compression_retriever = ContextualCompressionRetriever(
225          base_compressor=compressor, base_retriever=retriever_1
226      )
227      assert list(_extract_databricks_dependencies_from_retriever(compression_retriever)) == [
228          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"),
229          DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
230      ]
231  
232      ensemble_retriever = EnsembleRetriever(
233          retrievers=[retriever_1, retriever_2], weights=[0.5, 0.5]
234      )
235      assert list(_extract_databricks_dependencies_from_retriever(ensemble_retriever)) == [
236          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"),
237          DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
238          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_2"),
239          DatabricksServingEndpoint(endpoint_name="embedding-model"),
240      ]
241  
242      time_weighted_retriever = TimeWeightedVectorStoreRetriever(
243          vectorstore=vectorstore_1, decay_rate=0.0000000000000000000000001, k=1
244      )
245      assert list(_extract_databricks_dependencies_from_retriever(time_weighted_retriever)) == [
246          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"),
247          DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
248      ]
249  
250  
251  def test_parsing_dependency_from_retriever_with_embedding_endpoint_in_index():
252      vectorstore = get_vector_search(
253          endpoint_name="dbdemos_vs_endpoint",
254          index_name="mlflow.rag.vs_index",
255          has_embedding_endpoint=True,
256      )
257      retriever = vectorstore.as_retriever()
258      resources = list(_extract_databricks_dependencies_from_retriever(retriever))
259      assert resources == [
260          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"),
261          DatabricksServingEndpoint(endpoint_name="embedding-model"),
262      ]
263  
264  
265  def test_parsing_dependency_from_agent(monkeypatch: pytest.MonkeyPatch):
266      from databricks.sdk.service.catalog import FunctionInfo
267      from databricks_langchain import ChatDatabricks
268      from langchain.agents import initialize_agent
269  
270      try:
271          from langchain_community.tools.databricks import UCFunctionToolkit
272      except Exception:
273          return
274  
275      # When get is called return a function
276      def mock_function_get(self, function_name):
277          components = function_name.split(".")
278          # Initialize agent used below requires functions to take in exactly one parameter
279          param_dict = {
280              "parameters": [
281                  {
282                      "name": "param",
283                      "parameter_type": "PARAM",
284                      "position": 0,
285                      "type_json": '{"name":"param","type":"string","nullable":true,"metadata":{}}',
286                      "type_name": "STRING",
287                      "type_precision": 0,
288                      "type_scale": 0,
289                      "type_text": "string",
290                  }
291              ]
292          }
293          # Add the catalog, schema and name to the function Info followed by the parameter
294          return FunctionInfo.from_dict({
295              "catalog_name": components[0],
296              "schema_name": components[1],
297              "name": components[2],
298              "input_params": param_dict,
299          })
300  
301      monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
302      monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
303      monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.get", mock_function_get)
304  
305      toolkit = UCFunctionToolkit(warehouse_id="testId1").include("rag.test.test_function")
306      llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", temperature=0)
307      agent = initialize_agent(
308          toolkit.get_tools(),
309          llm,
310          verbose=True,
311      )
312  
313      resources = sorted(_extract_dependency_list_from_lc_model(agent), key=lambda x: x.name)
314      assert resources == [
315          DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"),
316          DatabricksFunction(function_name="rag.test.test_function"),
317          DatabricksSQLWarehouse(warehouse_id="testId1"),
318      ]
319  
320  
321  def test_parsing_multiple_dependency_from_agent(monkeypatch):
322      from databricks.sdk.service.catalog import FunctionInfo
323      from databricks_langchain import ChatDatabricks
324      from langchain.agents import initialize_agent
325      from langchain.tools.retriever import create_retriever_tool
326  
327      remove_langchain_community(monkeypatch)
328  
329      def mock_function_get(self, function_name):
330          components = function_name.split(".")
331          param_dict = {
332              "parameters": [
333                  {
334                      "name": "param",
335                      "parameter_type": "PARAM",
336                      "position": 0,
337                      "type_json": '{"name":"param","type":"string","nullable":true,"metadata":{}}',
338                      "type_name": "STRING",
339                      "type_precision": 0,
340                      "type_scale": 0,
341                      "type_text": "string",
342                  }
343              ]
344          }
345          return FunctionInfo.from_dict({
346              "catalog_name": components[0],
347              "schema_name": components[1],
348              "name": components[2],
349              "input_params": param_dict,
350          })
351  
352      # In addition to above now handle the case where a '*' is passed in and list all the functions
353      def mock_function_list(self, catalog_name, schema_name):
354          assert catalog_name == "rag"
355          assert schema_name == "test"
356          return [
357              FunctionInfo(full_name="rag.test.test_function"),
358              FunctionInfo(full_name="rag.test.test_function_2"),
359              FunctionInfo(full_name="rag.test.test_function_3"),
360          ]
361  
362      monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
363      monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
364      monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.get", mock_function_get)
365      monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.list", mock_function_list)
366  
367      include_uc_function_tools = False
368      try:
369          from langchain_community.tools.databricks import UCFunctionToolkit
370  
371          include_uc_function_tools = True
372      except Exception:
373          include_uc_function_tools = False
374  
375      uc_function_tools = (
376          (UCFunctionToolkit(warehouse_id="testId1").include("rag.test.*").get_tools())
377          if include_uc_function_tools
378          else []
379      )
380      chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500)
381  
382      vectorstore = get_vector_search(
383          endpoint_name="dbdemos_vs_endpoint",
384          index_name="mlflow.rag.vs_index",
385          has_embedding_endpoint=True,
386      )
387      retriever = vectorstore.as_retriever()
388  
389      retriever_tool = create_retriever_tool(retriever, "vs_index_name", "vs_index_desc")
390  
391      agent = initialize_agent(
392          uc_function_tools + [retriever_tool],
393          chat_model,
394          verbose=True,
395      )
396      resources = list(_extract_dependency_list_from_lc_model(agent))
397      # Ensure all resources are added in
398      expected = [
399          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"),
400          DatabricksServingEndpoint(endpoint_name="embedding-model"),
401          DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"),
402      ]
403      if include_uc_function_tools:
404          expected = [
405              DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"),
406              DatabricksFunction(function_name="rag.test.test_function"),
407              DatabricksFunction(function_name="rag.test.test_function_2"),
408              DatabricksFunction(function_name="rag.test.test_function_3"),
409              DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"),
410              DatabricksServingEndpoint(endpoint_name="embedding-model"),
411              DatabricksSQLWarehouse(warehouse_id="testId1"),
412          ]
413  
414      def build_resource_map(resources):
415          resource_map = defaultdict(list)
416  
417          for resource in resources:
418              resource_type = resource.type.value
419              resource_name = resource.to_dict()[resource_type][0]["name"]
420              resource_map[resource_type].append(resource_name)
421  
422          return dict(resource_map)
423  
424      # Build maps for resources and expected resources
425      resource_maps = build_resource_map(resources)
426      expected_maps = build_resource_map(expected)
427  
428      assert len(resource_maps) == len(expected_maps)
429  
430      for resource_type in resource_maps:
431          assert Counter(resource_maps[resource_type]) == Counter(
432              expected_maps.get(resource_type, [])
433          )
434  
435  
436  def test_parsing_dependency_from_databricks_chat(monkeypatch):
437      from databricks_langchain import ChatDatabricks
438  
439      # in databricks-langchain > 0.7.0, ChatDatabricks instantiates
440      # workspace client in __init__ which requires Databricks creds
441      monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
442      monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
443  
444      remove_langchain_community(monkeypatch)
445      with pytest.raises(ImportError, match="No module named 'langchain_community"):
446          from langchain_community.chat_models import ChatDatabricks
447  
448      chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500)
449      resources = list(_extract_databricks_dependencies_from_chat_model(chat_model))
450      assert resources == [DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat")]
451  
452  
453  def test_parsing_dependency_from_databricks(monkeypatch):
454      from databricks_langchain import ChatDatabricks
455  
456      # in databricks-langchain > 0.7.0, ChatDatabricks instantiates
457      # workspace client in __init__ which requires Databricks creds
458      monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
459      monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
460  
461      remove_langchain_community(monkeypatch)
462      with pytest.raises(ImportError, match="No module named 'langchain_community"):
463          from langchain_community.chat_models import ChatDatabricks
464  
465      vectorstore = get_vector_search(
466          endpoint_name="dbdemos_vs_endpoint",
467          index_name="mlflow.rag.vs_index",
468          has_embedding_endpoint=True,
469      )
470      retriever = vectorstore.as_retriever()
471      llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500)
472      llm2 = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500)
473  
474      model = retriever | llm | llm2
475      resources = _detect_databricks_dependencies(model)
476      assert resources == [
477          DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"),
478          DatabricksServingEndpoint(endpoint_name="embedding-model"),
479          DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"),
480      ]
481  
482  
483  def test_parsing_unitycatalog_tool_as_dependency(monkeypatch: pytest.MonkeyPatch):
484      from databricks.sdk.service.catalog import FunctionInfo
485      from databricks_langchain import ChatDatabricks
486      from langchain.agents import initialize_agent
487      from unitycatalog.ai.core.databricks import DatabricksFunctionClient
488      from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit
489  
490      # When get is called return a function
491      def mock_function_get(self, function_name):
492          components = function_name.split(".")
493          # Initialize agent used below requires functions to take in exactly one parameter
494          param_dict = {
495              "parameters": [
496                  {
497                      "name": "param",
498                      "parameter_type": "PARAM",
499                      "position": 0,
500                      "type_json": '{"name":"param","type":"string","nullable":true,"metadata":{}}',
501                      "type_name": "STRING",
502                      "type_precision": 0,
503                      "type_scale": 0,
504                      "type_text": "string",
505                  }
506              ]
507          }
508          # Add the catalog, schema and name to the function Info followed by the parameter
509          return FunctionInfo.from_dict({
510              "catalog_name": components[0],
511              "schema_name": components[1],
512              "name": components[2],
513              "input_params": param_dict,
514          })
515  
516      monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
517      monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
518      monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.get", mock_function_get)
519  
520      # TODO: remove this mock after unitycatalog-ai release a new version to avoid setting
521      # spark session during initialization
522      with mock.patch("unitycatalog.ai.core.databricks.DatabricksFunctionClient.set_spark_session"):
523          client = DatabricksFunctionClient()
524      toolkit = UCFunctionToolkit(function_names=["rag.test.test_function"], client=client)
525      llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", temperature=0)
526      agent = initialize_agent(
527          toolkit.tools,
528          llm,
529          verbose=True,
530      )
531  
532      resources = sorted(_extract_dependency_list_from_lc_model(agent), key=lambda x: x.name)
533      assert resources == [
534          DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"),
535          DatabricksFunction(function_name="rag.test.test_function"),
536      ]