test_langchain_databricks_dependency_extraction.py
1 from collections import Counter, defaultdict 2 from unittest import mock 3 4 import langchain 5 import pytest 6 from databricks.vector_search.client import VectorSearchIndex 7 from packaging.version import Version 8 9 from mlflow.langchain.databricks_dependencies import ( 10 _detect_databricks_dependencies, 11 _extract_databricks_dependencies_from_chat_model, 12 _extract_databricks_dependencies_from_llm, 13 _extract_databricks_dependencies_from_retriever, 14 _extract_dependency_list_from_lc_model, 15 ) 16 from mlflow.models.resources import ( 17 DatabricksFunction, 18 DatabricksServingEndpoint, 19 DatabricksSQLWarehouse, 20 DatabricksVectorSearchIndex, 21 ) 22 23 # TODO: Remove this once databricks-langchain supports v1 24 if Version(langchain.__version__).major >= 1: 25 pytest.skip("databricks-langchain does not support v1 yet", allow_module_level=True) 26 27 28 class MockDatabricksServingEndpointClient: 29 def __init__( 30 self, 31 host: str, 32 api_token: str, 33 endpoint_name: str, 34 databricks_uri: str, 35 task: str, 36 ): 37 self.host = host 38 self.api_token = api_token 39 self.endpoint_name = endpoint_name 40 self.databricks_uri = databricks_uri 41 self.task = task 42 43 44 def _is_partner_package_installed(): 45 try: 46 import databricks_langchain # noqa: F401 47 48 return True 49 except ImportError: 50 return False 51 52 53 def remove_langchain_community(monkeypatch): 54 # Simulate the environment where langchain_community is not installed 55 original_import = __import__ 56 57 def mock_import(name, *args, **kwargs): 58 if name.startswith("langchain_community"): 59 raise ImportError("No module named 'langchain_community'") 60 return original_import(name, *args, **kwargs) 61 62 monkeypatch.setattr("builtins.__import__", mock_import) 63 64 65 def test_parsing_dependency_from_databricks_llm(monkeypatch: pytest.MonkeyPatch): 66 from langchain_community.llms import Databricks 67 68 from mlflow.langchain.utils.logging import IS_PICKLE_SERIALIZATION_RESTRICTED 69 70 monkeypatch.setattr( 71 "langchain_community.llms.databricks._DatabricksServingEndpointClient", 72 MockDatabricksServingEndpointClient, 73 ) 74 monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") 75 monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") 76 77 llm_kwargs = {"endpoint_name": "databricks-mixtral-8x7b-instruct"} 78 if IS_PICKLE_SERIALIZATION_RESTRICTED: 79 llm_kwargs["allow_dangerous_deserialization"] = True 80 81 llm = Databricks(**llm_kwargs) 82 resources = list(_extract_databricks_dependencies_from_llm(llm)) 83 assert resources == [ 84 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct") 85 ] 86 87 88 class MockVectorSearchIndex(VectorSearchIndex): 89 def __init__(self, endpoint_name, index_name, has_embedding_endpoint=False) -> None: 90 self.endpoint_name = endpoint_name 91 self.name = index_name 92 self.has_embedding_endpoint = has_embedding_endpoint 93 94 def describe(self): 95 if self.has_embedding_endpoint: 96 return { 97 "name": self.name, 98 "endpoint_name": self.endpoint_name, 99 "primary_key": "id", 100 "index_type": "DELTA_SYNC", 101 "delta_sync_index_spec": { 102 "source_table": "ml.schema.databricks_documentation", 103 "embedding_source_columns": [ 104 {"name": "content", "embedding_model_endpoint_name": "embedding-model"} 105 ], 106 "pipeline_type": "TRIGGERED", 107 "pipeline_id": "79a76fcc-67ad-4ac6-8d8e-20f7d485ffa6", 108 }, 109 "status": { 110 "detailed_state": "OFFLINE_FAILED", 111 "message": "Index creation failed.", 112 "indexed_row_count": 0, 113 "failed_status": {"error_message": ""}, 114 "ready": False, 115 "index_url": "e2-dogfood.staging.cloud.databricks.com/rest_of_url", 116 }, 117 "creator": "first.last@databricks.com", 118 } 119 else: 120 return { 121 "name": self.name, 122 "endpoint_name": self.endpoint_name, 123 "primary_key": "id", 124 "index_type": "DELTA_SYNC", 125 "delta_sync_index_spec": { 126 "source_table": "ml.schema.databricks_documentation", 127 "embedding_vector_columns": [], 128 "pipeline_type": "TRIGGERED", 129 "pipeline_id": "fbbd5bf1-2b9b-4a7e-8c8d-c0f6cc1030de", 130 }, 131 "status": { 132 "detailed_state": "ONLINE", 133 "message": "Index is currently online", 134 "indexed_row_count": 17183, 135 "ready": True, 136 "index_url": "e2-dogfood.staging.cloud.databricks.com/rest_of_url", 137 }, 138 "creator": "first.last@databricks.com", 139 } 140 141 142 def get_vector_search( 143 endpoint_name: str, 144 index_name: str, 145 has_embedding_endpoint=False, 146 **kwargs, 147 ): 148 index = MockVectorSearchIndex(endpoint_name, index_name, has_embedding_endpoint) 149 150 from databricks_langchain import DatabricksVectorSearch 151 152 with mock.patch("databricks.vector_search.client.VectorSearchClient") as mock_client: 153 mock_client().get_index.return_value = index 154 return DatabricksVectorSearch( 155 endpoint=endpoint_name, 156 index_name=index_name, 157 **kwargs, 158 ) 159 160 161 def test_parsing_dependency_from_databricks_retriever(monkeypatch): 162 from databricks_langchain import ChatDatabricks, DatabricksEmbeddings 163 164 remove_langchain_community(monkeypatch) 165 with pytest.raises(ImportError, match="No module named 'langchain_community"): 166 from langchain_community.embeddings import DatabricksEmbeddings 167 168 embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en") 169 170 # vs_index_1 is a direct access index 171 vectorstore_1 = get_vector_search( 172 endpoint_name="vs_endpoint", 173 index_name="mlflow.rag.vs_index_1", 174 text_column="content", 175 embedding=embedding_model, 176 ) 177 retriever_1 = vectorstore_1.as_retriever() 178 179 # vs_index_2 has builtin embedding endpoint "embedding-model" 180 vectorstore_2 = get_vector_search( 181 endpoint_name="vs_endpoint", 182 index_name="mlflow.rag.vs_index_2", 183 has_embedding_endpoint=True, 184 ) 185 retriever_2 = vectorstore_2.as_retriever() 186 187 llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", temperature=0) 188 189 assert list(_extract_databricks_dependencies_from_retriever(retriever_1)) == [ 190 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"), 191 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 192 ] 193 194 assert list(_extract_databricks_dependencies_from_retriever(retriever_2)) == [ 195 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_2"), 196 DatabricksServingEndpoint(endpoint_name="embedding-model"), 197 ] 198 199 try: 200 from langchain.retrievers import ( 201 ContextualCompressionRetriever, 202 EnsembleRetriever, 203 TimeWeightedVectorStoreRetriever, 204 ) 205 from langchain.retrievers.document_compressors import LLMChainExtractor 206 from langchain.retrievers.multi_query import MultiQueryRetriever 207 208 except ImportError: 209 from langchain_classic.retrievers import ( 210 ContextualCompressionRetriever, 211 EnsembleRetriever, 212 TimeWeightedVectorStoreRetriever, 213 ) 214 from langchain_classic.retrievers.document_compressors import LLMChainExtractor 215 from langchain_classic.retrievers.multi_query import MultiQueryRetriever 216 217 multi_query_retriever = MultiQueryRetriever.from_llm(retriever=retriever_1, llm=llm) 218 assert list(_extract_databricks_dependencies_from_retriever(multi_query_retriever)) == [ 219 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"), 220 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 221 ] 222 223 compressor = LLMChainExtractor.from_llm(llm) 224 compression_retriever = ContextualCompressionRetriever( 225 base_compressor=compressor, base_retriever=retriever_1 226 ) 227 assert list(_extract_databricks_dependencies_from_retriever(compression_retriever)) == [ 228 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"), 229 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 230 ] 231 232 ensemble_retriever = EnsembleRetriever( 233 retrievers=[retriever_1, retriever_2], weights=[0.5, 0.5] 234 ) 235 assert list(_extract_databricks_dependencies_from_retriever(ensemble_retriever)) == [ 236 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"), 237 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 238 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_2"), 239 DatabricksServingEndpoint(endpoint_name="embedding-model"), 240 ] 241 242 time_weighted_retriever = TimeWeightedVectorStoreRetriever( 243 vectorstore=vectorstore_1, decay_rate=0.0000000000000000000000001, k=1 244 ) 245 assert list(_extract_databricks_dependencies_from_retriever(time_weighted_retriever)) == [ 246 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index_1"), 247 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 248 ] 249 250 251 def test_parsing_dependency_from_retriever_with_embedding_endpoint_in_index(): 252 vectorstore = get_vector_search( 253 endpoint_name="dbdemos_vs_endpoint", 254 index_name="mlflow.rag.vs_index", 255 has_embedding_endpoint=True, 256 ) 257 retriever = vectorstore.as_retriever() 258 resources = list(_extract_databricks_dependencies_from_retriever(retriever)) 259 assert resources == [ 260 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"), 261 DatabricksServingEndpoint(endpoint_name="embedding-model"), 262 ] 263 264 265 def test_parsing_dependency_from_agent(monkeypatch: pytest.MonkeyPatch): 266 from databricks.sdk.service.catalog import FunctionInfo 267 from databricks_langchain import ChatDatabricks 268 from langchain.agents import initialize_agent 269 270 try: 271 from langchain_community.tools.databricks import UCFunctionToolkit 272 except Exception: 273 return 274 275 # When get is called return a function 276 def mock_function_get(self, function_name): 277 components = function_name.split(".") 278 # Initialize agent used below requires functions to take in exactly one parameter 279 param_dict = { 280 "parameters": [ 281 { 282 "name": "param", 283 "parameter_type": "PARAM", 284 "position": 0, 285 "type_json": '{"name":"param","type":"string","nullable":true,"metadata":{}}', 286 "type_name": "STRING", 287 "type_precision": 0, 288 "type_scale": 0, 289 "type_text": "string", 290 } 291 ] 292 } 293 # Add the catalog, schema and name to the function Info followed by the parameter 294 return FunctionInfo.from_dict({ 295 "catalog_name": components[0], 296 "schema_name": components[1], 297 "name": components[2], 298 "input_params": param_dict, 299 }) 300 301 monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") 302 monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") 303 monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.get", mock_function_get) 304 305 toolkit = UCFunctionToolkit(warehouse_id="testId1").include("rag.test.test_function") 306 llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", temperature=0) 307 agent = initialize_agent( 308 toolkit.get_tools(), 309 llm, 310 verbose=True, 311 ) 312 313 resources = sorted(_extract_dependency_list_from_lc_model(agent), key=lambda x: x.name) 314 assert resources == [ 315 DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"), 316 DatabricksFunction(function_name="rag.test.test_function"), 317 DatabricksSQLWarehouse(warehouse_id="testId1"), 318 ] 319 320 321 def test_parsing_multiple_dependency_from_agent(monkeypatch): 322 from databricks.sdk.service.catalog import FunctionInfo 323 from databricks_langchain import ChatDatabricks 324 from langchain.agents import initialize_agent 325 from langchain.tools.retriever import create_retriever_tool 326 327 remove_langchain_community(monkeypatch) 328 329 def mock_function_get(self, function_name): 330 components = function_name.split(".") 331 param_dict = { 332 "parameters": [ 333 { 334 "name": "param", 335 "parameter_type": "PARAM", 336 "position": 0, 337 "type_json": '{"name":"param","type":"string","nullable":true,"metadata":{}}', 338 "type_name": "STRING", 339 "type_precision": 0, 340 "type_scale": 0, 341 "type_text": "string", 342 } 343 ] 344 } 345 return FunctionInfo.from_dict({ 346 "catalog_name": components[0], 347 "schema_name": components[1], 348 "name": components[2], 349 "input_params": param_dict, 350 }) 351 352 # In addition to above now handle the case where a '*' is passed in and list all the functions 353 def mock_function_list(self, catalog_name, schema_name): 354 assert catalog_name == "rag" 355 assert schema_name == "test" 356 return [ 357 FunctionInfo(full_name="rag.test.test_function"), 358 FunctionInfo(full_name="rag.test.test_function_2"), 359 FunctionInfo(full_name="rag.test.test_function_3"), 360 ] 361 362 monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") 363 monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") 364 monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.get", mock_function_get) 365 monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.list", mock_function_list) 366 367 include_uc_function_tools = False 368 try: 369 from langchain_community.tools.databricks import UCFunctionToolkit 370 371 include_uc_function_tools = True 372 except Exception: 373 include_uc_function_tools = False 374 375 uc_function_tools = ( 376 (UCFunctionToolkit(warehouse_id="testId1").include("rag.test.*").get_tools()) 377 if include_uc_function_tools 378 else [] 379 ) 380 chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500) 381 382 vectorstore = get_vector_search( 383 endpoint_name="dbdemos_vs_endpoint", 384 index_name="mlflow.rag.vs_index", 385 has_embedding_endpoint=True, 386 ) 387 retriever = vectorstore.as_retriever() 388 389 retriever_tool = create_retriever_tool(retriever, "vs_index_name", "vs_index_desc") 390 391 agent = initialize_agent( 392 uc_function_tools + [retriever_tool], 393 chat_model, 394 verbose=True, 395 ) 396 resources = list(_extract_dependency_list_from_lc_model(agent)) 397 # Ensure all resources are added in 398 expected = [ 399 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"), 400 DatabricksServingEndpoint(endpoint_name="embedding-model"), 401 DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"), 402 ] 403 if include_uc_function_tools: 404 expected = [ 405 DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"), 406 DatabricksFunction(function_name="rag.test.test_function"), 407 DatabricksFunction(function_name="rag.test.test_function_2"), 408 DatabricksFunction(function_name="rag.test.test_function_3"), 409 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"), 410 DatabricksServingEndpoint(endpoint_name="embedding-model"), 411 DatabricksSQLWarehouse(warehouse_id="testId1"), 412 ] 413 414 def build_resource_map(resources): 415 resource_map = defaultdict(list) 416 417 for resource in resources: 418 resource_type = resource.type.value 419 resource_name = resource.to_dict()[resource_type][0]["name"] 420 resource_map[resource_type].append(resource_name) 421 422 return dict(resource_map) 423 424 # Build maps for resources and expected resources 425 resource_maps = build_resource_map(resources) 426 expected_maps = build_resource_map(expected) 427 428 assert len(resource_maps) == len(expected_maps) 429 430 for resource_type in resource_maps: 431 assert Counter(resource_maps[resource_type]) == Counter( 432 expected_maps.get(resource_type, []) 433 ) 434 435 436 def test_parsing_dependency_from_databricks_chat(monkeypatch): 437 from databricks_langchain import ChatDatabricks 438 439 # in databricks-langchain > 0.7.0, ChatDatabricks instantiates 440 # workspace client in __init__ which requires Databricks creds 441 monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") 442 monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") 443 444 remove_langchain_community(monkeypatch) 445 with pytest.raises(ImportError, match="No module named 'langchain_community"): 446 from langchain_community.chat_models import ChatDatabricks 447 448 chat_model = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500) 449 resources = list(_extract_databricks_dependencies_from_chat_model(chat_model)) 450 assert resources == [DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat")] 451 452 453 def test_parsing_dependency_from_databricks(monkeypatch): 454 from databricks_langchain import ChatDatabricks 455 456 # in databricks-langchain > 0.7.0, ChatDatabricks instantiates 457 # workspace client in __init__ which requires Databricks creds 458 monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") 459 monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") 460 461 remove_langchain_community(monkeypatch) 462 with pytest.raises(ImportError, match="No module named 'langchain_community"): 463 from langchain_community.chat_models import ChatDatabricks 464 465 vectorstore = get_vector_search( 466 endpoint_name="dbdemos_vs_endpoint", 467 index_name="mlflow.rag.vs_index", 468 has_embedding_endpoint=True, 469 ) 470 retriever = vectorstore.as_retriever() 471 llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500) 472 llm2 = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", max_tokens=500) 473 474 model = retriever | llm | llm2 475 resources = _detect_databricks_dependencies(model) 476 assert resources == [ 477 DatabricksVectorSearchIndex(index_name="mlflow.rag.vs_index"), 478 DatabricksServingEndpoint(endpoint_name="embedding-model"), 479 DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"), 480 ] 481 482 483 def test_parsing_unitycatalog_tool_as_dependency(monkeypatch: pytest.MonkeyPatch): 484 from databricks.sdk.service.catalog import FunctionInfo 485 from databricks_langchain import ChatDatabricks 486 from langchain.agents import initialize_agent 487 from unitycatalog.ai.core.databricks import DatabricksFunctionClient 488 from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit 489 490 # When get is called return a function 491 def mock_function_get(self, function_name): 492 components = function_name.split(".") 493 # Initialize agent used below requires functions to take in exactly one parameter 494 param_dict = { 495 "parameters": [ 496 { 497 "name": "param", 498 "parameter_type": "PARAM", 499 "position": 0, 500 "type_json": '{"name":"param","type":"string","nullable":true,"metadata":{}}', 501 "type_name": "STRING", 502 "type_precision": 0, 503 "type_scale": 0, 504 "type_text": "string", 505 } 506 ] 507 } 508 # Add the catalog, schema and name to the function Info followed by the parameter 509 return FunctionInfo.from_dict({ 510 "catalog_name": components[0], 511 "schema_name": components[1], 512 "name": components[2], 513 "input_params": param_dict, 514 }) 515 516 monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") 517 monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") 518 monkeypatch.setattr("databricks.sdk.service.catalog.FunctionsAPI.get", mock_function_get) 519 520 # TODO: remove this mock after unitycatalog-ai release a new version to avoid setting 521 # spark session during initialization 522 with mock.patch("unitycatalog.ai.core.databricks.DatabricksFunctionClient.set_spark_session"): 523 client = DatabricksFunctionClient() 524 toolkit = UCFunctionToolkit(function_names=["rag.test.test_function"], client=client) 525 llm = ChatDatabricks(endpoint="databricks-llama-2-70b-chat", temperature=0) 526 agent = initialize_agent( 527 toolkit.tools, 528 llm, 529 verbose=True, 530 ) 531 532 resources = sorted(_extract_dependency_list_from_lc_model(agent), key=lambda x: x.name) 533 assert resources == [ 534 DatabricksServingEndpoint(endpoint_name="databricks-llama-2-70b-chat"), 535 DatabricksFunction(function_name="rag.test.test_function"), 536 ]