test_langchain_model_export.py
1 import inspect 2 import json 3 import os 4 import shutil 5 from operator import itemgetter 6 from typing import Any, Iterator 7 from unittest import mock 8 9 import langchain 10 import pytest 11 import yaml 12 from langchain_community.document_loaders import TextLoader 13 from langchain_community.embeddings.fake import FakeEmbeddings 14 from langchain_community.llms import OpenAI 15 from langchain_community.utilities import TextRequestsWrapper 16 from langchain_community.vectorstores import FAISS 17 from langchain_core.callbacks.base import BaseCallbackHandler 18 from langchain_core.callbacks.manager import CallbackManagerForLLMRun 19 from langchain_core.language_models import SimpleChatModel 20 from langchain_core.messages import ( 21 AIMessage, 22 AIMessageChunk, 23 BaseMessage, 24 HumanMessage, 25 ) 26 from langchain_core.output_parsers import StrOutputParser 27 from langchain_core.outputs import ChatGenerationChunk 28 from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate 29 from langchain_core.runnables import ( 30 RunnableBinding, 31 RunnableBranch, 32 RunnableLambda, 33 RunnableParallel, 34 RunnablePassthrough, 35 RunnableSequence, 36 ) 37 from langchain_core.tools import Tool 38 from langchain_openai import ChatOpenAI 39 from langchain_text_splitters.character import CharacterTextSplitter 40 from packaging import version 41 from pydantic import BaseModel 42 from pyspark.sql import SparkSession 43 44 import mlflow 45 import mlflow.models.model 46 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 47 from mlflow.deployments import PredictionsResponse 48 from mlflow.environment_variables import MLFLOW_CONVERT_MESSAGES_DICT_FOR_LANGCHAIN 49 from mlflow.exceptions import MlflowException 50 from mlflow.langchain.langchain_tracer import MlflowLangchainTracer 51 from mlflow.langchain.utils.chat import ( 52 try_transform_response_to_chat_format, 53 ) 54 from mlflow.langchain.utils.logging import ( 55 IS_PICKLE_SERIALIZATION_RESTRICTED, 56 lc_runnables_types, 57 ) 58 from mlflow.models import Model 59 from mlflow.models.dependencies_schemas import DependenciesSchemasType 60 from mlflow.models.resources import ( 61 DatabricksFunction, 62 DatabricksServingEndpoint, 63 DatabricksSQLWarehouse, 64 DatabricksVectorSearchIndex, 65 ) 66 from mlflow.models.signature import Schema, infer_signature 67 from mlflow.models.utils import load_serving_example 68 from mlflow.pyfunc.context import Context 69 from mlflow.tracing.constant import TraceMetadataKey 70 from mlflow.tracing.export.inference_table import pop_trace 71 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 72 from mlflow.types.schema import Array, ColSpec, DataType, Object, Property 73 74 from tests.helper_functions import _compare_logged_code_paths, pyfunc_serve_and_score_model 75 from tests.langchain.conftest import DeterministicDummyEmbeddings 76 from tests.tracing.helper import get_traces 77 78 # this kwarg was added in langchain_community 0.0.27, and 79 # prevents the use of pickled objects if not provided. 80 VECTORSTORE_KWARGS = ( 81 {"allow_dangerous_deserialization": True} if IS_PICKLE_SERIALIZATION_RESTRICTED else {} 82 ) 83 84 IS_LANGCHAIN_03 = version.parse(langchain.__version__) >= version.parse("0.3.0") 85 IS_LANGCHAIN_v1 = version.parse(langchain.__version__).major >= 1 86 LANGCHAIN_V1_SKIP_REASON = "Pickle serialization is not supported for LangChain v1" 87 88 # Reusable decorator for skipping tests on LangChain v1 89 skip_if_v1 = pytest.mark.skipif(IS_LANGCHAIN_v1, reason=LANGCHAIN_V1_SKIP_REASON) 90 91 # The mock OAI completion endpoint returns payload as it is 92 TEST_CONTENT = [{"role": "user", "content": "What is MLflow?"}] 93 94 SIMPLE_MODEL_CODE_PATH = "tests/langchain/sample_code/simple_runnable.py" 95 96 97 @pytest.fixture 98 def model_path(tmp_path): 99 return tmp_path / "model" 100 101 102 @pytest.fixture(scope="module") 103 def spark(): 104 with SparkSession.builder.master("local[*]").getOrCreate() as s: 105 yield s 106 107 108 def create_openai_runnable(): 109 from langchain_core.output_parsers import StrOutputParser 110 111 prompt = PromptTemplate( 112 input_variables=["product"], 113 template="What is {product}?", 114 ) 115 return prompt | ChatOpenAI(temperature=0.9) | StrOutputParser() 116 117 118 @pytest.fixture 119 def fake_chat_model(): 120 class FakeChatModel(SimpleChatModel): 121 """Fake Chat Model wrapper for testing purposes.""" 122 123 endpoint_name: str = "fake-endpoint" 124 125 def _call( 126 self, 127 messages: list[BaseMessage], 128 stop: list[str] | None = None, 129 run_manager: CallbackManagerForLLMRun | None = None, 130 **kwargs: Any, 131 ) -> str: 132 return "Databricks" 133 134 @property 135 def _llm_type(self) -> str: 136 return "fake chat model" 137 138 return FakeChatModel(endpoint_name="fake-endpoint") 139 140 141 @pytest.fixture 142 def fake_classifier_chat_model(): 143 class FakeMlflowClassifier(SimpleChatModel): 144 """Fake Chat Model wrapper for testing purposes.""" 145 146 def _call( 147 self, 148 messages: list[BaseMessage], 149 stop: list[str] | None = None, 150 run_manager: CallbackManagerForLLMRun | None = None, 151 **kwargs: Any, 152 ) -> str: 153 if "MLflow" in messages[0].content.split(":")[1]: 154 return "yes" 155 if "cat" in messages[0].content.split(":")[1]: 156 return "no" 157 return "unknown" 158 159 @property 160 def _llm_type(self) -> str: 161 return "fake mlflow classifier" 162 163 return FakeMlflowClassifier() 164 165 166 @skip_if_v1 167 def test_langchain_native_log_and_load_model(): 168 model = create_openai_runnable() 169 170 with mlflow.start_run(): 171 logged_model = mlflow.langchain.log_model( 172 model, name="langchain_model", input_example={"product": "MLflow"} 173 ) 174 175 loaded_model = mlflow.langchain.load_model(logged_model.model_uri) 176 177 assert "langchain" in logged_model.flavors 178 assert str(logged_model.signature.inputs) == "['product': string (required)]" 179 assert str(logged_model.signature.outputs) == "[string (required)]" 180 181 assert type(loaded_model) == RunnableSequence 182 assert loaded_model.steps[0].template == "What is {product}?" 183 assert type(loaded_model.steps[1]).__name__ == "ChatOpenAI" 184 185 # Predict 186 loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) 187 result = loaded_model.predict([{"product": "MLflow"}]) 188 assert result == [json.dumps(TEST_CONTENT)] 189 190 # Predict stream 191 result = loaded_model.predict_stream([{"product": "MLflow"}]) 192 assert inspect.isgenerator(result) 193 assert list(result) == ["Hello", " world"] 194 195 196 @skip_if_v1 197 def test_pyfunc_spark_udf_with_langchain_model(spark): 198 model = create_openai_runnable() 199 with mlflow.start_run(): 200 logged_model = mlflow.langchain.log_model( 201 model, name="langchain_model", input_example={"product": "MLflow"} 202 ) 203 loaded_model = mlflow.pyfunc.spark_udf(spark, logged_model.model_uri, result_type="string") 204 df = spark.createDataFrame([("MLflow",), ("Spark",)], ["product"]) 205 df = df.withColumn("answer", loaded_model()) 206 pdf = df.toPandas() 207 assert pdf["answer"].tolist() == [ 208 '[{"role": "user", "content": "What is MLflow?"}]', 209 '[{"role": "user", "content": "What is Spark?"}]', 210 ] 211 212 213 @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="create_agent is not supported in LangChain v0") 214 def test_langchain_agent_model_predict(monkeypatch): 215 input_example = {"input": "What is 2 * 3?"} 216 217 with mlflow.start_run(): 218 logged_model = mlflow.langchain.log_model( 219 # OpenAI Client since 1.0 contains thread lock object that cannot be 220 # pickled. Therefore, AgentExecutor cannot be saved with the legacy 221 # object-based logging and we need to use Model-from-Code logging. 222 "tests/langchain/sample_code/openai_agent.py", 223 name="langchain_model", 224 input_example=input_example, 225 ) 226 227 loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) 228 229 # Basic prediction 230 response = loaded_model.predict([input_example]) 231 expected_output = "The result of 2 * 3 is 6." 232 233 assert response[0]["messages"][-1]["content"] == expected_output 234 235 # Stream prediction 236 response = loaded_model.predict_stream([input_example]) 237 assert inspect.isgenerator(response) 238 assert list(response) == [ 239 {"model": {"messages": [mock.ANY]}}, 240 {"tools": {"messages": [mock.ANY]}}, 241 {"model": {"messages": [mock.ANY]}}, 242 ] 243 244 # Model serving 245 inference_payload = load_serving_example(logged_model.model_uri) 246 response = pyfunc_serve_and_score_model( 247 logged_model.model_uri, 248 data=inference_payload, 249 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 250 extra_args=["--env-manager", "local"], 251 ) 252 # TODO: The response is not wrapped by the "predictions" key. This is a bug in 253 # output handling. Often the user input contains a key "input" because it is 254 # used in popular agent prompts in the hub. However, this confuses the scoring 255 # server to treat it as a llm/v1/completion request. 256 response = json.loads(response.content.decode("utf-8")) 257 assert response[0]["messages"][-1]["content"] == expected_output 258 259 260 def assert_equal_retrievers(retriever, expected_retriever): 261 from langchain.schema.retriever import BaseRetriever 262 263 assert isinstance(retriever, BaseRetriever) 264 assert isinstance(retriever, type(expected_retriever)) 265 assert isinstance(retriever.vectorstore, type(expected_retriever.vectorstore)) 266 assert retriever.tags == expected_retriever.tags 267 assert retriever.metadata == expected_retriever.metadata 268 assert retriever.search_type == expected_retriever.search_type 269 assert retriever.search_kwargs == expected_retriever.search_kwargs 270 271 272 @skip_if_v1 273 def test_log_and_load_retriever_chain(tmp_path): 274 # Create the vector db, persist the db to a local fs folder 275 loader = TextLoader("tests/langchain/state_of_the_union.txt") 276 documents = loader.load() 277 text_splitter = CharacterTextSplitter(chunk_size=256, chunk_overlap=0) 278 docs = text_splitter.split_documents(documents) 279 embeddings = DeterministicDummyEmbeddings(size=5) 280 db = FAISS.from_documents(docs, embeddings) 281 persist_dir = str(tmp_path / "faiss_index") 282 db.save_local(persist_dir) 283 284 # Define the loader_fn 285 def load_retriever(persist_directory): 286 import numpy as np 287 from langchain.embeddings.base import Embeddings 288 289 class DeterministicDummyEmbeddings(Embeddings, BaseModel): 290 size: int 291 292 def _get_embedding(self, text: str) -> list[float]: 293 if isinstance(text, np.ndarray): 294 text = text.item() 295 seed = abs(hash(text)) % (10**8) 296 np.random.seed(seed) 297 return list(np.random.normal(size=self.size)) 298 299 def embed_documents(self, texts: list[str]) -> list[list[float]]: 300 return [self._get_embedding(t) for t in texts] 301 302 def embed_query(self, text: str) -> list[float]: 303 return self._get_embedding(text) 304 305 embeddings = DeterministicDummyEmbeddings(size=5) 306 vectorstore = FAISS.load_local( 307 persist_directory, 308 embeddings, 309 **VECTORSTORE_KWARGS, 310 ) 311 return vectorstore.as_retriever() 312 313 query = "What did the president say about Ketanji Brown Jackson" 314 langchain_input = {"query": query} 315 # Log the retriever 316 with mlflow.start_run(): 317 logged_model = mlflow.langchain.log_model( 318 db.as_retriever(), 319 name="retriever", 320 loader_fn=load_retriever, 321 persist_dir=persist_dir, 322 input_example=langchain_input, 323 ) 324 325 # Remove the persist_dir 326 shutil.rmtree(persist_dir) 327 328 # Load the retriever 329 loaded_model = mlflow.langchain.load_model(logged_model.model_uri) 330 assert_equal_retrievers(loaded_model, db.as_retriever()) 331 332 loaded_pyfunc_model = mlflow.pyfunc.load_model(logged_model.model_uri) 333 result = loaded_pyfunc_model.predict([langchain_input]) 334 expected_result = [ 335 { 336 "page_content": doc.page_content, 337 "metadata": doc.metadata, 338 "type": "Document", 339 "id": mock.ANY, 340 } 341 for doc in db.as_retriever().get_relevant_documents(query) 342 ] 343 assert result == [expected_result] 344 345 # Serve the retriever 346 inference_payload = load_serving_example(logged_model.model_uri) 347 response = pyfunc_serve_and_score_model( 348 logged_model.model_uri, 349 data=inference_payload, 350 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 351 extra_args=["--env-manager", "local"], 352 ) 353 pred = PredictionsResponse.from_json(response.content.decode("utf-8"))["predictions"] 354 assert type(pred) == list 355 assert len(pred) == 1 356 docs_list = pred[0] 357 assert type(docs_list) == list 358 assert len(docs_list) == 4 359 # The returned docs are non-deterministic when used with dummy embeddings, 360 # so we cannot assert pred == {"predictions": [expected_result]} 361 362 363 def load_requests_wrapper(_): 364 return TextRequestsWrapper(headers=None, aiosession=None) 365 366 367 @skip_if_v1 368 def test_agent_with_unpicklable_tools(tmp_path): 369 from langchain.agents import AgentType, initialize_agent 370 371 tmp_file = tmp_path / "temp_file.txt" 372 with open(tmp_file, mode="w") as temp_file: 373 # files that aren't opened for reading cannot be pickled 374 tools = [ 375 Tool.from_function( 376 func=lambda: temp_file, 377 name="Write 0", 378 description="If you need to write 0 to a file", 379 ) 380 ] 381 agent = initialize_agent( 382 llm=OpenAI(temperature=0), 383 tools=tools, 384 agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, 385 ) 386 387 with pytest.raises( 388 MlflowException, 389 match=( 390 "Error when attempting to pickle the AgentExecutor tools. " 391 "This model likely does not support serialization." 392 ), 393 ): 394 with mlflow.start_run(): 395 mlflow.langchain.log_model(agent, name="unpicklable_tools") 396 397 398 @skip_if_v1 399 def test_save_load_runnable_passthrough(): 400 runnable = RunnablePassthrough() 401 assert runnable.invoke("hello") == "hello" 402 403 input_example = "hello" 404 with mlflow.start_run(): 405 model_info = mlflow.langchain.log_model( 406 runnable, name="model_path", input_example=input_example 407 ) 408 409 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 410 assert loaded_model.invoke(input_example) == "hello" 411 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 412 assert pyfunc_loaded_model.predict(["hello"]) == ["hello"] 413 414 415 @skip_if_v1 416 def test_save_load_runnable_lambda(spark): 417 def add_one(x: int) -> int: 418 return x + 1 419 420 runnable = RunnableLambda(add_one) 421 422 assert runnable.invoke(1) == 2 423 assert runnable.batch([1, 2, 3]) == [2, 3, 4] 424 425 with mlflow.start_run(): 426 model_info = mlflow.langchain.log_model( 427 runnable, name="runnable_lambda", input_example=[1, 2, 3] 428 ) 429 430 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 431 assert loaded_model.invoke(1) == 2 432 assert loaded_model.batch([1, 2, 3]) == [2, 3, 4] 433 434 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 435 assert loaded_model.predict(1) == [2] 436 assert loaded_model.predict([1, 2, 3]) == [2, 3, 4] 437 438 udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, result_type="long") 439 df = spark.createDataFrame([(1,), (2,), (3,)], ["data"]) 440 df = df.withColumn("answer", udf("data")) 441 pdf = df.toPandas() 442 assert pdf["answer"].tolist() == [2, 3, 4] 443 444 445 @skip_if_v1 446 def test_save_load_runnable_lambda_in_sequence(): 447 def add_one(x): 448 return x + 1 449 450 def mul_two(x): 451 return x * 2 452 453 runnable_1 = RunnableLambda(add_one) 454 runnable_2 = RunnableLambda(mul_two) 455 sequence = runnable_1 | runnable_2 456 assert sequence.invoke(1) == 4 457 458 with mlflow.start_run(): 459 model_info = mlflow.langchain.log_model( 460 sequence, name="model_path", input_example=[1, 2, 3] 461 ) 462 463 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 464 assert loaded_model.invoke(1) == 4 465 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 466 assert pyfunc_loaded_model.predict(1) == [4] 467 assert pyfunc_loaded_model.predict([1, 2, 3]) == [4, 6, 8] 468 469 inference_payload = load_serving_example(model_info.model_uri) 470 response = pyfunc_serve_and_score_model( 471 model_info.model_uri, 472 data=inference_payload, 473 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 474 extra_args=["--env-manager", "local"], 475 ) 476 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 477 "predictions": [4, 6, 8] 478 } 479 480 481 @skip_if_v1 482 def test_predict_with_callbacks(fake_chat_model): 483 class TestCallbackHandler(BaseCallbackHandler): 484 def __init__(self): 485 super().__init__() 486 self.num_llm_start_calls = 0 487 488 def on_llm_start( 489 self, 490 serialized: dict[str, Any], 491 prompts: list[str], 492 **kwargs: Any, 493 ) -> Any: 494 self.num_llm_start_calls += 1 495 496 prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?") 497 chain = prompt | fake_chat_model | StrOutputParser() 498 # Test the basic functionality of the chain 499 assert chain.invoke({"industry": "tech"}) == "Databricks" 500 501 with mlflow.start_run(): 502 model_info = mlflow.langchain.log_model( 503 chain, name="model_path", input_example={"industry": "tech"} 504 ) 505 506 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 507 508 callback_handler1 = TestCallbackHandler() 509 callback_handler2 = TestCallbackHandler() 510 511 # Ensure handlers have not been called yet 512 assert callback_handler1.num_llm_start_calls == 0 513 assert callback_handler2.num_llm_start_calls == 0 514 515 assert ( 516 pyfunc_loaded_model._model_impl._predict_with_callbacks( 517 {"industry": "tech"}, 518 callback_handlers=[callback_handler1, callback_handler2], 519 ) 520 == "Databricks" 521 ) 522 523 # Test that the callback handlers were called 524 assert callback_handler1.num_llm_start_calls == 1 525 assert callback_handler2.num_llm_start_calls == 1 526 527 inference_payload = load_serving_example(model_info.model_uri) 528 response = pyfunc_serve_and_score_model( 529 model_info.model_uri, 530 data=inference_payload, 531 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 532 extra_args=["--env-manager", "local"], 533 ) 534 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 535 "predictions": ["Databricks"] 536 } 537 538 539 @skip_if_v1 540 def test_predict_with_callbacks_supports_chat_response_conversion(fake_chat_model): 541 prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?") 542 chain = prompt | fake_chat_model | StrOutputParser() 543 # Test the basic functionality of the chain 544 assert chain.invoke({"industry": "tech"}) == "Databricks" 545 546 with mlflow.start_run(): 547 model_info = mlflow.langchain.log_model( 548 chain, name="model_path", input_example={"industry": "tech"} 549 ) 550 551 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 552 expected_chat_response = { 553 "id": None, 554 "object": "chat.completion", 555 "created": 1677858242, 556 "model": "", 557 "choices": [ 558 { 559 "index": 0, 560 "message": { 561 "role": "assistant", 562 "content": "Databricks", 563 }, 564 "finish_reason": None, 565 } 566 ], 567 "usage": { 568 "prompt_tokens": None, 569 "completion_tokens": None, 570 "total_tokens": None, 571 }, 572 } 573 with mock.patch("time.time", return_value=1677858242): 574 assert ( 575 pyfunc_loaded_model._model_impl._predict_with_callbacks( 576 {"industry": "tech"}, 577 convert_chat_responses=True, 578 ) 579 == expected_chat_response 580 ) 581 582 assert ( 583 pyfunc_loaded_model._model_impl._predict_with_callbacks( 584 {"industry": "tech"}, 585 convert_chat_responses=False, 586 ) 587 == "Databricks" 588 ) 589 590 591 @skip_if_v1 592 def test_save_load_runnable_parallel(): 593 runnable = RunnableParallel({"llm": create_openai_runnable()}) 594 expected_result = {"llm": json.dumps(TEST_CONTENT)} 595 assert runnable.invoke({"product": "MLflow"}) == expected_result 596 with mlflow.start_run(): 597 model_info = mlflow.langchain.log_model( 598 runnable, name="model_path", input_example=[{"product": "MLflow"}] 599 ) 600 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 601 assert loaded_model.invoke({"product": "MLflow"}) == expected_result 602 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 603 assert pyfunc_loaded_model.predict([{"product": "MLflow"}]) == [expected_result] 604 605 inference_payload = load_serving_example(model_info.model_uri) 606 response = pyfunc_serve_and_score_model( 607 model_info.model_uri, 608 data=inference_payload, 609 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 610 extra_args=["--env-manager", "local"], 611 ) 612 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 613 "predictions": [expected_result] 614 } 615 616 617 @skip_if_v1 618 def test_save_load_chain_with_model_paths(): 619 prompt1 = PromptTemplate.from_template("what is the city {person} is from?") 620 llm = ChatOpenAI(temperature=0.9) 621 model = prompt1 | llm | StrOutputParser() 622 623 with mlflow.start_run(): 624 model_info = mlflow.langchain.log_model(model, name="model_path") 625 artifact_path = "model_path" 626 with ( 627 mlflow.start_run(), 628 mock.patch("mlflow.langchain.model._add_code_from_conf_to_system_path") as add_mock, 629 ): 630 model_info = mlflow.langchain.log_model(model, name=artifact_path, code_paths=[__file__]) 631 mlflow.langchain.load_model(model_info.model_uri) 632 model_uri = model_info.model_uri 633 _compare_logged_code_paths(__file__, model_uri, mlflow.langchain.FLAVOR_NAME) 634 add_mock.assert_called() 635 636 637 @skip_if_v1 638 def test_save_load_rag(tmp_path, spark, fake_chat_model): 639 # TODO: Migrate to models-from-code 640 # Create the vector db, persist the db to a local fs folder 641 loader = TextLoader("tests/langchain/state_of_the_union.txt") 642 documents = loader.load() 643 text_splitter = CharacterTextSplitter(chunk_size=10, chunk_overlap=0) 644 docs = text_splitter.split_documents(documents) 645 embeddings = DeterministicDummyEmbeddings(size=5) 646 db = FAISS.from_documents(docs, embeddings) 647 persist_dir = str(tmp_path / "faiss_index") 648 db.save_local(persist_dir) 649 retriever = db.as_retriever() 650 651 def load_retriever(persist_directory): 652 embeddings = FakeEmbeddings(size=5) 653 vectorstore = FAISS.load_local( 654 persist_directory, 655 embeddings, 656 **VECTORSTORE_KWARGS, 657 ) 658 return vectorstore.as_retriever() 659 660 prompt = ChatPromptTemplate.from_template( 661 "Answer the following question based on the context: {context}\nQuestion: {question}" 662 ) 663 retrieval_chain = ( 664 { 665 "context": retriever, 666 "question": RunnablePassthrough(), 667 } 668 | prompt 669 | fake_chat_model 670 | StrOutputParser() 671 ) 672 question = "What is a good name for a company that makes MLflow?" 673 answer = "Databricks" 674 assert retrieval_chain.invoke(question) == answer 675 with mlflow.start_run(): 676 model_info = mlflow.langchain.log_model( 677 retrieval_chain, 678 name="model_path", 679 loader_fn=load_retriever, 680 persist_dir=persist_dir, 681 input_example=question, 682 ) 683 684 # Remove the persist_dir 685 shutil.rmtree(persist_dir) 686 687 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 688 assert loaded_model.invoke(question) == answer 689 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 690 assert pyfunc_loaded_model.predict(question) == [answer] 691 692 udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, result_type="string") 693 df = spark.createDataFrame([(question,), (question,)], ["question"]) 694 df = df.withColumn("answer", udf("question")) 695 pdf = df.toPandas() 696 assert pdf["answer"].tolist() == [answer, answer] 697 698 inference_payload = load_serving_example(model_info.model_uri) 699 response = pyfunc_serve_and_score_model( 700 model_info.model_uri, 701 data=inference_payload, 702 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 703 extra_args=["--env-manager", "local"], 704 ) 705 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 706 "predictions": [answer] 707 } 708 709 710 @skip_if_v1 711 def test_runnable_branch_save_load(): 712 branch = RunnableBranch( 713 (lambda x: isinstance(x, str), lambda x: x.upper()), 714 (lambda x: isinstance(x, int), lambda x: x + 1), 715 (lambda x: isinstance(x, float), lambda x: x * 2), 716 lambda x: "goodbye", 717 ) 718 719 assert branch.invoke("hello") == "HELLO" 720 assert branch.invoke({}) == "goodbye" 721 722 with mlflow.start_run(): 723 # We only support single input format for now, so we should 724 # not save signature for runnable branch which accepts multiple 725 # input types 726 model_info = mlflow.langchain.log_model(branch, name="model_path") 727 728 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 729 assert loaded_model.invoke("hello") == "HELLO" 730 assert loaded_model.invoke({}) == "goodbye" 731 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 732 assert pyfunc_loaded_model.predict("hello") == "HELLO" 733 assert pyfunc_loaded_model.predict({}) == "goodbye" 734 735 response = pyfunc_serve_and_score_model( 736 model_info.model_uri, 737 data=json.dumps({"inputs": "hello"}), 738 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 739 extra_args=["--env-manager", "local"], 740 ) 741 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 742 "predictions": "HELLO" 743 } 744 745 746 @skip_if_v1 747 def test_complex_runnable_branch_save_load(fake_chat_model, fake_classifier_chat_model): 748 prompt = ChatPromptTemplate.from_template("{question_is_relevant}\n{query}") 749 # Need to add prompt here as the chat model doesn't accept dict input 750 answer_model = prompt | fake_chat_model 751 752 decline_to_answer = RunnableLambda( 753 lambda x: "I cannot answer questions that are not about MLflow." 754 ) 755 something_went_wrong = RunnableLambda(lambda x: "Something went wrong.") 756 757 is_question_about_mlflow_prompt = ChatPromptTemplate.from_template( 758 "You are classifying documents to know if this question " 759 "is related with MLflow. Only answer with yes or no. The question is: {query}" 760 ) 761 762 branch_node = RunnableBranch( 763 (lambda x: x["question_is_relevant"].lower() == "yes", answer_model), 764 (lambda x: x["question_is_relevant"].lower() == "no", decline_to_answer), 765 something_went_wrong, 766 ) 767 768 chain = ( 769 { 770 "question_is_relevant": is_question_about_mlflow_prompt 771 | fake_classifier_chat_model 772 | StrOutputParser(), 773 "query": itemgetter("query"), 774 } 775 | branch_node 776 | StrOutputParser() 777 ) 778 779 assert chain.invoke({"query": "Who owns MLflow?"}) == "Databricks" 780 assert ( 781 chain.invoke({"query": "Do you like cat?"}) 782 == "I cannot answer questions that are not about MLflow." 783 ) 784 assert chain.invoke({"query": "Are you happy today?"}) == "Something went wrong." 785 786 with mlflow.start_run(): 787 model_info = mlflow.langchain.log_model( 788 chain, name="model_path", input_example={"query": "Who owns MLflow?"} 789 ) 790 791 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 792 assert loaded_model.invoke({"query": "Who owns MLflow?"}) == "Databricks" 793 assert ( 794 loaded_model.invoke({"query": "Do you like cat?"}) 795 == "I cannot answer questions that are not about MLflow." 796 ) 797 assert loaded_model.invoke({"query": "Are you happy today?"}) == "Something went wrong." 798 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 799 assert pyfunc_loaded_model.predict({"query": "Who owns MLflow?"}) == ["Databricks"] 800 assert pyfunc_loaded_model.predict({"query": "Do you like cat?"}) == [ 801 "I cannot answer questions that are not about MLflow." 802 ] 803 assert pyfunc_loaded_model.predict({"query": "Are you happy today?"}) == [ 804 "Something went wrong." 805 ] 806 807 inference_payload = load_serving_example(model_info.model_uri) 808 response = pyfunc_serve_and_score_model( 809 model_info.model_uri, 810 data=inference_payload, 811 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 812 extra_args=["--env-manager", "local"], 813 ) 814 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 815 "predictions": ["Databricks"] 816 } 817 818 819 @skip_if_v1 820 def test_chat_with_history(spark, fake_chat_model): 821 prompt_with_history_str = """ 822 Here is a history between you and a human: {chat_history} 823 824 Now, please answer this question: {question} 825 """ 826 827 prompt_with_history = PromptTemplate( 828 input_variables=["chat_history", "question"], template=prompt_with_history_str 829 ) 830 831 def extract_question(input): 832 return input[-1]["content"] 833 834 def extract_history(input): 835 return input[:-1] 836 837 chain_with_history = ( 838 { 839 "question": itemgetter("messages") | RunnableLambda(extract_question), 840 "chat_history": itemgetter("messages") | RunnableLambda(extract_history), 841 } 842 | prompt_with_history 843 | fake_chat_model 844 | StrOutputParser() 845 ) 846 847 input_example = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]} 848 assert chain_with_history.invoke(input_example) == "Databricks" 849 850 with mlflow.start_run(): 851 model_info = mlflow.langchain.log_model( 852 chain_with_history, name="model_path", input_example=input_example 853 ) 854 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 855 assert loaded_model.invoke(input_example) == "Databricks" 856 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 857 input_schema = pyfunc_loaded_model.metadata.get_input_schema() 858 assert input_schema == Schema([ 859 ColSpec( 860 Array( 861 Object([ 862 Property("role", DataType.string), 863 Property("content", DataType.string), 864 ]) 865 ), 866 "messages", 867 ) 868 ]) 869 assert pyfunc_loaded_model.predict(input_example) == ["Databricks"] 870 871 udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, result_type="string") 872 df = spark.createDataFrame([(input_example["messages"],)], ["messages"]) 873 df = df.withColumn("answer", udf("messages")) 874 pdf = df.toPandas() 875 assert pdf["answer"].tolist() == ["Databricks"] 876 877 inference_payload = load_serving_example(model_info.model_uri) 878 response = pyfunc_serve_and_score_model( 879 model_info.model_uri, 880 data=inference_payload, 881 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 882 extra_args=["--env-manager", "local"], 883 ) 884 assert json.loads(response.content.decode("utf-8")) == ["Databricks"] 885 886 887 class ChatModel(SimpleChatModel): 888 def _call(self, messages, stop, run_manager, **kwargs): 889 return "\n".join([f"{message.type}: {message.content}" for message in messages]) 890 891 @property 892 def _llm_type(self) -> str: 893 return "chat model" 894 895 896 @skip_if_v1 897 def test_predict_with_builtin_pyfunc_chat_conversion(spark): 898 # TODO: Migrate to models-from-code 899 input_example = { 900 "messages": [ 901 {"role": "system", "content": "You are a helpful assistant."}, 902 {"role": "assistant", "content": "What would you like to ask?"}, 903 {"role": "user", "content": "Who owns MLflow?"}, 904 ] 905 } 906 content = ( 907 "system: You are a helpful assistant.\n" 908 "ai: What would you like to ask?\n" 909 "human: Who owns MLflow?" 910 ) 911 912 chain = ChatModel() | StrOutputParser() 913 assert chain.invoke([HumanMessage(content="Who owns MLflow?")]) == "human: Who owns MLflow?" 914 with pytest.raises(ValueError, match="Invalid input type"): 915 chain.invoke(input_example) 916 917 with mlflow.start_run(): 918 model_info = mlflow.langchain.log_model( 919 chain, name="model_path", input_example=input_example 920 ) 921 922 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 923 assert ( 924 loaded_model.invoke([HumanMessage(content="Who owns MLflow?")]) == "human: Who owns MLflow?" 925 ) 926 927 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 928 expected_chat_response = { 929 "id": None, 930 "object": "chat.completion", 931 "created": 1677858242, 932 "model": "", 933 "choices": [ 934 { 935 "index": 0, 936 "message": { 937 "role": "assistant", 938 "content": content, 939 }, 940 "finish_reason": None, 941 } 942 ], 943 "usage": { 944 "prompt_tokens": None, 945 "completion_tokens": None, 946 "total_tokens": None, 947 }, 948 } 949 950 with mock.patch("time.time", return_value=1677858242): 951 result1 = pyfunc_loaded_model.predict(input_example) 952 result1[0]["id"] = None 953 assert result1 == [expected_chat_response] 954 result2 = pyfunc_loaded_model.predict([input_example, input_example]) 955 result2[0]["id"] = None 956 result2[1]["id"] = None 957 assert result2 == [ 958 expected_chat_response, 959 expected_chat_response, 960 ] 961 962 with pytest.raises(MlflowException, match="Unrecognized chat message role"): 963 pyfunc_loaded_model.predict({"messages": [{"role": "foobar", "content": "test content"}]}) 964 965 966 @skip_if_v1 967 def test_predict_with_builtin_pyfunc_chat_conversion_for_aimessage_response(): 968 class ChatModel(SimpleChatModel): 969 def _call(self, messages, stop, run_manager, **kwargs): 970 return "You own MLflow" 971 972 @property 973 def _llm_type(self) -> str: 974 return "chat model" 975 976 input_example = { 977 "messages": [ 978 {"role": "system", "content": "You are a helpful assistant."}, 979 {"role": "assistant", "content": "What would you like to ask?"}, 980 {"role": "user", "content": "Who owns MLflow?"}, 981 ] 982 } 983 984 chain = ChatModel() 985 result = chain.invoke([HumanMessage(content="Who owns MLflow?")]) 986 assert isinstance(result, AIMessage) 987 assert result.content == "You own MLflow" 988 989 with mlflow.start_run(): 990 model_info = mlflow.langchain.log_model( 991 chain, name="model_path", input_example=input_example 992 ) 993 994 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 995 result = loaded_model.invoke([HumanMessage(content="Who owns MLflow?")]) 996 assert isinstance(result, AIMessage) 997 assert result.content == "You own MLflow" 998 999 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1000 with mock.patch("time.time", return_value=1677858242): 1001 result = pyfunc_loaded_model.predict(input_example) 1002 assert "id" in result[0], "Response message id is lost." 1003 result[0]["id"] = None 1004 assert result == [ 1005 { 1006 "id": None, 1007 "object": "chat.completion", 1008 "created": 1677858242, 1009 "model": "", 1010 "choices": [ 1011 { 1012 "index": 0, 1013 "message": { 1014 "role": "assistant", 1015 "content": "You own MLflow", 1016 }, 1017 "finish_reason": None, 1018 } 1019 ], 1020 "usage": { 1021 "prompt_tokens": None, 1022 "completion_tokens": None, 1023 "total_tokens": None, 1024 }, 1025 } 1026 ] 1027 1028 1029 @skip_if_v1 1030 def test_pyfunc_builtin_chat_request_conversion_fails_gracefully(): 1031 chain = RunnablePassthrough() | itemgetter("messages") 1032 # Ensure we're going to test that "messages" remains intact & unchanged even if it 1033 # doesn't appear explicitly in the chain's input schema 1034 assert "messages" not in chain.input_schema().model_fields 1035 1036 with mlflow.start_run(): 1037 model_info = mlflow.langchain.log_model(chain, name="model_path") 1038 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1039 1040 assert pyfunc_loaded_model.predict({"messages": "not an array"}) == "not an array" 1041 1042 # Verify that messages aren't converted to LangChain format if extra keys are present, 1043 # under the assumption that additional keys can't be specified when calling LangChain invoke() 1044 # / batch() with chat messages 1045 assert pyfunc_loaded_model.predict({ 1046 "messages": [{"role": "user", "content": "blah"}], 1047 "extrakey": "extra", 1048 }) == [ 1049 {"role": "user", "content": "blah"}, 1050 ] 1051 1052 # Verify that messages aren't converted to LangChain format if role / content are missing 1053 # or extra keys are present in the message 1054 assert pyfunc_loaded_model.predict({ 1055 "messages": [{"content": "blah"}], 1056 }) == [ 1057 {"content": "blah"}, 1058 ] 1059 assert pyfunc_loaded_model.predict({ 1060 "messages": [{"role": "user", "content": "blah"}, {}], 1061 }) == [ 1062 {"role": "user", "content": "blah"}, 1063 {}, 1064 ] 1065 assert pyfunc_loaded_model.predict({ 1066 "messages": [{"role": "user", "content": 123}], 1067 }) == [ 1068 {"role": "user", "content": 123}, 1069 ] 1070 1071 # Verify behavior for batches of message histories 1072 assert pyfunc_loaded_model.predict([ 1073 { 1074 "messages": "not an array", 1075 }, 1076 { 1077 "messages": [{"role": "user", "content": "content"}], 1078 }, 1079 ]) == [ 1080 "not an array", 1081 [{"role": "user", "content": "content"}], 1082 ] 1083 assert pyfunc_loaded_model.predict([ 1084 { 1085 "messages": [{"role": "user", "content": "content"}], 1086 }, 1087 {"messages": [{"role": "user", "content": "content"}], "extrakey": "extra"}, 1088 ]) == [ 1089 [{"role": "user", "content": "content"}], 1090 [{"role": "user", "content": "content"}], 1091 ] 1092 assert pyfunc_loaded_model.predict([ 1093 { 1094 "messages": [{"role": "user", "content": "content"}], 1095 }, 1096 { 1097 "messages": [ 1098 {"role": "user", "content": "content"}, 1099 {"role": "user", "content": 123}, 1100 ], 1101 }, 1102 ]) == [ 1103 [{"role": "user", "content": "content"}], 1104 [{"role": "user", "content": "content"}, {"role": "user", "content": 123}], 1105 ] 1106 1107 1108 @skip_if_v1 1109 def test_save_load_chain_that_relies_on_pickle_serialization(monkeypatch, model_path): 1110 from langchain_community.llms.databricks import Databricks 1111 1112 monkeypatch.setattr( 1113 "langchain_community.llms.databricks._DatabricksServingEndpointClient", 1114 mock.MagicMock(), 1115 ) 1116 monkeypatch.setenv("DATABRICKS_HOST", "test-host") 1117 monkeypatch.setenv("DATABRICKS_TOKEN", "test-token") 1118 1119 llm_kwargs = {"endpoint_name": "test-endpoint", "temperature": 0.9} 1120 if IS_PICKLE_SERIALIZATION_RESTRICTED: 1121 llm_kwargs["allow_dangerous_deserialization"] = True 1122 1123 llm = Databricks(**llm_kwargs) 1124 prompt = PromptTemplate(input_variables=["question"], template="I have a question: {question}") 1125 chain = prompt | llm | StrOutputParser() 1126 1127 # Not passing an input_example to avoid triggering prediction 1128 mlflow.langchain.save_model(chain, model_path) 1129 1130 loaded_model = mlflow.langchain.load_model(model_path) 1131 1132 # Check if the deserialized model has the same endpoint and temperature 1133 loaded_databricks_llm = loaded_model.middle[0] 1134 assert loaded_databricks_llm.endpoint_name == "test-endpoint" 1135 assert loaded_databricks_llm.temperature == 0.9 1136 1137 1138 def _get_message_content(predictions): 1139 return predictions[0]["choices"][0]["message"]["content"] 1140 1141 1142 @pytest.mark.parametrize( 1143 ("chain_path", "model_config"), 1144 [ 1145 ( 1146 os.path.abspath("tests/langchain/sample_code/chain.py"), 1147 os.path.abspath("tests/langchain/sample_code/config.yml"), 1148 ), 1149 ( 1150 "tests/langchain/../langchain/sample_code/chain.py", 1151 "tests/langchain/../langchain/sample_code/config.yml", 1152 ), 1153 ], 1154 ) 1155 def test_save_load_chain_as_code(chain_path, model_config, monkeypatch): 1156 input_example = { 1157 "messages": [ 1158 { 1159 "role": "user", 1160 "content": "What is a good name for a company that makes MLflow?", 1161 } 1162 ] 1163 } 1164 artifact_path = "model_path" 1165 with mlflow.start_run() as run: 1166 model_info = mlflow.langchain.log_model( 1167 chain_path, 1168 name=artifact_path, 1169 input_example=input_example, 1170 model_config=model_config, 1171 ) 1172 1173 client = mlflow.tracking.MlflowClient() 1174 run_id = run.info.run_id 1175 assert client.get_run(run_id).data.params == { 1176 "llm_prompt_template": "Answer the following question based on " 1177 "the context: {context}\nQuestion: {question}", 1178 "embedding_size": "5", 1179 "not_used_array": "[1, 2, 3]", 1180 "response": "Databricks", 1181 } 1182 1183 assert mlflow.models.model_config.__mlflow_model_config__ is None 1184 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1185 1186 # During the loading process, MLflow executes the chain.py file to 1187 # load the model class. It should not generate any traces even if 1188 # the code enables autologging and invoke chain. 1189 assert len(get_traces()) == 0 1190 1191 assert mlflow.models.model_config.__mlflow_model_config__ is None 1192 answer = "Databricks" 1193 assert loaded_model.invoke(input_example) == answer 1194 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1195 assert answer == _get_message_content(pyfunc_loaded_model.predict(input_example)) 1196 1197 inference_payload = load_serving_example(model_info.model_uri) 1198 response = pyfunc_serve_and_score_model( 1199 model_info.model_uri, 1200 data=inference_payload, 1201 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1202 extra_args=["--env-manager", "local"], 1203 ) 1204 predictions = json.loads(response.content.decode("utf-8")) 1205 # Mock out the `created` timestamp as it is not deterministic 1206 expected = [{**try_transform_response_to_chat_format(answer), "created": mock.ANY}] 1207 assert expected == predictions 1208 1209 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 1210 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 1211 assert reloaded_model.resources["databricks"] == { 1212 "serving_endpoint": [{"name": "fake-endpoint"}] 1213 } 1214 assert reloaded_model.metadata["dependencies_schemas"] == { 1215 DependenciesSchemasType.RETRIEVERS.value: [ 1216 { 1217 "doc_uri": "doc-uri", 1218 "name": "retriever", 1219 "other_columns": ["column1", "column2"], 1220 "primary_key": "primary-key", 1221 "text_column": "text-column", 1222 } 1223 ] 1224 } 1225 1226 # Emulate the model serving environment 1227 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 1228 monkeypatch.setenv("ENABLE_MLFLOW_TRACING", "true") 1229 mlflow.tracing.reset() 1230 1231 request_id = "mock_request_id" 1232 tracer = MlflowLangchainTracer(prediction_context=Context(request_id)) 1233 input_example = {"messages": [{"role": "user", "content": json.dumps(TEST_CONTENT)}]} 1234 response = pyfunc_loaded_model._model_impl._predict_with_callbacks( 1235 data=input_example, callback_handlers=[tracer] 1236 ) 1237 assert response["choices"][0]["message"]["content"] == "Databricks" 1238 trace = pop_trace(request_id) 1239 assert trace["info"]["tags"][DependenciesSchemasType.RETRIEVERS.value] == json.dumps([ 1240 { 1241 "doc_uri": "doc-uri", 1242 "name": "retriever", 1243 "other_columns": ["column1", "column2"], 1244 "primary_key": "primary-key", 1245 "text_column": "text-column", 1246 } 1247 ]) 1248 1249 1250 @pytest.mark.parametrize( 1251 "chain_path", 1252 [ 1253 os.path.abspath("tests/langchain/sample_code/chain.py"), 1254 "tests/langchain/../langchain/sample_code/chain.py", 1255 ], 1256 ) 1257 def test_save_load_chain_as_code_model_config_dict(chain_path): 1258 input_example = { 1259 "messages": [ 1260 { 1261 "role": "user", 1262 "content": "What is a good name for a company that makes MLflow?", 1263 } 1264 ] 1265 } 1266 with mlflow.start_run(): 1267 model_info = mlflow.langchain.log_model( 1268 chain_path, 1269 name="model_path", 1270 input_example=input_example, 1271 model_config={ 1272 "response": "modified response", 1273 "embedding_size": 5, 1274 "llm_prompt_template": "answer the question", 1275 }, 1276 ) 1277 1278 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1279 answer = "modified response" 1280 assert loaded_model.invoke(input_example) == answer 1281 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1282 assert answer == _get_message_content(pyfunc_loaded_model.predict(input_example)) 1283 1284 1285 @pytest.mark.parametrize( 1286 "model_config", 1287 [ 1288 os.path.abspath("tests/langchain/sample_code/config.yml"), 1289 "tests/langchain/../langchain/sample_code/config.yml", 1290 ], 1291 ) 1292 def test_save_load_chain_as_code_with_different_names(tmp_path, model_config): 1293 input_example = { 1294 "messages": [ 1295 { 1296 "role": "user", 1297 "content": "What is a good name for a company that makes MLflow?", 1298 } 1299 ] 1300 } 1301 1302 # Read the contents of the original chain file 1303 with open("tests/langchain/sample_code/chain.py") as chain_file: 1304 chain_file_content = chain_file.read() 1305 1306 temp_file = tmp_path / "model.py" 1307 temp_file.write_text(chain_file_content) 1308 1309 with mlflow.start_run(): 1310 model_info = mlflow.langchain.log_model( 1311 str(temp_file), 1312 name="model_path", 1313 input_example=input_example, 1314 model_config=model_config, 1315 ) 1316 1317 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1318 answer = "Databricks" 1319 assert loaded_model.invoke(input_example) == answer 1320 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1321 assert answer == _get_message_content(pyfunc_loaded_model.predict(input_example)) 1322 1323 1324 @pytest.mark.parametrize( 1325 "chain_path", 1326 [ 1327 os.path.abspath("tests/langchain/sample_code/chain.py"), 1328 "tests/langchain/../langchain/sample_code/chain.py", 1329 ], 1330 ) 1331 @pytest.mark.parametrize( 1332 "model_config", 1333 [ 1334 os.path.abspath("tests/langchain/sample_code/config.yml"), 1335 "tests/langchain/../langchain/sample_code/config.yml", 1336 ], 1337 ) 1338 def test_save_load_chain_as_code_multiple_times(tmp_path, chain_path, model_config): 1339 input_example = { 1340 "messages": [ 1341 { 1342 "role": "user", 1343 "content": "What is a good name for a company that makes MLflow?", 1344 } 1345 ] 1346 } 1347 with mlflow.start_run(): 1348 model_info = mlflow.langchain.log_model( 1349 chain_path, 1350 name="model_path", 1351 input_example=input_example, 1352 model_config=model_config, 1353 ) 1354 1355 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1356 with open(model_config) as f: 1357 base_config = yaml.safe_load(f) 1358 1359 assert loaded_model.middle[0].messages[0].prompt.template == base_config["llm_prompt_template"] 1360 1361 file_name = "config_updated.yml" 1362 new_config_file = str(tmp_path.joinpath(file_name)) 1363 1364 new_config = base_config.copy() 1365 new_config["llm_prompt_template"] = "new_template" 1366 with open(new_config_file, "w") as f: 1367 yaml.dump(new_config, f) 1368 1369 with mlflow.start_run(): 1370 model_info = mlflow.langchain.log_model( 1371 chain_path, 1372 name="model_path", 1373 input_example=input_example, 1374 model_config=new_config_file, 1375 ) 1376 1377 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1378 assert loaded_model.middle[0].messages[0].prompt.template == new_config["llm_prompt_template"] 1379 1380 1381 @pytest.mark.parametrize( 1382 "chain_path", 1383 [ 1384 os.path.abspath("tests/langchain/sample_code/chain.py"), 1385 "tests/langchain/../langchain/sample_code/chain.py", 1386 ], 1387 ) 1388 def test_save_load_chain_as_code_with_model_paths(chain_path): 1389 input_example = { 1390 "messages": [ 1391 { 1392 "role": "user", 1393 "content": "What is a good name for a company that makes MLflow?", 1394 } 1395 ] 1396 } 1397 artifact_path = "model_path" 1398 with ( 1399 mlflow.start_run(), 1400 mock.patch("mlflow.langchain.model._add_code_from_conf_to_system_path") as add_mock, 1401 ): 1402 model_info = mlflow.langchain.log_model( 1403 chain_path, 1404 name=artifact_path, 1405 input_example=input_example, 1406 code_paths=[__file__], 1407 model_config={ 1408 "response": "modified response", 1409 "embedding_size": 5, 1410 "llm_prompt_template": "answer the question", 1411 }, 1412 ) 1413 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1414 answer = "modified response" 1415 _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.langchain.FLAVOR_NAME) 1416 assert loaded_model.invoke(input_example) == answer 1417 add_mock.assert_called() 1418 1419 1420 @pytest.mark.parametrize("chain_path", [os.path.abspath("tests/langchain1/sample_code/chain.py")]) 1421 def test_save_load_chain_errors(chain_path): 1422 input_example = { 1423 "messages": [ 1424 { 1425 "role": "user", 1426 "content": "What is a good name for a company that makes MLflow?", 1427 } 1428 ] 1429 } 1430 with mlflow.start_run(): 1431 with pytest.raises( 1432 MlflowException, 1433 match=f"The provided model path '{chain_path}' does not exist. " 1434 "Ensure the file path is valid and try again.", 1435 ): 1436 mlflow.langchain.log_model( 1437 chain_path, 1438 name="model_path", 1439 input_example=input_example, 1440 model_config="tests/langchain/state_of_the_union.txt", 1441 ) 1442 1443 1444 @pytest.mark.parametrize( 1445 "chain_path", 1446 [ 1447 os.path.abspath("tests/langchain/sample_code/no_config/chain.py"), 1448 "tests/langchain/../langchain/sample_code/no_config/chain.py", 1449 ], 1450 ) 1451 def test_save_load_chain_as_code_optional_code_path(chain_path): 1452 input_example = { 1453 "messages": [ 1454 { 1455 "role": "user", 1456 "content": "What is a good name for a company that makes MLflow?", 1457 } 1458 ] 1459 } 1460 artifact_path = "new_model_path" 1461 with mlflow.start_run(): 1462 model_info = mlflow.langchain.log_model( 1463 chain_path, 1464 name=artifact_path, 1465 input_example=input_example, 1466 ) 1467 1468 assert mlflow.models.model_config.__mlflow_model_config__ is None 1469 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1470 assert mlflow.models.model_config.__mlflow_model_config__ is None 1471 answer = "Databricks" 1472 assert loaded_model.invoke(input_example) == answer 1473 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1474 assert ( 1475 pyfunc_loaded_model 1476 .predict(input_example)[0] 1477 .get("choices")[0] 1478 .get("message") 1479 .get("content") 1480 == answer 1481 ) 1482 1483 inference_payload = load_serving_example(model_info.model_uri) 1484 response = pyfunc_serve_and_score_model( 1485 model_info.model_uri, 1486 data=inference_payload, 1487 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1488 extra_args=["--env-manager", "local"], 1489 ) 1490 # avoid minor diff of created time in the response 1491 prediction_result = json.loads(response.content.decode("utf-8")) 1492 prediction_result[0]["created"] = 123 1493 expected_prediction = try_transform_response_to_chat_format(answer) 1494 expected_prediction["created"] = 123 1495 assert prediction_result == [expected_prediction] 1496 1497 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 1498 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 1499 assert reloaded_model.resources["databricks"] == { 1500 "serving_endpoint": [{"name": "fake-endpoint"}] 1501 } 1502 assert reloaded_model.metadata is None 1503 1504 1505 @pytest.fixture 1506 def fake_chat_stream_model(): 1507 class FakeChatStreamModel(SimpleChatModel): 1508 """Fake Chat Stream Model wrapper for testing purposes.""" 1509 1510 endpoint_name: str = "fake-stream-endpoint" 1511 1512 def _call( 1513 self, 1514 messages: list[BaseMessage], 1515 stop: list[str] | None = None, 1516 run_manager: CallbackManagerForLLMRun | None = None, 1517 **kwargs: Any, 1518 ) -> str: 1519 return "Databricks" 1520 1521 def _stream( 1522 self, 1523 messages: list[BaseMessage], 1524 stop: list[str] | None = None, 1525 run_manager: CallbackManagerForLLMRun | None = None, 1526 **kwargs: Any, 1527 ) -> Iterator[ChatGenerationChunk]: 1528 for chunk_content, finish_reason in [ 1529 ("Da", None), 1530 ("tab", None), 1531 ("ricks", "stop"), 1532 ]: 1533 chunk = ChatGenerationChunk( 1534 message=AIMessageChunk(content=chunk_content), 1535 generation_info={"finish_reason": finish_reason}, 1536 ) 1537 if run_manager: 1538 run_manager.on_llm_new_token(chunk.text, chunk=chunk) 1539 1540 yield chunk 1541 1542 @property 1543 def _llm_type(self) -> str: 1544 return "fake chat model" 1545 1546 return FakeChatStreamModel(endpoint_name="fake-stream-endpoint") 1547 1548 1549 @skip_if_v1 1550 @pytest.mark.parametrize("provide_signature", [True, False]) 1551 def test_simple_chat_model_stream_inference(fake_chat_stream_model, provide_signature): 1552 # TODO: Migrate to models-from-code 1553 input_example = { 1554 "messages": [ 1555 {"role": "system", "content": "You are a helpful assistant."}, 1556 {"role": "assistant", "content": "What would you like to ask?"}, 1557 {"role": "user", "content": "Who owns MLflow?"}, 1558 ] 1559 } 1560 with mlflow.start_run(): 1561 model_info = mlflow.langchain.log_model( 1562 fake_chat_stream_model, 1563 name="model", 1564 ) 1565 1566 if provide_signature: 1567 signature = infer_signature(model_input=input_example) 1568 with mlflow.start_run(): 1569 model_with_siginature_info = mlflow.langchain.log_model( 1570 fake_chat_stream_model, name="model", signature=signature 1571 ) 1572 else: 1573 with mlflow.start_run(): 1574 model_with_siginature_info = mlflow.langchain.log_model( 1575 fake_chat_stream_model, name="model", input_example=input_example 1576 ) 1577 1578 for model_uri in [model_info.model_uri, model_with_siginature_info.model_uri]: 1579 loaded_model = mlflow.pyfunc.load_model(model_uri) 1580 1581 chunk_iter = loaded_model.predict_stream(input_example) 1582 1583 finish_reason = "stop" 1584 1585 with mock.patch("time.time", return_value=1677858242): 1586 chunks = list(chunk_iter) 1587 1588 for chunk in chunks: 1589 assert "id" in chunk, "chunk id is lost." 1590 chunk["id"] = None 1591 1592 assert chunks == [ 1593 { 1594 "id": None, 1595 "object": "chat.completion.chunk", 1596 "created": 1677858242, 1597 "model": "", 1598 "choices": [ 1599 { 1600 "index": 0, 1601 "finish_reason": None, 1602 "delta": {"role": "assistant", "content": "Da"}, 1603 } 1604 ], 1605 }, 1606 { 1607 "id": None, 1608 "object": "chat.completion.chunk", 1609 "created": 1677858242, 1610 "model": "", 1611 "choices": [ 1612 { 1613 "index": 0, 1614 "finish_reason": None, 1615 "delta": {"role": "assistant", "content": "tab"}, 1616 } 1617 ], 1618 }, 1619 { 1620 "id": None, 1621 "object": "chat.completion.chunk", 1622 "created": 1677858242, 1623 "model": "", 1624 "choices": [ 1625 { 1626 "index": 0, 1627 "finish_reason": finish_reason, 1628 "delta": {"role": "assistant", "content": "ricks"}, 1629 } 1630 ], 1631 }, 1632 ] 1633 1634 1635 @skip_if_v1 1636 def test_simple_chat_model_stream_with_callbacks(fake_chat_stream_model): 1637 # TODO: Migrate to models-from-code 1638 class TestCallbackHandler(BaseCallbackHandler): 1639 def __init__(self): 1640 super().__init__() 1641 self.num_llm_start_calls = 0 1642 1643 def on_llm_start( 1644 self, 1645 serialized: dict[str, Any], 1646 prompts: list[str], 1647 **kwargs: Any, 1648 ) -> Any: 1649 self.num_llm_start_calls += 1 1650 1651 prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?") 1652 chain = prompt | fake_chat_stream_model | StrOutputParser() 1653 # Test the basic functionality of the chain 1654 assert chain.invoke({"industry": "tech"}) == "Databricks" 1655 1656 with mlflow.start_run(): 1657 model_info = mlflow.langchain.log_model( 1658 chain, name="model_path", input_example={"industry": "tech"} 1659 ) 1660 1661 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1662 1663 callback_handler1 = TestCallbackHandler() 1664 callback_handler2 = TestCallbackHandler() 1665 1666 # Ensure handlers have not been called yet 1667 assert callback_handler1.num_llm_start_calls == 0 1668 assert callback_handler2.num_llm_start_calls == 0 1669 1670 stream = pyfunc_loaded_model._model_impl._predict_stream_with_callbacks( 1671 {"industry": "tech"}, 1672 callback_handlers=[callback_handler1, callback_handler2], 1673 ) 1674 assert list(stream) == ["Da", "tab", "ricks"] 1675 1676 # Test that the callback handlers were called 1677 assert callback_handler1.num_llm_start_calls == 1 1678 assert callback_handler2.num_llm_start_calls == 1 1679 1680 1681 @skip_if_v1 1682 def test_langchain_model_save_load_with_listeners(fake_chat_model): 1683 # Migrate this to models-from-code 1684 prompt = ChatPromptTemplate.from_messages([ 1685 ("system", "You are a helpful assistant."), 1686 MessagesPlaceholder(variable_name="history"), 1687 ("human", "{question}"), 1688 ]) 1689 1690 def retrieve_history(input): 1691 return {"history": [], "question": input["question"], "name": input["name"]} 1692 1693 chain = ( 1694 {"question": itemgetter("question"), "name": itemgetter("name")} 1695 | (RunnableLambda(retrieve_history) | prompt | fake_chat_model).with_listeners() 1696 | StrOutputParser() 1697 | RunnablePassthrough() 1698 ) 1699 input_example = {"question": "Who owns MLflow?", "name": ""} 1700 assert chain.invoke(input_example) == "Databricks" 1701 1702 with mlflow.start_run(): 1703 model_info = mlflow.langchain.log_model( 1704 chain, name="model_path", input_example=input_example 1705 ) 1706 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1707 assert loaded_model.invoke(input_example) == "Databricks" 1708 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1709 assert pyfunc_loaded_model.predict(input_example) == ["Databricks"] 1710 1711 inference_payload = load_serving_example(model_info.model_uri) 1712 response = pyfunc_serve_and_score_model( 1713 model_info.model_uri, 1714 data=inference_payload, 1715 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1716 extra_args=["--env-manager", "local"], 1717 ) 1718 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 1719 "predictions": ["Databricks"] 1720 } 1721 1722 1723 @pytest.mark.parametrize("env_var", ["MLFLOW_ENABLE_TRACE_IN_SERVING", "ENABLE_MLFLOW_TRACING"]) 1724 def test_langchain_model_not_inject_callback_when_disabled(monkeypatch, model_path, env_var): 1725 # Emulate the model serving environment 1726 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 1727 1728 # Disable tracing 1729 monkeypatch.setenv(env_var, "false") 1730 1731 mlflow.langchain.save_model(SIMPLE_MODEL_CODE_PATH, model_path) 1732 1733 loaded_model = mlflow.pyfunc.load_model(model_path) 1734 loaded_model.predict({"product": "shoe"}) 1735 1736 # Trace should be logged to the inference table 1737 from mlflow.tracing.export.inference_table import _TRACE_BUFFER 1738 1739 assert _TRACE_BUFFER == {} 1740 1741 1742 @pytest.mark.parametrize( 1743 "chain_path", 1744 [ 1745 os.path.abspath("tests/langchain/sample_code/no_config/chain.py"), 1746 "tests/langchain/../langchain/sample_code/no_config/chain.py", 1747 ], 1748 ) 1749 def test_save_model_as_code_correct_streamable(chain_path): 1750 input_example = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]} 1751 answer = "Databricks" 1752 artifact_path = "model_path" 1753 with mlflow.start_run(): 1754 model_info = mlflow.langchain.log_model( 1755 chain_path, 1756 name=artifact_path, 1757 input_example=input_example, 1758 ) 1759 1760 assert model_info.flavors["langchain"]["streamable"] is True 1761 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1762 1763 with mock.patch("time.time", return_value=1677858242): 1764 assert pyfunc_loaded_model._model_impl._predict_with_callbacks(input_example) == { 1765 "id": None, 1766 "object": "chat.completion", 1767 "created": 1677858242, 1768 "model": "", 1769 "choices": [ 1770 { 1771 "index": 0, 1772 "message": { 1773 "role": "assistant", 1774 "content": "Databricks", 1775 }, 1776 "finish_reason": None, 1777 } 1778 ], 1779 "usage": { 1780 "prompt_tokens": None, 1781 "completion_tokens": None, 1782 "total_tokens": None, 1783 }, 1784 } 1785 1786 inference_payload = load_serving_example(model_info.model_uri) 1787 response = pyfunc_serve_and_score_model( 1788 model_info.model_uri, 1789 data=inference_payload, 1790 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1791 extra_args=["--env-manager", "local"], 1792 ) 1793 # avoid minor diff of created time in the response 1794 prediction_result = json.loads(response.content.decode("utf-8")) 1795 prediction_result[0]["created"] = 123 1796 expected_prediction = try_transform_response_to_chat_format(answer) 1797 expected_prediction["created"] = 123 1798 assert prediction_result == [expected_prediction] 1799 1800 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 1801 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 1802 assert reloaded_model.resources["databricks"] == { 1803 "serving_endpoint": [{"name": "fake-endpoint"}] 1804 } 1805 1806 1807 @skip_if_v1 1808 def test_save_load_langchain_binding(fake_chat_model): 1809 runnable_binding = RunnableBinding(bound=fake_chat_model, kwargs={"stop": ["-"]}) 1810 model = runnable_binding | StrOutputParser() 1811 assert model.invoke("Say something") == "Databricks" 1812 1813 with mlflow.start_run(): 1814 model_info = mlflow.langchain.log_model( 1815 model, name="model_path", input_example="Say something" 1816 ) 1817 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1818 assert loaded_model.first.kwargs == {"stop": ["-"]} 1819 assert loaded_model.invoke("hello") == "Databricks" 1820 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1821 assert pyfunc_loaded_model.predict("hello") == ["Databricks"] 1822 1823 inference_payload = load_serving_example(model_info.model_uri) 1824 response = pyfunc_serve_and_score_model( 1825 model_info.model_uri, 1826 data=inference_payload, 1827 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1828 extra_args=["--env-manager", "local"], 1829 ) 1830 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 1831 "predictions": ["Databricks"] 1832 } 1833 1834 1835 @skip_if_v1 1836 def test_save_load_langchain_binding_llm_with_tool(): 1837 from langchain_core.tools import tool 1838 1839 # We need to use ChatOpenAI from langchain_openai as community one does not support bind_tools 1840 from langchain_openai import ChatOpenAI 1841 1842 @tool 1843 def add(a: int, b: int) -> int: 1844 """Adds a and b. 1845 1846 Args: 1847 a: first int 1848 b: second int 1849 """ 1850 return a + b 1851 1852 runnable_binding = ChatOpenAI(temperature=0.9).bind_tools([add]) 1853 model = runnable_binding | StrOutputParser() 1854 expected_output = '[{"role": "user", "content": "hello"}]' 1855 assert model.invoke("hello") == expected_output 1856 1857 with mlflow.start_run(): 1858 model_info = mlflow.langchain.log_model(model, name="model_path", input_example="hello") 1859 1860 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1861 assert loaded_model.invoke("hello") == expected_output 1862 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1863 assert pyfunc_loaded_model.predict("hello") == [expected_output] 1864 1865 1866 @skip_if_v1 1867 def test_langchain_bindings_save_load_with_config_and_types(fake_chat_model): 1868 class CustomCallbackHandler(BaseCallbackHandler): 1869 def __init__(self): 1870 self.count = 0 1871 1872 def on_chain_start( 1873 self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any 1874 ) -> None: 1875 self.count += 1 1876 1877 def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: 1878 self.count += 1 1879 1880 model = fake_chat_model | StrOutputParser() 1881 callback = CustomCallbackHandler() 1882 model = model.with_config(run_name="test_run", callbacks=[callback]).with_types( 1883 input_type=str, output_type=str 1884 ) 1885 assert model.invoke("Say something") == "Databricks" 1886 assert callback.count == 4 1887 1888 with mlflow.start_run(): 1889 model_info = mlflow.langchain.log_model(model, name="model_path", input_example="hello") 1890 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1891 assert loaded_model.config["run_name"] == "test_run" 1892 assert loaded_model.custom_input_type == str 1893 assert loaded_model.custom_output_type == str 1894 callback = loaded_model.config["callbacks"][0] 1895 assert loaded_model.invoke("hello") == "Databricks" 1896 assert callback.count > 8 # accumulated count (inside model logging we also call the callbacks) 1897 pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1898 assert pyfunc_loaded_model.predict("hello") == ["Databricks"] 1899 1900 inference_payload = load_serving_example(model_info.model_uri) 1901 response = pyfunc_serve_and_score_model( 1902 model_info.model_uri, 1903 data=inference_payload, 1904 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1905 extra_args=["--env-manager", "local"], 1906 ) 1907 assert PredictionsResponse.from_json(response.content.decode("utf-8")) == { 1908 "predictions": ["Databricks"] 1909 } 1910 1911 1912 @pytest.mark.parametrize( 1913 "chain_path", 1914 [ 1915 os.path.abspath("tests/langchain/sample_code/chain.py"), 1916 "tests/langchain/../langchain/sample_code/chain.py", 1917 ], 1918 ) 1919 @pytest.mark.parametrize( 1920 "model_config", 1921 [ 1922 os.path.abspath("tests/langchain/sample_code/config.yml"), 1923 "tests/langchain/../langchain/sample_code/config.yml", 1924 ], 1925 ) 1926 def test_load_chain_with_model_config_overrides_saved_config(chain_path, model_config): 1927 input_example = { 1928 "messages": [ 1929 { 1930 "role": "user", 1931 "content": "What is a good name for a company that makes MLflow?", 1932 } 1933 ] 1934 } 1935 artifact_path = "model_path" 1936 with mlflow.start_run(): 1937 model_info = mlflow.langchain.log_model( 1938 chain_path, 1939 name=artifact_path, 1940 input_example=input_example, 1941 model_config=model_config, 1942 ) 1943 1944 with mock.patch("mlflow.langchain.model._load_model_code_path") as load_model_code_path_mock: 1945 mlflow.pyfunc.load_model(model_info.model_uri, model_config={"embedding_size": 2}) 1946 args, kwargs = load_model_code_path_mock.call_args 1947 assert args[1] == { 1948 "embedding_size": 2, 1949 "llm_prompt_template": "Answer the following question based on the " 1950 "context: {context}\nQuestion: {question}", 1951 "not_used_array": [ 1952 1, 1953 2, 1954 3, 1955 ], 1956 "response": "Databricks", 1957 } 1958 1959 1960 @skip_if_v1 1961 @pytest.mark.parametrize("streamable", [True, False, None]) 1962 def test_langchain_model_streamable_param_in_log_model(streamable, fake_chat_model): 1963 # TODO: Migrate to models-from-code 1964 prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?") 1965 chain = prompt | fake_chat_model | StrOutputParser() 1966 1967 runnable = RunnableParallel({"llm": lambda _: "completion"}) 1968 1969 for model in [chain, runnable]: 1970 with mock.patch("mlflow.langchain.model._save_model"), mlflow.start_run(): 1971 model_info = mlflow.langchain.log_model( 1972 model, 1973 name="model", 1974 streamable=streamable, 1975 pip_requirements=[], 1976 ) 1977 1978 expected = (streamable is None) or streamable 1979 assert model_info.flavors["langchain"]["streamable"] is expected 1980 1981 1982 @pytest.fixture 1983 def model_type(request): 1984 return lc_runnables_types()[request.param] 1985 1986 1987 @skip_if_v1 1988 @pytest.mark.parametrize("streamable", [True, False, None]) 1989 @pytest.mark.parametrize("model_type", range(len(lc_runnables_types())), indirect=True) 1990 def test_langchain_model_streamable_param_in_log_model_for_lc_runnable_types( 1991 streamable, model_type 1992 ): 1993 with mock.patch("mlflow.langchain.model._save_model"), mlflow.start_run(): 1994 model = mock.MagicMock(spec=model_type) 1995 assert hasattr(model, "stream") is True 1996 model_info = mlflow.langchain.log_model( 1997 model, 1998 name="model", 1999 streamable=streamable, 2000 pip_requirements=[], 2001 ) 2002 2003 expected = (streamable is None) or streamable 2004 assert model_info.flavors["langchain"]["streamable"] is expected 2005 2006 del model.stream 2007 assert hasattr(model, "stream") is False 2008 model_info = mlflow.langchain.log_model( 2009 model, 2010 name="model", 2011 streamable=streamable, 2012 pip_requirements=[], 2013 ) 2014 assert model_info.flavors["langchain"]["streamable"] is bool(streamable) 2015 2016 2017 @skip_if_v1 2018 def test_agent_executor_model_with_messages_input(): 2019 question = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]} 2020 2021 with mlflow.start_run(): 2022 model_info = mlflow.langchain.log_model( 2023 os.path.abspath("tests/langchain/agent_executor/chain.py"), 2024 name="model_path", 2025 input_example=question, 2026 model_config=os.path.abspath("tests/langchain/agent_executor/config.yml"), 2027 ) 2028 native_model = mlflow.langchain.load_model(model_info.model_uri) 2029 assert native_model.invoke(question)["output"] == "Databricks" 2030 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2031 # TODO: in the future we should fix this and output shouldn't be wrapped 2032 # The result is wrapped in a list because during signature enforcement we convert 2033 # input data to pandas dataframe, then inside _convert_llm_input_data 2034 # we convert pandas dataframe back to records, and a single row will be 2035 # wrapped inside a list. 2036 assert pyfunc_model.predict(question) == ["Databricks"] 2037 2038 # Test stream output 2039 response = pyfunc_model.predict_stream(question) 2040 assert inspect.isgenerator(response) 2041 2042 expected_response = [ 2043 { 2044 "output": "Databricks", 2045 "messages": [ 2046 { 2047 "additional_kwargs": {}, 2048 "content": "Databricks", 2049 "example": False, 2050 "id": None, 2051 "invalid_tool_calls": [], 2052 "name": None, 2053 "response_metadata": {}, 2054 "tool_calls": [], 2055 "type": "ai", 2056 "usage_metadata": None, 2057 } 2058 ], 2059 } 2060 ] 2061 assert list(response) == expected_response 2062 2063 2064 def test_invoking_model_with_params(): 2065 with mlflow.start_run(): 2066 model_info = mlflow.langchain.log_model( 2067 os.path.abspath("tests/langchain/sample_code/model_with_config.py"), 2068 name="model", 2069 ) 2070 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2071 data = {"x": 0} 2072 pyfunc_model.predict(data) 2073 params = {"config": {"temperature": 3.0}} 2074 with mock.patch("mlflow.pyfunc._validate_prediction_input", return_value=(data, params)): 2075 # This proves the temperature is passed to the model 2076 with pytest.raises(MlflowException, match=r"Input should be less than or equal to 2"): 2077 pyfunc_model.predict(data=data, params=params) 2078 2079 2080 def test_custom_resources(tmp_path): 2081 input_example = { 2082 "messages": [ 2083 { 2084 "role": "user", 2085 "content": "What is a good name for a company that makes MLflow?", 2086 } 2087 ] 2088 } 2089 expected_resources = { 2090 "api_version": "1", 2091 "databricks": { 2092 "serving_endpoint": [ 2093 {"name": "databricks-mixtral-8x7b-instruct"}, 2094 {"name": "databricks-bge-large-en"}, 2095 {"name": "azure-eastus-model-serving-2_vs_endpoint"}, 2096 ], 2097 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 2098 "sql_warehouse": [{"name": "testid"}], 2099 "function": [ 2100 {"name": "rag.studio.test_function_a"}, 2101 {"name": "rag.studio.test_function_b"}, 2102 ], 2103 }, 2104 } 2105 artifact_path = "model_path" 2106 chain_path = "tests/langchain/sample_code/chain.py" 2107 with mlflow.start_run(): 2108 model_info = mlflow.langchain.log_model( 2109 chain_path, 2110 name=artifact_path, 2111 input_example=input_example, 2112 model_config="tests/langchain/sample_code/config.yml", 2113 resources=[ 2114 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 2115 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 2116 DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"), 2117 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 2118 DatabricksSQLWarehouse(warehouse_id="testid"), 2119 DatabricksFunction(function_name="rag.studio.test_function_a"), 2120 DatabricksFunction(function_name="rag.studio.test_function_b"), 2121 ], 2122 ) 2123 2124 model_path = _download_artifact_from_uri(model_info.model_uri) 2125 reloaded_model = Model.load(os.path.join(model_path, "MLmodel")) 2126 assert reloaded_model.resources == expected_resources 2127 2128 yaml_file = tmp_path.joinpath("resources.yaml") 2129 with open(yaml_file, "w") as f: 2130 f.write( 2131 """ 2132 api_version: "1" 2133 databricks: 2134 vector_search_index: 2135 - name: rag.studio_bugbash.databricks_docs_index 2136 serving_endpoint: 2137 - name: databricks-mixtral-8x7b-instruct 2138 - name: databricks-bge-large-en 2139 - name: azure-eastus-model-serving-2_vs_endpoint 2140 sql_warehouse: 2141 - name: testid 2142 function: 2143 - name: rag.studio.test_function_a 2144 - name: rag.studio.test_function_b 2145 """ 2146 ) 2147 2148 artifact_path_2 = "model_path_2" 2149 with mlflow.start_run(): 2150 model_info = mlflow.langchain.log_model( 2151 chain_path, 2152 name=artifact_path_2, 2153 input_example=input_example, 2154 model_config="tests/langchain/sample_code/config.yml", 2155 resources=yaml_file, 2156 ) 2157 2158 model_path = _download_artifact_from_uri(model_info.model_uri) 2159 reloaded_model = Model.load(os.path.join(model_path, "MLmodel")) 2160 assert reloaded_model.resources == expected_resources 2161 2162 2163 def test_pyfunc_converts_chat_request_for_non_chat_model(): 2164 input_example = {"messages": [{"role": "user", "content": "Hello"}]} 2165 with mlflow.start_run(): 2166 model_info = mlflow.langchain.log_model( 2167 lc_model=SIMPLE_MODEL_CODE_PATH, 2168 input_example=input_example, 2169 ) 2170 2171 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2172 result = pyfunc_model.predict(input_example) 2173 # output are converted to chatResponse format 2174 assert isinstance(result[0]["choices"][0]["message"]["content"], str) 2175 2176 response = pyfunc_model.predict_stream(input_example) 2177 assert inspect.isgenerator(response) 2178 assert isinstance(list(response)[0]["choices"][0]["delta"]["content"], str) 2179 2180 2181 @skip_if_v1 2182 def test_pyfunc_should_not_convert_chat_request_if_env_var_is_set_to_false(monkeypatch): 2183 monkeypatch.setenv(MLFLOW_CONVERT_MESSAGES_DICT_FOR_LANGCHAIN.name, "false") 2184 2185 # This model is an example when the model expects a chat request 2186 # format input, but the input should not be converted to List[BaseMessage] 2187 model = RunnablePassthrough.assign(problem=lambda x: x["messages"][-1]["content"]) | itemgetter( 2188 "problem" 2189 ) 2190 input_example = {"messages": [{"role": "user", "content": "Databricks"}]} 2191 assert model.invoke(input_example) == "Databricks" 2192 2193 # pyfunc model can accepts chat request format even the chain 2194 # itself does not accept it, but we need to use the correct 2195 # input example to infer model signature 2196 with mlflow.start_run(): 2197 model_info = mlflow.langchain.log_model(model, input_example=input_example) 2198 2199 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2200 result = pyfunc_model.predict(input_example) 2201 assert result == ["Databricks"] 2202 2203 # Test stream output 2204 response = pyfunc_model.predict_stream(input_example) 2205 assert inspect.isgenerator(response) 2206 assert list(response) == ["Databricks"], list(response) 2207 2208 2209 def test_log_langchain_model_with_prompt(): 2210 mlflow.register_prompt( 2211 name="qa_prompt", 2212 template="What is a good name for a company that makes {{product}}?", 2213 commit_message="Prompt for generating company names", 2214 ) 2215 mlflow.set_prompt_alias("qa_prompt", alias="production", version=1) 2216 2217 mlflow.register_prompt(name="another_prompt", template="Hi") 2218 2219 # If the model code involves `mlflow.load_prompt()` call, the prompt version 2220 # should be automatically logged to the Run 2221 with mlflow.start_run(): 2222 model_info = mlflow.langchain.log_model( 2223 os.path.abspath("tests/langchain/sample_code/chain_with_mlflow_prompt.py"), 2224 name="model", 2225 # Manually associate another prompt 2226 prompts=["prompts:/another_prompt/1"], 2227 ) 2228 2229 # Check that prompts were linked to the run via the linkedPrompts tag 2230 from mlflow.tracing.constant import TraceTagKey 2231 2232 run = mlflow.MlflowClient().get_run(model_info.run_id) 2233 linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS) 2234 assert linked_prompts_tag is not None 2235 2236 linked_prompts = json.loads(linked_prompts_tag) 2237 assert len(linked_prompts) == 2 2238 assert {p["name"] for p in linked_prompts} == {"qa_prompt", "another_prompt"} 2239 2240 prompt = mlflow.load_prompt("qa_prompt", 1) 2241 assert prompt.aliases == ["production"] 2242 2243 prompt = mlflow.load_prompt("another_prompt", 1) 2244 2245 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2246 response = pyfunc_model.predict({"product": "shoe"}) 2247 # Fake OpenAI server echo the input 2248 assert ( 2249 response 2250 == '[{"role": "user", "content": "What is a good name for a company that makes shoe?"}]' 2251 ) 2252 2253 2254 def test_predict_with_callbacks_with_tracing(monkeypatch): 2255 # Simulate the model serving environment 2256 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 2257 monkeypatch.setenv("ENABLE_MLFLOW_TRACING", "true") 2258 mlflow.tracing.reset() 2259 2260 model_info = mlflow.langchain.log_model( 2261 os.path.abspath("tests/langchain/sample_code/workflow.py"), 2262 name="model_path", 2263 input_example={"messages": [{"role": "user", "content": "What is MLflow?"}]}, 2264 ) 2265 # serving environment only reads from this environment variable 2266 monkeypatch.setenv("MLFLOW_EXPERIMENT_ID", mlflow.last_logged_model().experiment_id) 2267 2268 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2269 2270 request_id = "mock_request_id" 2271 tracer = MlflowLangchainTracer(prediction_context=Context(request_id)) 2272 input_example = {"messages": [{"role": "user", "content": TEST_CONTENT}]} 2273 2274 with mock.patch("mlflow.tracing.client.TracingClient.start_trace") as mock_start_trace: 2275 pyfunc_model._model_impl._predict_with_callbacks( 2276 data=input_example, callback_handlers=[tracer] 2277 ) 2278 mlflow.flush_trace_async_logging() 2279 mock_start_trace.assert_called_once() 2280 trace_info = mock_start_trace.call_args[0][0] 2281 assert trace_info.client_request_id == request_id 2282 assert trace_info.request_metadata[TraceMetadataKey.MODEL_ID] == model_info.model_id 2283 2284 2285 @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="The test is only for langchain>=1 versions") 2286 def test_langchain_v1_save_model_as_pickle_error(): 2287 model = create_openai_runnable() 2288 with mlflow.start_run(): 2289 with pytest.raises( 2290 MlflowException, 2291 match="LangChain v1 onward only supports models-from-code", 2292 ): 2293 mlflow.langchain.log_model( 2294 model, name="langchain_model", input_example={"product": "MLflow"} 2295 )