/ tests / langchain / test_langchain_model_export.py
test_langchain_model_export.py
   1  import inspect
   2  import json
   3  import os
   4  import shutil
   5  from operator import itemgetter
   6  from typing import Any, Iterator
   7  from unittest import mock
   8  
   9  import langchain
  10  import pytest
  11  import yaml
  12  from langchain_community.document_loaders import TextLoader
  13  from langchain_community.embeddings.fake import FakeEmbeddings
  14  from langchain_community.llms import OpenAI
  15  from langchain_community.utilities import TextRequestsWrapper
  16  from langchain_community.vectorstores import FAISS
  17  from langchain_core.callbacks.base import BaseCallbackHandler
  18  from langchain_core.callbacks.manager import CallbackManagerForLLMRun
  19  from langchain_core.language_models import SimpleChatModel
  20  from langchain_core.messages import (
  21      AIMessage,
  22      AIMessageChunk,
  23      BaseMessage,
  24      HumanMessage,
  25  )
  26  from langchain_core.output_parsers import StrOutputParser
  27  from langchain_core.outputs import ChatGenerationChunk
  28  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
  29  from langchain_core.runnables import (
  30      RunnableBinding,
  31      RunnableBranch,
  32      RunnableLambda,
  33      RunnableParallel,
  34      RunnablePassthrough,
  35      RunnableSequence,
  36  )
  37  from langchain_core.tools import Tool
  38  from langchain_openai import ChatOpenAI
  39  from langchain_text_splitters.character import CharacterTextSplitter
  40  from packaging import version
  41  from pydantic import BaseModel
  42  from pyspark.sql import SparkSession
  43  
  44  import mlflow
  45  import mlflow.models.model
  46  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
  47  from mlflow.deployments import PredictionsResponse
  48  from mlflow.environment_variables import MLFLOW_CONVERT_MESSAGES_DICT_FOR_LANGCHAIN
  49  from mlflow.exceptions import MlflowException
  50  from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
  51  from mlflow.langchain.utils.chat import (
  52      try_transform_response_to_chat_format,
  53  )
  54  from mlflow.langchain.utils.logging import (
  55      IS_PICKLE_SERIALIZATION_RESTRICTED,
  56      lc_runnables_types,
  57  )
  58  from mlflow.models import Model
  59  from mlflow.models.dependencies_schemas import DependenciesSchemasType
  60  from mlflow.models.resources import (
  61      DatabricksFunction,
  62      DatabricksServingEndpoint,
  63      DatabricksSQLWarehouse,
  64      DatabricksVectorSearchIndex,
  65  )
  66  from mlflow.models.signature import Schema, infer_signature
  67  from mlflow.models.utils import load_serving_example
  68  from mlflow.pyfunc.context import Context
  69  from mlflow.tracing.constant import TraceMetadataKey
  70  from mlflow.tracing.export.inference_table import pop_trace
  71  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  72  from mlflow.types.schema import Array, ColSpec, DataType, Object, Property
  73  
  74  from tests.helper_functions import _compare_logged_code_paths, pyfunc_serve_and_score_model
  75  from tests.langchain.conftest import DeterministicDummyEmbeddings
  76  from tests.tracing.helper import get_traces
  77  
  78  # this kwarg was added in langchain_community 0.0.27, and
  79  # prevents the use of pickled objects if not provided.
  80  VECTORSTORE_KWARGS = (
  81      {"allow_dangerous_deserialization": True} if IS_PICKLE_SERIALIZATION_RESTRICTED else {}
  82  )
  83  
  84  IS_LANGCHAIN_03 = version.parse(langchain.__version__) >= version.parse("0.3.0")
  85  IS_LANGCHAIN_v1 = version.parse(langchain.__version__).major >= 1
  86  LANGCHAIN_V1_SKIP_REASON = "Pickle serialization is not supported for LangChain v1"
  87  
  88  # Reusable decorator for skipping tests on LangChain v1
  89  skip_if_v1 = pytest.mark.skipif(IS_LANGCHAIN_v1, reason=LANGCHAIN_V1_SKIP_REASON)
  90  
  91  # The mock OAI completion endpoint returns payload as it is
  92  TEST_CONTENT = [{"role": "user", "content": "What is MLflow?"}]
  93  
  94  SIMPLE_MODEL_CODE_PATH = "tests/langchain/sample_code/simple_runnable.py"
  95  
  96  
  97  @pytest.fixture
  98  def model_path(tmp_path):
  99      return tmp_path / "model"
 100  
 101  
 102  @pytest.fixture(scope="module")
 103  def spark():
 104      with SparkSession.builder.master("local[*]").getOrCreate() as s:
 105          yield s
 106  
 107  
 108  def create_openai_runnable():
 109      from langchain_core.output_parsers import StrOutputParser
 110  
 111      prompt = PromptTemplate(
 112          input_variables=["product"],
 113          template="What is {product}?",
 114      )
 115      return prompt | ChatOpenAI(temperature=0.9) | StrOutputParser()
 116  
 117  
 118  @pytest.fixture
 119  def fake_chat_model():
 120      class FakeChatModel(SimpleChatModel):
 121          """Fake Chat Model wrapper for testing purposes."""
 122  
 123          endpoint_name: str = "fake-endpoint"
 124  
 125          def _call(
 126              self,
 127              messages: list[BaseMessage],
 128              stop: list[str] | None = None,
 129              run_manager: CallbackManagerForLLMRun | None = None,
 130              **kwargs: Any,
 131          ) -> str:
 132              return "Databricks"
 133  
 134          @property
 135          def _llm_type(self) -> str:
 136              return "fake chat model"
 137  
 138      return FakeChatModel(endpoint_name="fake-endpoint")
 139  
 140  
 141  @pytest.fixture
 142  def fake_classifier_chat_model():
 143      class FakeMlflowClassifier(SimpleChatModel):
 144          """Fake Chat Model wrapper for testing purposes."""
 145  
 146          def _call(
 147              self,
 148              messages: list[BaseMessage],
 149              stop: list[str] | None = None,
 150              run_manager: CallbackManagerForLLMRun | None = None,
 151              **kwargs: Any,
 152          ) -> str:
 153              if "MLflow" in messages[0].content.split(":")[1]:
 154                  return "yes"
 155              if "cat" in messages[0].content.split(":")[1]:
 156                  return "no"
 157              return "unknown"
 158  
 159          @property
 160          def _llm_type(self) -> str:
 161              return "fake mlflow classifier"
 162  
 163      return FakeMlflowClassifier()
 164  
 165  
 166  @skip_if_v1
 167  def test_langchain_native_log_and_load_model():
 168      model = create_openai_runnable()
 169  
 170      with mlflow.start_run():
 171          logged_model = mlflow.langchain.log_model(
 172              model, name="langchain_model", input_example={"product": "MLflow"}
 173          )
 174  
 175      loaded_model = mlflow.langchain.load_model(logged_model.model_uri)
 176  
 177      assert "langchain" in logged_model.flavors
 178      assert str(logged_model.signature.inputs) == "['product': string (required)]"
 179      assert str(logged_model.signature.outputs) == "[string (required)]"
 180  
 181      assert type(loaded_model) == RunnableSequence
 182      assert loaded_model.steps[0].template == "What is {product}?"
 183      assert type(loaded_model.steps[1]).__name__ == "ChatOpenAI"
 184  
 185      # Predict
 186      loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
 187      result = loaded_model.predict([{"product": "MLflow"}])
 188      assert result == [json.dumps(TEST_CONTENT)]
 189  
 190      # Predict stream
 191      result = loaded_model.predict_stream([{"product": "MLflow"}])
 192      assert inspect.isgenerator(result)
 193      assert list(result) == ["Hello", " world"]
 194  
 195  
 196  @skip_if_v1
 197  def test_pyfunc_spark_udf_with_langchain_model(spark):
 198      model = create_openai_runnable()
 199      with mlflow.start_run():
 200          logged_model = mlflow.langchain.log_model(
 201              model, name="langchain_model", input_example={"product": "MLflow"}
 202          )
 203      loaded_model = mlflow.pyfunc.spark_udf(spark, logged_model.model_uri, result_type="string")
 204      df = spark.createDataFrame([("MLflow",), ("Spark",)], ["product"])
 205      df = df.withColumn("answer", loaded_model())
 206      pdf = df.toPandas()
 207      assert pdf["answer"].tolist() == [
 208          '[{"role": "user", "content": "What is MLflow?"}]',
 209          '[{"role": "user", "content": "What is Spark?"}]',
 210      ]
 211  
 212  
 213  @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="create_agent is not supported in LangChain v0")
 214  def test_langchain_agent_model_predict(monkeypatch):
 215      input_example = {"input": "What is 2 * 3?"}
 216  
 217      with mlflow.start_run():
 218          logged_model = mlflow.langchain.log_model(
 219              # OpenAI Client since 1.0 contains thread lock object that cannot be
 220              # pickled. Therefore, AgentExecutor cannot be saved with the legacy
 221              # object-based logging and we need to use Model-from-Code logging.
 222              "tests/langchain/sample_code/openai_agent.py",
 223              name="langchain_model",
 224              input_example=input_example,
 225          )
 226  
 227      loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
 228  
 229      # Basic prediction
 230      response = loaded_model.predict([input_example])
 231      expected_output = "The result of 2 * 3 is 6."
 232  
 233      assert response[0]["messages"][-1]["content"] == expected_output
 234  
 235      # Stream prediction
 236      response = loaded_model.predict_stream([input_example])
 237      assert inspect.isgenerator(response)
 238      assert list(response) == [
 239          {"model": {"messages": [mock.ANY]}},
 240          {"tools": {"messages": [mock.ANY]}},
 241          {"model": {"messages": [mock.ANY]}},
 242      ]
 243  
 244      # Model serving
 245      inference_payload = load_serving_example(logged_model.model_uri)
 246      response = pyfunc_serve_and_score_model(
 247          logged_model.model_uri,
 248          data=inference_payload,
 249          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 250          extra_args=["--env-manager", "local"],
 251      )
 252      # TODO: The response is not wrapped by the "predictions" key. This is a bug in
 253      # output handling. Often the user input contains a key "input" because it is
 254      # used in popular agent prompts in the hub. However, this confuses the scoring
 255      # server to treat it as a llm/v1/completion request.
 256      response = json.loads(response.content.decode("utf-8"))
 257      assert response[0]["messages"][-1]["content"] == expected_output
 258  
 259  
 260  def assert_equal_retrievers(retriever, expected_retriever):
 261      from langchain.schema.retriever import BaseRetriever
 262  
 263      assert isinstance(retriever, BaseRetriever)
 264      assert isinstance(retriever, type(expected_retriever))
 265      assert isinstance(retriever.vectorstore, type(expected_retriever.vectorstore))
 266      assert retriever.tags == expected_retriever.tags
 267      assert retriever.metadata == expected_retriever.metadata
 268      assert retriever.search_type == expected_retriever.search_type
 269      assert retriever.search_kwargs == expected_retriever.search_kwargs
 270  
 271  
 272  @skip_if_v1
 273  def test_log_and_load_retriever_chain(tmp_path):
 274      # Create the vector db, persist the db to a local fs folder
 275      loader = TextLoader("tests/langchain/state_of_the_union.txt")
 276      documents = loader.load()
 277      text_splitter = CharacterTextSplitter(chunk_size=256, chunk_overlap=0)
 278      docs = text_splitter.split_documents(documents)
 279      embeddings = DeterministicDummyEmbeddings(size=5)
 280      db = FAISS.from_documents(docs, embeddings)
 281      persist_dir = str(tmp_path / "faiss_index")
 282      db.save_local(persist_dir)
 283  
 284      # Define the loader_fn
 285      def load_retriever(persist_directory):
 286          import numpy as np
 287          from langchain.embeddings.base import Embeddings
 288  
 289          class DeterministicDummyEmbeddings(Embeddings, BaseModel):
 290              size: int
 291  
 292              def _get_embedding(self, text: str) -> list[float]:
 293                  if isinstance(text, np.ndarray):
 294                      text = text.item()
 295                  seed = abs(hash(text)) % (10**8)
 296                  np.random.seed(seed)
 297                  return list(np.random.normal(size=self.size))
 298  
 299              def embed_documents(self, texts: list[str]) -> list[list[float]]:
 300                  return [self._get_embedding(t) for t in texts]
 301  
 302              def embed_query(self, text: str) -> list[float]:
 303                  return self._get_embedding(text)
 304  
 305          embeddings = DeterministicDummyEmbeddings(size=5)
 306          vectorstore = FAISS.load_local(
 307              persist_directory,
 308              embeddings,
 309              **VECTORSTORE_KWARGS,
 310          )
 311          return vectorstore.as_retriever()
 312  
 313      query = "What did the president say about Ketanji Brown Jackson"
 314      langchain_input = {"query": query}
 315      # Log the retriever
 316      with mlflow.start_run():
 317          logged_model = mlflow.langchain.log_model(
 318              db.as_retriever(),
 319              name="retriever",
 320              loader_fn=load_retriever,
 321              persist_dir=persist_dir,
 322              input_example=langchain_input,
 323          )
 324  
 325      # Remove the persist_dir
 326      shutil.rmtree(persist_dir)
 327  
 328      # Load the retriever
 329      loaded_model = mlflow.langchain.load_model(logged_model.model_uri)
 330      assert_equal_retrievers(loaded_model, db.as_retriever())
 331  
 332      loaded_pyfunc_model = mlflow.pyfunc.load_model(logged_model.model_uri)
 333      result = loaded_pyfunc_model.predict([langchain_input])
 334      expected_result = [
 335          {
 336              "page_content": doc.page_content,
 337              "metadata": doc.metadata,
 338              "type": "Document",
 339              "id": mock.ANY,
 340          }
 341          for doc in db.as_retriever().get_relevant_documents(query)
 342      ]
 343      assert result == [expected_result]
 344  
 345      # Serve the retriever
 346      inference_payload = load_serving_example(logged_model.model_uri)
 347      response = pyfunc_serve_and_score_model(
 348          logged_model.model_uri,
 349          data=inference_payload,
 350          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 351          extra_args=["--env-manager", "local"],
 352      )
 353      pred = PredictionsResponse.from_json(response.content.decode("utf-8"))["predictions"]
 354      assert type(pred) == list
 355      assert len(pred) == 1
 356      docs_list = pred[0]
 357      assert type(docs_list) == list
 358      assert len(docs_list) == 4
 359      # The returned docs are non-deterministic when used with dummy embeddings,
 360      # so we cannot assert pred == {"predictions": [expected_result]}
 361  
 362  
 363  def load_requests_wrapper(_):
 364      return TextRequestsWrapper(headers=None, aiosession=None)
 365  
 366  
 367  @skip_if_v1
 368  def test_agent_with_unpicklable_tools(tmp_path):
 369      from langchain.agents import AgentType, initialize_agent
 370  
 371      tmp_file = tmp_path / "temp_file.txt"
 372      with open(tmp_file, mode="w") as temp_file:
 373          # files that aren't opened for reading cannot be pickled
 374          tools = [
 375              Tool.from_function(
 376                  func=lambda: temp_file,
 377                  name="Write 0",
 378                  description="If you need to write 0 to a file",
 379              )
 380          ]
 381          agent = initialize_agent(
 382              llm=OpenAI(temperature=0),
 383              tools=tools,
 384              agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
 385          )
 386  
 387          with pytest.raises(
 388              MlflowException,
 389              match=(
 390                  "Error when attempting to pickle the AgentExecutor tools. "
 391                  "This model likely does not support serialization."
 392              ),
 393          ):
 394              with mlflow.start_run():
 395                  mlflow.langchain.log_model(agent, name="unpicklable_tools")
 396  
 397  
 398  @skip_if_v1
 399  def test_save_load_runnable_passthrough():
 400      runnable = RunnablePassthrough()
 401      assert runnable.invoke("hello") == "hello"
 402  
 403      input_example = "hello"
 404      with mlflow.start_run():
 405          model_info = mlflow.langchain.log_model(
 406              runnable, name="model_path", input_example=input_example
 407          )
 408  
 409      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 410      assert loaded_model.invoke(input_example) == "hello"
 411      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 412      assert pyfunc_loaded_model.predict(["hello"]) == ["hello"]
 413  
 414  
 415  @skip_if_v1
 416  def test_save_load_runnable_lambda(spark):
 417      def add_one(x: int) -> int:
 418          return x + 1
 419  
 420      runnable = RunnableLambda(add_one)
 421  
 422      assert runnable.invoke(1) == 2
 423      assert runnable.batch([1, 2, 3]) == [2, 3, 4]
 424  
 425      with mlflow.start_run():
 426          model_info = mlflow.langchain.log_model(
 427              runnable, name="runnable_lambda", input_example=[1, 2, 3]
 428          )
 429  
 430      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 431      assert loaded_model.invoke(1) == 2
 432      assert loaded_model.batch([1, 2, 3]) == [2, 3, 4]
 433  
 434      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 435      assert loaded_model.predict(1) == [2]
 436      assert loaded_model.predict([1, 2, 3]) == [2, 3, 4]
 437  
 438      udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, result_type="long")
 439      df = spark.createDataFrame([(1,), (2,), (3,)], ["data"])
 440      df = df.withColumn("answer", udf("data"))
 441      pdf = df.toPandas()
 442      assert pdf["answer"].tolist() == [2, 3, 4]
 443  
 444  
 445  @skip_if_v1
 446  def test_save_load_runnable_lambda_in_sequence():
 447      def add_one(x):
 448          return x + 1
 449  
 450      def mul_two(x):
 451          return x * 2
 452  
 453      runnable_1 = RunnableLambda(add_one)
 454      runnable_2 = RunnableLambda(mul_two)
 455      sequence = runnable_1 | runnable_2
 456      assert sequence.invoke(1) == 4
 457  
 458      with mlflow.start_run():
 459          model_info = mlflow.langchain.log_model(
 460              sequence, name="model_path", input_example=[1, 2, 3]
 461          )
 462  
 463      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 464      assert loaded_model.invoke(1) == 4
 465      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 466      assert pyfunc_loaded_model.predict(1) == [4]
 467      assert pyfunc_loaded_model.predict([1, 2, 3]) == [4, 6, 8]
 468  
 469      inference_payload = load_serving_example(model_info.model_uri)
 470      response = pyfunc_serve_and_score_model(
 471          model_info.model_uri,
 472          data=inference_payload,
 473          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 474          extra_args=["--env-manager", "local"],
 475      )
 476      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
 477          "predictions": [4, 6, 8]
 478      }
 479  
 480  
 481  @skip_if_v1
 482  def test_predict_with_callbacks(fake_chat_model):
 483      class TestCallbackHandler(BaseCallbackHandler):
 484          def __init__(self):
 485              super().__init__()
 486              self.num_llm_start_calls = 0
 487  
 488          def on_llm_start(
 489              self,
 490              serialized: dict[str, Any],
 491              prompts: list[str],
 492              **kwargs: Any,
 493          ) -> Any:
 494              self.num_llm_start_calls += 1
 495  
 496      prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?")
 497      chain = prompt | fake_chat_model | StrOutputParser()
 498      # Test the basic functionality of the chain
 499      assert chain.invoke({"industry": "tech"}) == "Databricks"
 500  
 501      with mlflow.start_run():
 502          model_info = mlflow.langchain.log_model(
 503              chain, name="model_path", input_example={"industry": "tech"}
 504          )
 505  
 506      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 507  
 508      callback_handler1 = TestCallbackHandler()
 509      callback_handler2 = TestCallbackHandler()
 510  
 511      # Ensure handlers have not been called yet
 512      assert callback_handler1.num_llm_start_calls == 0
 513      assert callback_handler2.num_llm_start_calls == 0
 514  
 515      assert (
 516          pyfunc_loaded_model._model_impl._predict_with_callbacks(
 517              {"industry": "tech"},
 518              callback_handlers=[callback_handler1, callback_handler2],
 519          )
 520          == "Databricks"
 521      )
 522  
 523      # Test that the callback handlers were called
 524      assert callback_handler1.num_llm_start_calls == 1
 525      assert callback_handler2.num_llm_start_calls == 1
 526  
 527      inference_payload = load_serving_example(model_info.model_uri)
 528      response = pyfunc_serve_and_score_model(
 529          model_info.model_uri,
 530          data=inference_payload,
 531          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 532          extra_args=["--env-manager", "local"],
 533      )
 534      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
 535          "predictions": ["Databricks"]
 536      }
 537  
 538  
 539  @skip_if_v1
 540  def test_predict_with_callbacks_supports_chat_response_conversion(fake_chat_model):
 541      prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?")
 542      chain = prompt | fake_chat_model | StrOutputParser()
 543      # Test the basic functionality of the chain
 544      assert chain.invoke({"industry": "tech"}) == "Databricks"
 545  
 546      with mlflow.start_run():
 547          model_info = mlflow.langchain.log_model(
 548              chain, name="model_path", input_example={"industry": "tech"}
 549          )
 550  
 551      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 552      expected_chat_response = {
 553          "id": None,
 554          "object": "chat.completion",
 555          "created": 1677858242,
 556          "model": "",
 557          "choices": [
 558              {
 559                  "index": 0,
 560                  "message": {
 561                      "role": "assistant",
 562                      "content": "Databricks",
 563                  },
 564                  "finish_reason": None,
 565              }
 566          ],
 567          "usage": {
 568              "prompt_tokens": None,
 569              "completion_tokens": None,
 570              "total_tokens": None,
 571          },
 572      }
 573      with mock.patch("time.time", return_value=1677858242):
 574          assert (
 575              pyfunc_loaded_model._model_impl._predict_with_callbacks(
 576                  {"industry": "tech"},
 577                  convert_chat_responses=True,
 578              )
 579              == expected_chat_response
 580          )
 581  
 582          assert (
 583              pyfunc_loaded_model._model_impl._predict_with_callbacks(
 584                  {"industry": "tech"},
 585                  convert_chat_responses=False,
 586              )
 587              == "Databricks"
 588          )
 589  
 590  
 591  @skip_if_v1
 592  def test_save_load_runnable_parallel():
 593      runnable = RunnableParallel({"llm": create_openai_runnable()})
 594      expected_result = {"llm": json.dumps(TEST_CONTENT)}
 595      assert runnable.invoke({"product": "MLflow"}) == expected_result
 596      with mlflow.start_run():
 597          model_info = mlflow.langchain.log_model(
 598              runnable, name="model_path", input_example=[{"product": "MLflow"}]
 599          )
 600      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 601      assert loaded_model.invoke({"product": "MLflow"}) == expected_result
 602      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 603      assert pyfunc_loaded_model.predict([{"product": "MLflow"}]) == [expected_result]
 604  
 605      inference_payload = load_serving_example(model_info.model_uri)
 606      response = pyfunc_serve_and_score_model(
 607          model_info.model_uri,
 608          data=inference_payload,
 609          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 610          extra_args=["--env-manager", "local"],
 611      )
 612      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
 613          "predictions": [expected_result]
 614      }
 615  
 616  
 617  @skip_if_v1
 618  def test_save_load_chain_with_model_paths():
 619      prompt1 = PromptTemplate.from_template("what is the city {person} is from?")
 620      llm = ChatOpenAI(temperature=0.9)
 621      model = prompt1 | llm | StrOutputParser()
 622  
 623      with mlflow.start_run():
 624          model_info = mlflow.langchain.log_model(model, name="model_path")
 625      artifact_path = "model_path"
 626      with (
 627          mlflow.start_run(),
 628          mock.patch("mlflow.langchain.model._add_code_from_conf_to_system_path") as add_mock,
 629      ):
 630          model_info = mlflow.langchain.log_model(model, name=artifact_path, code_paths=[__file__])
 631          mlflow.langchain.load_model(model_info.model_uri)
 632          model_uri = model_info.model_uri
 633          _compare_logged_code_paths(__file__, model_uri, mlflow.langchain.FLAVOR_NAME)
 634          add_mock.assert_called()
 635  
 636  
 637  @skip_if_v1
 638  def test_save_load_rag(tmp_path, spark, fake_chat_model):
 639      # TODO: Migrate to models-from-code
 640      # Create the vector db, persist the db to a local fs folder
 641      loader = TextLoader("tests/langchain/state_of_the_union.txt")
 642      documents = loader.load()
 643      text_splitter = CharacterTextSplitter(chunk_size=10, chunk_overlap=0)
 644      docs = text_splitter.split_documents(documents)
 645      embeddings = DeterministicDummyEmbeddings(size=5)
 646      db = FAISS.from_documents(docs, embeddings)
 647      persist_dir = str(tmp_path / "faiss_index")
 648      db.save_local(persist_dir)
 649      retriever = db.as_retriever()
 650  
 651      def load_retriever(persist_directory):
 652          embeddings = FakeEmbeddings(size=5)
 653          vectorstore = FAISS.load_local(
 654              persist_directory,
 655              embeddings,
 656              **VECTORSTORE_KWARGS,
 657          )
 658          return vectorstore.as_retriever()
 659  
 660      prompt = ChatPromptTemplate.from_template(
 661          "Answer the following question based on the context: {context}\nQuestion: {question}"
 662      )
 663      retrieval_chain = (
 664          {
 665              "context": retriever,
 666              "question": RunnablePassthrough(),
 667          }
 668          | prompt
 669          | fake_chat_model
 670          | StrOutputParser()
 671      )
 672      question = "What is a good name for a company that makes MLflow?"
 673      answer = "Databricks"
 674      assert retrieval_chain.invoke(question) == answer
 675      with mlflow.start_run():
 676          model_info = mlflow.langchain.log_model(
 677              retrieval_chain,
 678              name="model_path",
 679              loader_fn=load_retriever,
 680              persist_dir=persist_dir,
 681              input_example=question,
 682          )
 683  
 684      # Remove the persist_dir
 685      shutil.rmtree(persist_dir)
 686  
 687      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 688      assert loaded_model.invoke(question) == answer
 689      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 690      assert pyfunc_loaded_model.predict(question) == [answer]
 691  
 692      udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, result_type="string")
 693      df = spark.createDataFrame([(question,), (question,)], ["question"])
 694      df = df.withColumn("answer", udf("question"))
 695      pdf = df.toPandas()
 696      assert pdf["answer"].tolist() == [answer, answer]
 697  
 698      inference_payload = load_serving_example(model_info.model_uri)
 699      response = pyfunc_serve_and_score_model(
 700          model_info.model_uri,
 701          data=inference_payload,
 702          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 703          extra_args=["--env-manager", "local"],
 704      )
 705      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
 706          "predictions": [answer]
 707      }
 708  
 709  
 710  @skip_if_v1
 711  def test_runnable_branch_save_load():
 712      branch = RunnableBranch(
 713          (lambda x: isinstance(x, str), lambda x: x.upper()),
 714          (lambda x: isinstance(x, int), lambda x: x + 1),
 715          (lambda x: isinstance(x, float), lambda x: x * 2),
 716          lambda x: "goodbye",
 717      )
 718  
 719      assert branch.invoke("hello") == "HELLO"
 720      assert branch.invoke({}) == "goodbye"
 721  
 722      with mlflow.start_run():
 723          # We only support single input format for now, so we should
 724          # not save signature for runnable branch which accepts multiple
 725          # input types
 726          model_info = mlflow.langchain.log_model(branch, name="model_path")
 727  
 728      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 729      assert loaded_model.invoke("hello") == "HELLO"
 730      assert loaded_model.invoke({}) == "goodbye"
 731      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 732      assert pyfunc_loaded_model.predict("hello") == "HELLO"
 733      assert pyfunc_loaded_model.predict({}) == "goodbye"
 734  
 735      response = pyfunc_serve_and_score_model(
 736          model_info.model_uri,
 737          data=json.dumps({"inputs": "hello"}),
 738          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 739          extra_args=["--env-manager", "local"],
 740      )
 741      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
 742          "predictions": "HELLO"
 743      }
 744  
 745  
 746  @skip_if_v1
 747  def test_complex_runnable_branch_save_load(fake_chat_model, fake_classifier_chat_model):
 748      prompt = ChatPromptTemplate.from_template("{question_is_relevant}\n{query}")
 749      # Need to add prompt here as the chat model doesn't accept dict input
 750      answer_model = prompt | fake_chat_model
 751  
 752      decline_to_answer = RunnableLambda(
 753          lambda x: "I cannot answer questions that are not about MLflow."
 754      )
 755      something_went_wrong = RunnableLambda(lambda x: "Something went wrong.")
 756  
 757      is_question_about_mlflow_prompt = ChatPromptTemplate.from_template(
 758          "You are classifying documents to know if this question "
 759          "is related with MLflow. Only answer with yes or no. The question is: {query}"
 760      )
 761  
 762      branch_node = RunnableBranch(
 763          (lambda x: x["question_is_relevant"].lower() == "yes", answer_model),
 764          (lambda x: x["question_is_relevant"].lower() == "no", decline_to_answer),
 765          something_went_wrong,
 766      )
 767  
 768      chain = (
 769          {
 770              "question_is_relevant": is_question_about_mlflow_prompt
 771              | fake_classifier_chat_model
 772              | StrOutputParser(),
 773              "query": itemgetter("query"),
 774          }
 775          | branch_node
 776          | StrOutputParser()
 777      )
 778  
 779      assert chain.invoke({"query": "Who owns MLflow?"}) == "Databricks"
 780      assert (
 781          chain.invoke({"query": "Do you like cat?"})
 782          == "I cannot answer questions that are not about MLflow."
 783      )
 784      assert chain.invoke({"query": "Are you happy today?"}) == "Something went wrong."
 785  
 786      with mlflow.start_run():
 787          model_info = mlflow.langchain.log_model(
 788              chain, name="model_path", input_example={"query": "Who owns MLflow?"}
 789          )
 790  
 791      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 792      assert loaded_model.invoke({"query": "Who owns MLflow?"}) == "Databricks"
 793      assert (
 794          loaded_model.invoke({"query": "Do you like cat?"})
 795          == "I cannot answer questions that are not about MLflow."
 796      )
 797      assert loaded_model.invoke({"query": "Are you happy today?"}) == "Something went wrong."
 798      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 799      assert pyfunc_loaded_model.predict({"query": "Who owns MLflow?"}) == ["Databricks"]
 800      assert pyfunc_loaded_model.predict({"query": "Do you like cat?"}) == [
 801          "I cannot answer questions that are not about MLflow."
 802      ]
 803      assert pyfunc_loaded_model.predict({"query": "Are you happy today?"}) == [
 804          "Something went wrong."
 805      ]
 806  
 807      inference_payload = load_serving_example(model_info.model_uri)
 808      response = pyfunc_serve_and_score_model(
 809          model_info.model_uri,
 810          data=inference_payload,
 811          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 812          extra_args=["--env-manager", "local"],
 813      )
 814      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
 815          "predictions": ["Databricks"]
 816      }
 817  
 818  
 819  @skip_if_v1
 820  def test_chat_with_history(spark, fake_chat_model):
 821      prompt_with_history_str = """
 822      Here is a history between you and a human: {chat_history}
 823  
 824      Now, please answer this question: {question}
 825      """
 826  
 827      prompt_with_history = PromptTemplate(
 828          input_variables=["chat_history", "question"], template=prompt_with_history_str
 829      )
 830  
 831      def extract_question(input):
 832          return input[-1]["content"]
 833  
 834      def extract_history(input):
 835          return input[:-1]
 836  
 837      chain_with_history = (
 838          {
 839              "question": itemgetter("messages") | RunnableLambda(extract_question),
 840              "chat_history": itemgetter("messages") | RunnableLambda(extract_history),
 841          }
 842          | prompt_with_history
 843          | fake_chat_model
 844          | StrOutputParser()
 845      )
 846  
 847      input_example = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]}
 848      assert chain_with_history.invoke(input_example) == "Databricks"
 849  
 850      with mlflow.start_run():
 851          model_info = mlflow.langchain.log_model(
 852              chain_with_history, name="model_path", input_example=input_example
 853          )
 854      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 855      assert loaded_model.invoke(input_example) == "Databricks"
 856      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 857      input_schema = pyfunc_loaded_model.metadata.get_input_schema()
 858      assert input_schema == Schema([
 859          ColSpec(
 860              Array(
 861                  Object([
 862                      Property("role", DataType.string),
 863                      Property("content", DataType.string),
 864                  ])
 865              ),
 866              "messages",
 867          )
 868      ])
 869      assert pyfunc_loaded_model.predict(input_example) == ["Databricks"]
 870  
 871      udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, result_type="string")
 872      df = spark.createDataFrame([(input_example["messages"],)], ["messages"])
 873      df = df.withColumn("answer", udf("messages"))
 874      pdf = df.toPandas()
 875      assert pdf["answer"].tolist() == ["Databricks"]
 876  
 877      inference_payload = load_serving_example(model_info.model_uri)
 878      response = pyfunc_serve_and_score_model(
 879          model_info.model_uri,
 880          data=inference_payload,
 881          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 882          extra_args=["--env-manager", "local"],
 883      )
 884      assert json.loads(response.content.decode("utf-8")) == ["Databricks"]
 885  
 886  
 887  class ChatModel(SimpleChatModel):
 888      def _call(self, messages, stop, run_manager, **kwargs):
 889          return "\n".join([f"{message.type}: {message.content}" for message in messages])
 890  
 891      @property
 892      def _llm_type(self) -> str:
 893          return "chat model"
 894  
 895  
 896  @skip_if_v1
 897  def test_predict_with_builtin_pyfunc_chat_conversion(spark):
 898      # TODO: Migrate to models-from-code
 899      input_example = {
 900          "messages": [
 901              {"role": "system", "content": "You are a helpful assistant."},
 902              {"role": "assistant", "content": "What would you like to ask?"},
 903              {"role": "user", "content": "Who owns MLflow?"},
 904          ]
 905      }
 906      content = (
 907          "system: You are a helpful assistant.\n"
 908          "ai: What would you like to ask?\n"
 909          "human: Who owns MLflow?"
 910      )
 911  
 912      chain = ChatModel() | StrOutputParser()
 913      assert chain.invoke([HumanMessage(content="Who owns MLflow?")]) == "human: Who owns MLflow?"
 914      with pytest.raises(ValueError, match="Invalid input type"):
 915          chain.invoke(input_example)
 916  
 917      with mlflow.start_run():
 918          model_info = mlflow.langchain.log_model(
 919              chain, name="model_path", input_example=input_example
 920          )
 921  
 922      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 923      assert (
 924          loaded_model.invoke([HumanMessage(content="Who owns MLflow?")]) == "human: Who owns MLflow?"
 925      )
 926  
 927      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 928      expected_chat_response = {
 929          "id": None,
 930          "object": "chat.completion",
 931          "created": 1677858242,
 932          "model": "",
 933          "choices": [
 934              {
 935                  "index": 0,
 936                  "message": {
 937                      "role": "assistant",
 938                      "content": content,
 939                  },
 940                  "finish_reason": None,
 941              }
 942          ],
 943          "usage": {
 944              "prompt_tokens": None,
 945              "completion_tokens": None,
 946              "total_tokens": None,
 947          },
 948      }
 949  
 950      with mock.patch("time.time", return_value=1677858242):
 951          result1 = pyfunc_loaded_model.predict(input_example)
 952          result1[0]["id"] = None
 953          assert result1 == [expected_chat_response]
 954          result2 = pyfunc_loaded_model.predict([input_example, input_example])
 955          result2[0]["id"] = None
 956          result2[1]["id"] = None
 957          assert result2 == [
 958              expected_chat_response,
 959              expected_chat_response,
 960          ]
 961  
 962      with pytest.raises(MlflowException, match="Unrecognized chat message role"):
 963          pyfunc_loaded_model.predict({"messages": [{"role": "foobar", "content": "test content"}]})
 964  
 965  
 966  @skip_if_v1
 967  def test_predict_with_builtin_pyfunc_chat_conversion_for_aimessage_response():
 968      class ChatModel(SimpleChatModel):
 969          def _call(self, messages, stop, run_manager, **kwargs):
 970              return "You own MLflow"
 971  
 972          @property
 973          def _llm_type(self) -> str:
 974              return "chat model"
 975  
 976      input_example = {
 977          "messages": [
 978              {"role": "system", "content": "You are a helpful assistant."},
 979              {"role": "assistant", "content": "What would you like to ask?"},
 980              {"role": "user", "content": "Who owns MLflow?"},
 981          ]
 982      }
 983  
 984      chain = ChatModel()
 985      result = chain.invoke([HumanMessage(content="Who owns MLflow?")])
 986      assert isinstance(result, AIMessage)
 987      assert result.content == "You own MLflow"
 988  
 989      with mlflow.start_run():
 990          model_info = mlflow.langchain.log_model(
 991              chain, name="model_path", input_example=input_example
 992          )
 993  
 994      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
 995      result = loaded_model.invoke([HumanMessage(content="Who owns MLflow?")])
 996      assert isinstance(result, AIMessage)
 997      assert result.content == "You own MLflow"
 998  
 999      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1000      with mock.patch("time.time", return_value=1677858242):
1001          result = pyfunc_loaded_model.predict(input_example)
1002          assert "id" in result[0], "Response message id is lost."
1003          result[0]["id"] = None
1004          assert result == [
1005              {
1006                  "id": None,
1007                  "object": "chat.completion",
1008                  "created": 1677858242,
1009                  "model": "",
1010                  "choices": [
1011                      {
1012                          "index": 0,
1013                          "message": {
1014                              "role": "assistant",
1015                              "content": "You own MLflow",
1016                          },
1017                          "finish_reason": None,
1018                      }
1019                  ],
1020                  "usage": {
1021                      "prompt_tokens": None,
1022                      "completion_tokens": None,
1023                      "total_tokens": None,
1024                  },
1025              }
1026          ]
1027  
1028  
1029  @skip_if_v1
1030  def test_pyfunc_builtin_chat_request_conversion_fails_gracefully():
1031      chain = RunnablePassthrough() | itemgetter("messages")
1032      # Ensure we're going to test that "messages" remains intact & unchanged even if it
1033      # doesn't appear explicitly in the chain's input schema
1034      assert "messages" not in chain.input_schema().model_fields
1035  
1036      with mlflow.start_run():
1037          model_info = mlflow.langchain.log_model(chain, name="model_path")
1038          pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1039  
1040      assert pyfunc_loaded_model.predict({"messages": "not an array"}) == "not an array"
1041  
1042      # Verify that messages aren't converted to LangChain format if extra keys are present,
1043      # under the assumption that additional keys can't be specified when calling LangChain invoke()
1044      # / batch() with chat messages
1045      assert pyfunc_loaded_model.predict({
1046          "messages": [{"role": "user", "content": "blah"}],
1047          "extrakey": "extra",
1048      }) == [
1049          {"role": "user", "content": "blah"},
1050      ]
1051  
1052      # Verify that messages aren't converted to LangChain format if role / content are missing
1053      # or extra keys are present in the message
1054      assert pyfunc_loaded_model.predict({
1055          "messages": [{"content": "blah"}],
1056      }) == [
1057          {"content": "blah"},
1058      ]
1059      assert pyfunc_loaded_model.predict({
1060          "messages": [{"role": "user", "content": "blah"}, {}],
1061      }) == [
1062          {"role": "user", "content": "blah"},
1063          {},
1064      ]
1065      assert pyfunc_loaded_model.predict({
1066          "messages": [{"role": "user", "content": 123}],
1067      }) == [
1068          {"role": "user", "content": 123},
1069      ]
1070  
1071      # Verify behavior for batches of message histories
1072      assert pyfunc_loaded_model.predict([
1073          {
1074              "messages": "not an array",
1075          },
1076          {
1077              "messages": [{"role": "user", "content": "content"}],
1078          },
1079      ]) == [
1080          "not an array",
1081          [{"role": "user", "content": "content"}],
1082      ]
1083      assert pyfunc_loaded_model.predict([
1084          {
1085              "messages": [{"role": "user", "content": "content"}],
1086          },
1087          {"messages": [{"role": "user", "content": "content"}], "extrakey": "extra"},
1088      ]) == [
1089          [{"role": "user", "content": "content"}],
1090          [{"role": "user", "content": "content"}],
1091      ]
1092      assert pyfunc_loaded_model.predict([
1093          {
1094              "messages": [{"role": "user", "content": "content"}],
1095          },
1096          {
1097              "messages": [
1098                  {"role": "user", "content": "content"},
1099                  {"role": "user", "content": 123},
1100              ],
1101          },
1102      ]) == [
1103          [{"role": "user", "content": "content"}],
1104          [{"role": "user", "content": "content"}, {"role": "user", "content": 123}],
1105      ]
1106  
1107  
1108  @skip_if_v1
1109  def test_save_load_chain_that_relies_on_pickle_serialization(monkeypatch, model_path):
1110      from langchain_community.llms.databricks import Databricks
1111  
1112      monkeypatch.setattr(
1113          "langchain_community.llms.databricks._DatabricksServingEndpointClient",
1114          mock.MagicMock(),
1115      )
1116      monkeypatch.setenv("DATABRICKS_HOST", "test-host")
1117      monkeypatch.setenv("DATABRICKS_TOKEN", "test-token")
1118  
1119      llm_kwargs = {"endpoint_name": "test-endpoint", "temperature": 0.9}
1120      if IS_PICKLE_SERIALIZATION_RESTRICTED:
1121          llm_kwargs["allow_dangerous_deserialization"] = True
1122  
1123      llm = Databricks(**llm_kwargs)
1124      prompt = PromptTemplate(input_variables=["question"], template="I have a question: {question}")
1125      chain = prompt | llm | StrOutputParser()
1126  
1127      # Not passing an input_example to avoid triggering prediction
1128      mlflow.langchain.save_model(chain, model_path)
1129  
1130      loaded_model = mlflow.langchain.load_model(model_path)
1131  
1132      # Check if the deserialized model has the same endpoint and temperature
1133      loaded_databricks_llm = loaded_model.middle[0]
1134      assert loaded_databricks_llm.endpoint_name == "test-endpoint"
1135      assert loaded_databricks_llm.temperature == 0.9
1136  
1137  
1138  def _get_message_content(predictions):
1139      return predictions[0]["choices"][0]["message"]["content"]
1140  
1141  
1142  @pytest.mark.parametrize(
1143      ("chain_path", "model_config"),
1144      [
1145          (
1146              os.path.abspath("tests/langchain/sample_code/chain.py"),
1147              os.path.abspath("tests/langchain/sample_code/config.yml"),
1148          ),
1149          (
1150              "tests/langchain/../langchain/sample_code/chain.py",
1151              "tests/langchain/../langchain/sample_code/config.yml",
1152          ),
1153      ],
1154  )
1155  def test_save_load_chain_as_code(chain_path, model_config, monkeypatch):
1156      input_example = {
1157          "messages": [
1158              {
1159                  "role": "user",
1160                  "content": "What is a good name for a company that makes MLflow?",
1161              }
1162          ]
1163      }
1164      artifact_path = "model_path"
1165      with mlflow.start_run() as run:
1166          model_info = mlflow.langchain.log_model(
1167              chain_path,
1168              name=artifact_path,
1169              input_example=input_example,
1170              model_config=model_config,
1171          )
1172  
1173      client = mlflow.tracking.MlflowClient()
1174      run_id = run.info.run_id
1175      assert client.get_run(run_id).data.params == {
1176          "llm_prompt_template": "Answer the following question based on "
1177          "the context: {context}\nQuestion: {question}",
1178          "embedding_size": "5",
1179          "not_used_array": "[1, 2, 3]",
1180          "response": "Databricks",
1181      }
1182  
1183      assert mlflow.models.model_config.__mlflow_model_config__ is None
1184      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1185  
1186      # During the loading process, MLflow executes the chain.py file to
1187      # load the model class. It should not generate any traces even if
1188      # the code enables autologging and invoke chain.
1189      assert len(get_traces()) == 0
1190  
1191      assert mlflow.models.model_config.__mlflow_model_config__ is None
1192      answer = "Databricks"
1193      assert loaded_model.invoke(input_example) == answer
1194      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1195      assert answer == _get_message_content(pyfunc_loaded_model.predict(input_example))
1196  
1197      inference_payload = load_serving_example(model_info.model_uri)
1198      response = pyfunc_serve_and_score_model(
1199          model_info.model_uri,
1200          data=inference_payload,
1201          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1202          extra_args=["--env-manager", "local"],
1203      )
1204      predictions = json.loads(response.content.decode("utf-8"))
1205      # Mock out the `created` timestamp as it is not deterministic
1206      expected = [{**try_transform_response_to_chat_format(answer), "created": mock.ANY}]
1207      assert expected == predictions
1208  
1209      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
1210      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
1211      assert reloaded_model.resources["databricks"] == {
1212          "serving_endpoint": [{"name": "fake-endpoint"}]
1213      }
1214      assert reloaded_model.metadata["dependencies_schemas"] == {
1215          DependenciesSchemasType.RETRIEVERS.value: [
1216              {
1217                  "doc_uri": "doc-uri",
1218                  "name": "retriever",
1219                  "other_columns": ["column1", "column2"],
1220                  "primary_key": "primary-key",
1221                  "text_column": "text-column",
1222              }
1223          ]
1224      }
1225  
1226      # Emulate the model serving environment
1227      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
1228      monkeypatch.setenv("ENABLE_MLFLOW_TRACING", "true")
1229      mlflow.tracing.reset()
1230  
1231      request_id = "mock_request_id"
1232      tracer = MlflowLangchainTracer(prediction_context=Context(request_id))
1233      input_example = {"messages": [{"role": "user", "content": json.dumps(TEST_CONTENT)}]}
1234      response = pyfunc_loaded_model._model_impl._predict_with_callbacks(
1235          data=input_example, callback_handlers=[tracer]
1236      )
1237      assert response["choices"][0]["message"]["content"] == "Databricks"
1238      trace = pop_trace(request_id)
1239      assert trace["info"]["tags"][DependenciesSchemasType.RETRIEVERS.value] == json.dumps([
1240          {
1241              "doc_uri": "doc-uri",
1242              "name": "retriever",
1243              "other_columns": ["column1", "column2"],
1244              "primary_key": "primary-key",
1245              "text_column": "text-column",
1246          }
1247      ])
1248  
1249  
1250  @pytest.mark.parametrize(
1251      "chain_path",
1252      [
1253          os.path.abspath("tests/langchain/sample_code/chain.py"),
1254          "tests/langchain/../langchain/sample_code/chain.py",
1255      ],
1256  )
1257  def test_save_load_chain_as_code_model_config_dict(chain_path):
1258      input_example = {
1259          "messages": [
1260              {
1261                  "role": "user",
1262                  "content": "What is a good name for a company that makes MLflow?",
1263              }
1264          ]
1265      }
1266      with mlflow.start_run():
1267          model_info = mlflow.langchain.log_model(
1268              chain_path,
1269              name="model_path",
1270              input_example=input_example,
1271              model_config={
1272                  "response": "modified response",
1273                  "embedding_size": 5,
1274                  "llm_prompt_template": "answer the question",
1275              },
1276          )
1277  
1278      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1279      answer = "modified response"
1280      assert loaded_model.invoke(input_example) == answer
1281      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1282      assert answer == _get_message_content(pyfunc_loaded_model.predict(input_example))
1283  
1284  
1285  @pytest.mark.parametrize(
1286      "model_config",
1287      [
1288          os.path.abspath("tests/langchain/sample_code/config.yml"),
1289          "tests/langchain/../langchain/sample_code/config.yml",
1290      ],
1291  )
1292  def test_save_load_chain_as_code_with_different_names(tmp_path, model_config):
1293      input_example = {
1294          "messages": [
1295              {
1296                  "role": "user",
1297                  "content": "What is a good name for a company that makes MLflow?",
1298              }
1299          ]
1300      }
1301  
1302      # Read the contents of the original chain file
1303      with open("tests/langchain/sample_code/chain.py") as chain_file:
1304          chain_file_content = chain_file.read()
1305  
1306      temp_file = tmp_path / "model.py"
1307      temp_file.write_text(chain_file_content)
1308  
1309      with mlflow.start_run():
1310          model_info = mlflow.langchain.log_model(
1311              str(temp_file),
1312              name="model_path",
1313              input_example=input_example,
1314              model_config=model_config,
1315          )
1316  
1317      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1318      answer = "Databricks"
1319      assert loaded_model.invoke(input_example) == answer
1320      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1321      assert answer == _get_message_content(pyfunc_loaded_model.predict(input_example))
1322  
1323  
1324  @pytest.mark.parametrize(
1325      "chain_path",
1326      [
1327          os.path.abspath("tests/langchain/sample_code/chain.py"),
1328          "tests/langchain/../langchain/sample_code/chain.py",
1329      ],
1330  )
1331  @pytest.mark.parametrize(
1332      "model_config",
1333      [
1334          os.path.abspath("tests/langchain/sample_code/config.yml"),
1335          "tests/langchain/../langchain/sample_code/config.yml",
1336      ],
1337  )
1338  def test_save_load_chain_as_code_multiple_times(tmp_path, chain_path, model_config):
1339      input_example = {
1340          "messages": [
1341              {
1342                  "role": "user",
1343                  "content": "What is a good name for a company that makes MLflow?",
1344              }
1345          ]
1346      }
1347      with mlflow.start_run():
1348          model_info = mlflow.langchain.log_model(
1349              chain_path,
1350              name="model_path",
1351              input_example=input_example,
1352              model_config=model_config,
1353          )
1354  
1355      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1356      with open(model_config) as f:
1357          base_config = yaml.safe_load(f)
1358  
1359      assert loaded_model.middle[0].messages[0].prompt.template == base_config["llm_prompt_template"]
1360  
1361      file_name = "config_updated.yml"
1362      new_config_file = str(tmp_path.joinpath(file_name))
1363  
1364      new_config = base_config.copy()
1365      new_config["llm_prompt_template"] = "new_template"
1366      with open(new_config_file, "w") as f:
1367          yaml.dump(new_config, f)
1368  
1369      with mlflow.start_run():
1370          model_info = mlflow.langchain.log_model(
1371              chain_path,
1372              name="model_path",
1373              input_example=input_example,
1374              model_config=new_config_file,
1375          )
1376  
1377      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1378      assert loaded_model.middle[0].messages[0].prompt.template == new_config["llm_prompt_template"]
1379  
1380  
1381  @pytest.mark.parametrize(
1382      "chain_path",
1383      [
1384          os.path.abspath("tests/langchain/sample_code/chain.py"),
1385          "tests/langchain/../langchain/sample_code/chain.py",
1386      ],
1387  )
1388  def test_save_load_chain_as_code_with_model_paths(chain_path):
1389      input_example = {
1390          "messages": [
1391              {
1392                  "role": "user",
1393                  "content": "What is a good name for a company that makes MLflow?",
1394              }
1395          ]
1396      }
1397      artifact_path = "model_path"
1398      with (
1399          mlflow.start_run(),
1400          mock.patch("mlflow.langchain.model._add_code_from_conf_to_system_path") as add_mock,
1401      ):
1402          model_info = mlflow.langchain.log_model(
1403              chain_path,
1404              name=artifact_path,
1405              input_example=input_example,
1406              code_paths=[__file__],
1407              model_config={
1408                  "response": "modified response",
1409                  "embedding_size": 5,
1410                  "llm_prompt_template": "answer the question",
1411              },
1412          )
1413          loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1414          answer = "modified response"
1415          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.langchain.FLAVOR_NAME)
1416          assert loaded_model.invoke(input_example) == answer
1417          add_mock.assert_called()
1418  
1419  
1420  @pytest.mark.parametrize("chain_path", [os.path.abspath("tests/langchain1/sample_code/chain.py")])
1421  def test_save_load_chain_errors(chain_path):
1422      input_example = {
1423          "messages": [
1424              {
1425                  "role": "user",
1426                  "content": "What is a good name for a company that makes MLflow?",
1427              }
1428          ]
1429      }
1430      with mlflow.start_run():
1431          with pytest.raises(
1432              MlflowException,
1433              match=f"The provided model path '{chain_path}' does not exist. "
1434              "Ensure the file path is valid and try again.",
1435          ):
1436              mlflow.langchain.log_model(
1437                  chain_path,
1438                  name="model_path",
1439                  input_example=input_example,
1440                  model_config="tests/langchain/state_of_the_union.txt",
1441              )
1442  
1443  
1444  @pytest.mark.parametrize(
1445      "chain_path",
1446      [
1447          os.path.abspath("tests/langchain/sample_code/no_config/chain.py"),
1448          "tests/langchain/../langchain/sample_code/no_config/chain.py",
1449      ],
1450  )
1451  def test_save_load_chain_as_code_optional_code_path(chain_path):
1452      input_example = {
1453          "messages": [
1454              {
1455                  "role": "user",
1456                  "content": "What is a good name for a company that makes MLflow?",
1457              }
1458          ]
1459      }
1460      artifact_path = "new_model_path"
1461      with mlflow.start_run():
1462          model_info = mlflow.langchain.log_model(
1463              chain_path,
1464              name=artifact_path,
1465              input_example=input_example,
1466          )
1467  
1468      assert mlflow.models.model_config.__mlflow_model_config__ is None
1469      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1470      assert mlflow.models.model_config.__mlflow_model_config__ is None
1471      answer = "Databricks"
1472      assert loaded_model.invoke(input_example) == answer
1473      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1474      assert (
1475          pyfunc_loaded_model
1476          .predict(input_example)[0]
1477          .get("choices")[0]
1478          .get("message")
1479          .get("content")
1480          == answer
1481      )
1482  
1483      inference_payload = load_serving_example(model_info.model_uri)
1484      response = pyfunc_serve_and_score_model(
1485          model_info.model_uri,
1486          data=inference_payload,
1487          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1488          extra_args=["--env-manager", "local"],
1489      )
1490      # avoid minor diff of created time in the response
1491      prediction_result = json.loads(response.content.decode("utf-8"))
1492      prediction_result[0]["created"] = 123
1493      expected_prediction = try_transform_response_to_chat_format(answer)
1494      expected_prediction["created"] = 123
1495      assert prediction_result == [expected_prediction]
1496  
1497      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
1498      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
1499      assert reloaded_model.resources["databricks"] == {
1500          "serving_endpoint": [{"name": "fake-endpoint"}]
1501      }
1502      assert reloaded_model.metadata is None
1503  
1504  
1505  @pytest.fixture
1506  def fake_chat_stream_model():
1507      class FakeChatStreamModel(SimpleChatModel):
1508          """Fake Chat Stream Model wrapper for testing purposes."""
1509  
1510          endpoint_name: str = "fake-stream-endpoint"
1511  
1512          def _call(
1513              self,
1514              messages: list[BaseMessage],
1515              stop: list[str] | None = None,
1516              run_manager: CallbackManagerForLLMRun | None = None,
1517              **kwargs: Any,
1518          ) -> str:
1519              return "Databricks"
1520  
1521          def _stream(
1522              self,
1523              messages: list[BaseMessage],
1524              stop: list[str] | None = None,
1525              run_manager: CallbackManagerForLLMRun | None = None,
1526              **kwargs: Any,
1527          ) -> Iterator[ChatGenerationChunk]:
1528              for chunk_content, finish_reason in [
1529                  ("Da", None),
1530                  ("tab", None),
1531                  ("ricks", "stop"),
1532              ]:
1533                  chunk = ChatGenerationChunk(
1534                      message=AIMessageChunk(content=chunk_content),
1535                      generation_info={"finish_reason": finish_reason},
1536                  )
1537                  if run_manager:
1538                      run_manager.on_llm_new_token(chunk.text, chunk=chunk)
1539  
1540                  yield chunk
1541  
1542          @property
1543          def _llm_type(self) -> str:
1544              return "fake chat model"
1545  
1546      return FakeChatStreamModel(endpoint_name="fake-stream-endpoint")
1547  
1548  
1549  @skip_if_v1
1550  @pytest.mark.parametrize("provide_signature", [True, False])
1551  def test_simple_chat_model_stream_inference(fake_chat_stream_model, provide_signature):
1552      # TODO: Migrate to models-from-code
1553      input_example = {
1554          "messages": [
1555              {"role": "system", "content": "You are a helpful assistant."},
1556              {"role": "assistant", "content": "What would you like to ask?"},
1557              {"role": "user", "content": "Who owns MLflow?"},
1558          ]
1559      }
1560      with mlflow.start_run():
1561          model_info = mlflow.langchain.log_model(
1562              fake_chat_stream_model,
1563              name="model",
1564          )
1565  
1566      if provide_signature:
1567          signature = infer_signature(model_input=input_example)
1568          with mlflow.start_run():
1569              model_with_siginature_info = mlflow.langchain.log_model(
1570                  fake_chat_stream_model, name="model", signature=signature
1571              )
1572      else:
1573          with mlflow.start_run():
1574              model_with_siginature_info = mlflow.langchain.log_model(
1575                  fake_chat_stream_model, name="model", input_example=input_example
1576              )
1577  
1578      for model_uri in [model_info.model_uri, model_with_siginature_info.model_uri]:
1579          loaded_model = mlflow.pyfunc.load_model(model_uri)
1580  
1581          chunk_iter = loaded_model.predict_stream(input_example)
1582  
1583          finish_reason = "stop"
1584  
1585          with mock.patch("time.time", return_value=1677858242):
1586              chunks = list(chunk_iter)
1587  
1588              for chunk in chunks:
1589                  assert "id" in chunk, "chunk id is lost."
1590                  chunk["id"] = None
1591  
1592              assert chunks == [
1593                  {
1594                      "id": None,
1595                      "object": "chat.completion.chunk",
1596                      "created": 1677858242,
1597                      "model": "",
1598                      "choices": [
1599                          {
1600                              "index": 0,
1601                              "finish_reason": None,
1602                              "delta": {"role": "assistant", "content": "Da"},
1603                          }
1604                      ],
1605                  },
1606                  {
1607                      "id": None,
1608                      "object": "chat.completion.chunk",
1609                      "created": 1677858242,
1610                      "model": "",
1611                      "choices": [
1612                          {
1613                              "index": 0,
1614                              "finish_reason": None,
1615                              "delta": {"role": "assistant", "content": "tab"},
1616                          }
1617                      ],
1618                  },
1619                  {
1620                      "id": None,
1621                      "object": "chat.completion.chunk",
1622                      "created": 1677858242,
1623                      "model": "",
1624                      "choices": [
1625                          {
1626                              "index": 0,
1627                              "finish_reason": finish_reason,
1628                              "delta": {"role": "assistant", "content": "ricks"},
1629                          }
1630                      ],
1631                  },
1632              ]
1633  
1634  
1635  @skip_if_v1
1636  def test_simple_chat_model_stream_with_callbacks(fake_chat_stream_model):
1637      # TODO: Migrate to models-from-code
1638      class TestCallbackHandler(BaseCallbackHandler):
1639          def __init__(self):
1640              super().__init__()
1641              self.num_llm_start_calls = 0
1642  
1643          def on_llm_start(
1644              self,
1645              serialized: dict[str, Any],
1646              prompts: list[str],
1647              **kwargs: Any,
1648          ) -> Any:
1649              self.num_llm_start_calls += 1
1650  
1651      prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?")
1652      chain = prompt | fake_chat_stream_model | StrOutputParser()
1653      # Test the basic functionality of the chain
1654      assert chain.invoke({"industry": "tech"}) == "Databricks"
1655  
1656      with mlflow.start_run():
1657          model_info = mlflow.langchain.log_model(
1658              chain, name="model_path", input_example={"industry": "tech"}
1659          )
1660  
1661      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1662  
1663      callback_handler1 = TestCallbackHandler()
1664      callback_handler2 = TestCallbackHandler()
1665  
1666      # Ensure handlers have not been called yet
1667      assert callback_handler1.num_llm_start_calls == 0
1668      assert callback_handler2.num_llm_start_calls == 0
1669  
1670      stream = pyfunc_loaded_model._model_impl._predict_stream_with_callbacks(
1671          {"industry": "tech"},
1672          callback_handlers=[callback_handler1, callback_handler2],
1673      )
1674      assert list(stream) == ["Da", "tab", "ricks"]
1675  
1676      # Test that the callback handlers were called
1677      assert callback_handler1.num_llm_start_calls == 1
1678      assert callback_handler2.num_llm_start_calls == 1
1679  
1680  
1681  @skip_if_v1
1682  def test_langchain_model_save_load_with_listeners(fake_chat_model):
1683      # Migrate this to models-from-code
1684      prompt = ChatPromptTemplate.from_messages([
1685          ("system", "You are a helpful assistant."),
1686          MessagesPlaceholder(variable_name="history"),
1687          ("human", "{question}"),
1688      ])
1689  
1690      def retrieve_history(input):
1691          return {"history": [], "question": input["question"], "name": input["name"]}
1692  
1693      chain = (
1694          {"question": itemgetter("question"), "name": itemgetter("name")}
1695          | (RunnableLambda(retrieve_history) | prompt | fake_chat_model).with_listeners()
1696          | StrOutputParser()
1697          | RunnablePassthrough()
1698      )
1699      input_example = {"question": "Who owns MLflow?", "name": ""}
1700      assert chain.invoke(input_example) == "Databricks"
1701  
1702      with mlflow.start_run():
1703          model_info = mlflow.langchain.log_model(
1704              chain, name="model_path", input_example=input_example
1705          )
1706      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1707      assert loaded_model.invoke(input_example) == "Databricks"
1708      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1709      assert pyfunc_loaded_model.predict(input_example) == ["Databricks"]
1710  
1711      inference_payload = load_serving_example(model_info.model_uri)
1712      response = pyfunc_serve_and_score_model(
1713          model_info.model_uri,
1714          data=inference_payload,
1715          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1716          extra_args=["--env-manager", "local"],
1717      )
1718      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
1719          "predictions": ["Databricks"]
1720      }
1721  
1722  
1723  @pytest.mark.parametrize("env_var", ["MLFLOW_ENABLE_TRACE_IN_SERVING", "ENABLE_MLFLOW_TRACING"])
1724  def test_langchain_model_not_inject_callback_when_disabled(monkeypatch, model_path, env_var):
1725      # Emulate the model serving environment
1726      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
1727  
1728      # Disable tracing
1729      monkeypatch.setenv(env_var, "false")
1730  
1731      mlflow.langchain.save_model(SIMPLE_MODEL_CODE_PATH, model_path)
1732  
1733      loaded_model = mlflow.pyfunc.load_model(model_path)
1734      loaded_model.predict({"product": "shoe"})
1735  
1736      # Trace should be logged to the inference table
1737      from mlflow.tracing.export.inference_table import _TRACE_BUFFER
1738  
1739      assert _TRACE_BUFFER == {}
1740  
1741  
1742  @pytest.mark.parametrize(
1743      "chain_path",
1744      [
1745          os.path.abspath("tests/langchain/sample_code/no_config/chain.py"),
1746          "tests/langchain/../langchain/sample_code/no_config/chain.py",
1747      ],
1748  )
1749  def test_save_model_as_code_correct_streamable(chain_path):
1750      input_example = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]}
1751      answer = "Databricks"
1752      artifact_path = "model_path"
1753      with mlflow.start_run():
1754          model_info = mlflow.langchain.log_model(
1755              chain_path,
1756              name=artifact_path,
1757              input_example=input_example,
1758          )
1759  
1760      assert model_info.flavors["langchain"]["streamable"] is True
1761      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1762  
1763      with mock.patch("time.time", return_value=1677858242):
1764          assert pyfunc_loaded_model._model_impl._predict_with_callbacks(input_example) == {
1765              "id": None,
1766              "object": "chat.completion",
1767              "created": 1677858242,
1768              "model": "",
1769              "choices": [
1770                  {
1771                      "index": 0,
1772                      "message": {
1773                          "role": "assistant",
1774                          "content": "Databricks",
1775                      },
1776                      "finish_reason": None,
1777                  }
1778              ],
1779              "usage": {
1780                  "prompt_tokens": None,
1781                  "completion_tokens": None,
1782                  "total_tokens": None,
1783              },
1784          }
1785  
1786      inference_payload = load_serving_example(model_info.model_uri)
1787      response = pyfunc_serve_and_score_model(
1788          model_info.model_uri,
1789          data=inference_payload,
1790          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1791          extra_args=["--env-manager", "local"],
1792      )
1793      # avoid minor diff of created time in the response
1794      prediction_result = json.loads(response.content.decode("utf-8"))
1795      prediction_result[0]["created"] = 123
1796      expected_prediction = try_transform_response_to_chat_format(answer)
1797      expected_prediction["created"] = 123
1798      assert prediction_result == [expected_prediction]
1799  
1800      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
1801      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
1802      assert reloaded_model.resources["databricks"] == {
1803          "serving_endpoint": [{"name": "fake-endpoint"}]
1804      }
1805  
1806  
1807  @skip_if_v1
1808  def test_save_load_langchain_binding(fake_chat_model):
1809      runnable_binding = RunnableBinding(bound=fake_chat_model, kwargs={"stop": ["-"]})
1810      model = runnable_binding | StrOutputParser()
1811      assert model.invoke("Say something") == "Databricks"
1812  
1813      with mlflow.start_run():
1814          model_info = mlflow.langchain.log_model(
1815              model, name="model_path", input_example="Say something"
1816          )
1817      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1818      assert loaded_model.first.kwargs == {"stop": ["-"]}
1819      assert loaded_model.invoke("hello") == "Databricks"
1820      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1821      assert pyfunc_loaded_model.predict("hello") == ["Databricks"]
1822  
1823      inference_payload = load_serving_example(model_info.model_uri)
1824      response = pyfunc_serve_and_score_model(
1825          model_info.model_uri,
1826          data=inference_payload,
1827          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1828          extra_args=["--env-manager", "local"],
1829      )
1830      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
1831          "predictions": ["Databricks"]
1832      }
1833  
1834  
1835  @skip_if_v1
1836  def test_save_load_langchain_binding_llm_with_tool():
1837      from langchain_core.tools import tool
1838  
1839      # We need to use ChatOpenAI from langchain_openai as community one does not support bind_tools
1840      from langchain_openai import ChatOpenAI
1841  
1842      @tool
1843      def add(a: int, b: int) -> int:
1844          """Adds a and b.
1845  
1846          Args:
1847              a: first int
1848              b: second int
1849          """
1850          return a + b
1851  
1852      runnable_binding = ChatOpenAI(temperature=0.9).bind_tools([add])
1853      model = runnable_binding | StrOutputParser()
1854      expected_output = '[{"role": "user", "content": "hello"}]'
1855      assert model.invoke("hello") == expected_output
1856  
1857      with mlflow.start_run():
1858          model_info = mlflow.langchain.log_model(model, name="model_path", input_example="hello")
1859  
1860      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1861      assert loaded_model.invoke("hello") == expected_output
1862      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1863      assert pyfunc_loaded_model.predict("hello") == [expected_output]
1864  
1865  
1866  @skip_if_v1
1867  def test_langchain_bindings_save_load_with_config_and_types(fake_chat_model):
1868      class CustomCallbackHandler(BaseCallbackHandler):
1869          def __init__(self):
1870              self.count = 0
1871  
1872          def on_chain_start(
1873              self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
1874          ) -> None:
1875              self.count += 1
1876  
1877          def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
1878              self.count += 1
1879  
1880      model = fake_chat_model | StrOutputParser()
1881      callback = CustomCallbackHandler()
1882      model = model.with_config(run_name="test_run", callbacks=[callback]).with_types(
1883          input_type=str, output_type=str
1884      )
1885      assert model.invoke("Say something") == "Databricks"
1886      assert callback.count == 4
1887  
1888      with mlflow.start_run():
1889          model_info = mlflow.langchain.log_model(model, name="model_path", input_example="hello")
1890      loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1891      assert loaded_model.config["run_name"] == "test_run"
1892      assert loaded_model.custom_input_type == str
1893      assert loaded_model.custom_output_type == str
1894      callback = loaded_model.config["callbacks"][0]
1895      assert loaded_model.invoke("hello") == "Databricks"
1896      assert callback.count > 8  # accumulated count (inside model logging we also call the callbacks)
1897      pyfunc_loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1898      assert pyfunc_loaded_model.predict("hello") == ["Databricks"]
1899  
1900      inference_payload = load_serving_example(model_info.model_uri)
1901      response = pyfunc_serve_and_score_model(
1902          model_info.model_uri,
1903          data=inference_payload,
1904          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1905          extra_args=["--env-manager", "local"],
1906      )
1907      assert PredictionsResponse.from_json(response.content.decode("utf-8")) == {
1908          "predictions": ["Databricks"]
1909      }
1910  
1911  
1912  @pytest.mark.parametrize(
1913      "chain_path",
1914      [
1915          os.path.abspath("tests/langchain/sample_code/chain.py"),
1916          "tests/langchain/../langchain/sample_code/chain.py",
1917      ],
1918  )
1919  @pytest.mark.parametrize(
1920      "model_config",
1921      [
1922          os.path.abspath("tests/langchain/sample_code/config.yml"),
1923          "tests/langchain/../langchain/sample_code/config.yml",
1924      ],
1925  )
1926  def test_load_chain_with_model_config_overrides_saved_config(chain_path, model_config):
1927      input_example = {
1928          "messages": [
1929              {
1930                  "role": "user",
1931                  "content": "What is a good name for a company that makes MLflow?",
1932              }
1933          ]
1934      }
1935      artifact_path = "model_path"
1936      with mlflow.start_run():
1937          model_info = mlflow.langchain.log_model(
1938              chain_path,
1939              name=artifact_path,
1940              input_example=input_example,
1941              model_config=model_config,
1942          )
1943  
1944      with mock.patch("mlflow.langchain.model._load_model_code_path") as load_model_code_path_mock:
1945          mlflow.pyfunc.load_model(model_info.model_uri, model_config={"embedding_size": 2})
1946          args, kwargs = load_model_code_path_mock.call_args
1947          assert args[1] == {
1948              "embedding_size": 2,
1949              "llm_prompt_template": "Answer the following question based on the "
1950              "context: {context}\nQuestion: {question}",
1951              "not_used_array": [
1952                  1,
1953                  2,
1954                  3,
1955              ],
1956              "response": "Databricks",
1957          }
1958  
1959  
1960  @skip_if_v1
1961  @pytest.mark.parametrize("streamable", [True, False, None])
1962  def test_langchain_model_streamable_param_in_log_model(streamable, fake_chat_model):
1963      # TODO: Migrate to models-from-code
1964      prompt = ChatPromptTemplate.from_template("What's your favorite {industry} company?")
1965      chain = prompt | fake_chat_model | StrOutputParser()
1966  
1967      runnable = RunnableParallel({"llm": lambda _: "completion"})
1968  
1969      for model in [chain, runnable]:
1970          with mock.patch("mlflow.langchain.model._save_model"), mlflow.start_run():
1971              model_info = mlflow.langchain.log_model(
1972                  model,
1973                  name="model",
1974                  streamable=streamable,
1975                  pip_requirements=[],
1976              )
1977  
1978              expected = (streamable is None) or streamable
1979              assert model_info.flavors["langchain"]["streamable"] is expected
1980  
1981  
1982  @pytest.fixture
1983  def model_type(request):
1984      return lc_runnables_types()[request.param]
1985  
1986  
1987  @skip_if_v1
1988  @pytest.mark.parametrize("streamable", [True, False, None])
1989  @pytest.mark.parametrize("model_type", range(len(lc_runnables_types())), indirect=True)
1990  def test_langchain_model_streamable_param_in_log_model_for_lc_runnable_types(
1991      streamable, model_type
1992  ):
1993      with mock.patch("mlflow.langchain.model._save_model"), mlflow.start_run():
1994          model = mock.MagicMock(spec=model_type)
1995          assert hasattr(model, "stream") is True
1996          model_info = mlflow.langchain.log_model(
1997              model,
1998              name="model",
1999              streamable=streamable,
2000              pip_requirements=[],
2001          )
2002  
2003          expected = (streamable is None) or streamable
2004          assert model_info.flavors["langchain"]["streamable"] is expected
2005  
2006          del model.stream
2007          assert hasattr(model, "stream") is False
2008          model_info = mlflow.langchain.log_model(
2009              model,
2010              name="model",
2011              streamable=streamable,
2012              pip_requirements=[],
2013          )
2014          assert model_info.flavors["langchain"]["streamable"] is bool(streamable)
2015  
2016  
2017  @skip_if_v1
2018  def test_agent_executor_model_with_messages_input():
2019      question = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]}
2020  
2021      with mlflow.start_run():
2022          model_info = mlflow.langchain.log_model(
2023              os.path.abspath("tests/langchain/agent_executor/chain.py"),
2024              name="model_path",
2025              input_example=question,
2026              model_config=os.path.abspath("tests/langchain/agent_executor/config.yml"),
2027          )
2028      native_model = mlflow.langchain.load_model(model_info.model_uri)
2029      assert native_model.invoke(question)["output"] == "Databricks"
2030      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2031      # TODO: in the future we should fix this and output shouldn't be wrapped
2032      # The result is wrapped in a list because during signature enforcement we convert
2033      # input data to pandas dataframe, then inside _convert_llm_input_data
2034      # we convert pandas dataframe back to records, and a single row will be
2035      # wrapped inside a list.
2036      assert pyfunc_model.predict(question) == ["Databricks"]
2037  
2038      # Test stream output
2039      response = pyfunc_model.predict_stream(question)
2040      assert inspect.isgenerator(response)
2041  
2042      expected_response = [
2043          {
2044              "output": "Databricks",
2045              "messages": [
2046                  {
2047                      "additional_kwargs": {},
2048                      "content": "Databricks",
2049                      "example": False,
2050                      "id": None,
2051                      "invalid_tool_calls": [],
2052                      "name": None,
2053                      "response_metadata": {},
2054                      "tool_calls": [],
2055                      "type": "ai",
2056                      "usage_metadata": None,
2057                  }
2058              ],
2059          }
2060      ]
2061      assert list(response) == expected_response
2062  
2063  
2064  def test_invoking_model_with_params():
2065      with mlflow.start_run():
2066          model_info = mlflow.langchain.log_model(
2067              os.path.abspath("tests/langchain/sample_code/model_with_config.py"),
2068              name="model",
2069          )
2070      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2071      data = {"x": 0}
2072      pyfunc_model.predict(data)
2073      params = {"config": {"temperature": 3.0}}
2074      with mock.patch("mlflow.pyfunc._validate_prediction_input", return_value=(data, params)):
2075          # This proves the temperature is passed to the model
2076          with pytest.raises(MlflowException, match=r"Input should be less than or equal to 2"):
2077              pyfunc_model.predict(data=data, params=params)
2078  
2079  
2080  def test_custom_resources(tmp_path):
2081      input_example = {
2082          "messages": [
2083              {
2084                  "role": "user",
2085                  "content": "What is a good name for a company that makes MLflow?",
2086              }
2087          ]
2088      }
2089      expected_resources = {
2090          "api_version": "1",
2091          "databricks": {
2092              "serving_endpoint": [
2093                  {"name": "databricks-mixtral-8x7b-instruct"},
2094                  {"name": "databricks-bge-large-en"},
2095                  {"name": "azure-eastus-model-serving-2_vs_endpoint"},
2096              ],
2097              "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
2098              "sql_warehouse": [{"name": "testid"}],
2099              "function": [
2100                  {"name": "rag.studio.test_function_a"},
2101                  {"name": "rag.studio.test_function_b"},
2102              ],
2103          },
2104      }
2105      artifact_path = "model_path"
2106      chain_path = "tests/langchain/sample_code/chain.py"
2107      with mlflow.start_run():
2108          model_info = mlflow.langchain.log_model(
2109              chain_path,
2110              name=artifact_path,
2111              input_example=input_example,
2112              model_config="tests/langchain/sample_code/config.yml",
2113              resources=[
2114                  DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
2115                  DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
2116                  DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"),
2117                  DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
2118                  DatabricksSQLWarehouse(warehouse_id="testid"),
2119                  DatabricksFunction(function_name="rag.studio.test_function_a"),
2120                  DatabricksFunction(function_name="rag.studio.test_function_b"),
2121              ],
2122          )
2123  
2124          model_path = _download_artifact_from_uri(model_info.model_uri)
2125          reloaded_model = Model.load(os.path.join(model_path, "MLmodel"))
2126          assert reloaded_model.resources == expected_resources
2127  
2128      yaml_file = tmp_path.joinpath("resources.yaml")
2129      with open(yaml_file, "w") as f:
2130          f.write(
2131              """
2132              api_version: "1"
2133              databricks:
2134                  vector_search_index:
2135                  - name: rag.studio_bugbash.databricks_docs_index
2136                  serving_endpoint:
2137                  - name: databricks-mixtral-8x7b-instruct
2138                  - name: databricks-bge-large-en
2139                  - name: azure-eastus-model-serving-2_vs_endpoint
2140                  sql_warehouse:
2141                  - name: testid
2142                  function:
2143                  - name: rag.studio.test_function_a
2144                  - name: rag.studio.test_function_b
2145              """
2146          )
2147  
2148      artifact_path_2 = "model_path_2"
2149      with mlflow.start_run():
2150          model_info = mlflow.langchain.log_model(
2151              chain_path,
2152              name=artifact_path_2,
2153              input_example=input_example,
2154              model_config="tests/langchain/sample_code/config.yml",
2155              resources=yaml_file,
2156          )
2157  
2158          model_path = _download_artifact_from_uri(model_info.model_uri)
2159          reloaded_model = Model.load(os.path.join(model_path, "MLmodel"))
2160          assert reloaded_model.resources == expected_resources
2161  
2162  
2163  def test_pyfunc_converts_chat_request_for_non_chat_model():
2164      input_example = {"messages": [{"role": "user", "content": "Hello"}]}
2165      with mlflow.start_run():
2166          model_info = mlflow.langchain.log_model(
2167              lc_model=SIMPLE_MODEL_CODE_PATH,
2168              input_example=input_example,
2169          )
2170  
2171      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2172      result = pyfunc_model.predict(input_example)
2173      # output are converted to chatResponse format
2174      assert isinstance(result[0]["choices"][0]["message"]["content"], str)
2175  
2176      response = pyfunc_model.predict_stream(input_example)
2177      assert inspect.isgenerator(response)
2178      assert isinstance(list(response)[0]["choices"][0]["delta"]["content"], str)
2179  
2180  
2181  @skip_if_v1
2182  def test_pyfunc_should_not_convert_chat_request_if_env_var_is_set_to_false(monkeypatch):
2183      monkeypatch.setenv(MLFLOW_CONVERT_MESSAGES_DICT_FOR_LANGCHAIN.name, "false")
2184  
2185      # This model is an example when the model expects a chat request
2186      # format input, but the input should not be converted to List[BaseMessage]
2187      model = RunnablePassthrough.assign(problem=lambda x: x["messages"][-1]["content"]) | itemgetter(
2188          "problem"
2189      )
2190      input_example = {"messages": [{"role": "user", "content": "Databricks"}]}
2191      assert model.invoke(input_example) == "Databricks"
2192  
2193      # pyfunc model can accepts chat request format even the chain
2194      # itself does not accept it, but we need to use the correct
2195      # input example to infer model signature
2196      with mlflow.start_run():
2197          model_info = mlflow.langchain.log_model(model, input_example=input_example)
2198  
2199      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2200      result = pyfunc_model.predict(input_example)
2201      assert result == ["Databricks"]
2202  
2203      # Test stream output
2204      response = pyfunc_model.predict_stream(input_example)
2205      assert inspect.isgenerator(response)
2206      assert list(response) == ["Databricks"], list(response)
2207  
2208  
2209  def test_log_langchain_model_with_prompt():
2210      mlflow.register_prompt(
2211          name="qa_prompt",
2212          template="What is a good name for a company that makes {{product}}?",
2213          commit_message="Prompt for generating company names",
2214      )
2215      mlflow.set_prompt_alias("qa_prompt", alias="production", version=1)
2216  
2217      mlflow.register_prompt(name="another_prompt", template="Hi")
2218  
2219      # If the model code involves `mlflow.load_prompt()` call, the prompt version
2220      # should be automatically logged to the Run
2221      with mlflow.start_run():
2222          model_info = mlflow.langchain.log_model(
2223              os.path.abspath("tests/langchain/sample_code/chain_with_mlflow_prompt.py"),
2224              name="model",
2225              # Manually associate another prompt
2226              prompts=["prompts:/another_prompt/1"],
2227          )
2228  
2229      # Check that prompts were linked to the run via the linkedPrompts tag
2230      from mlflow.tracing.constant import TraceTagKey
2231  
2232      run = mlflow.MlflowClient().get_run(model_info.run_id)
2233      linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS)
2234      assert linked_prompts_tag is not None
2235  
2236      linked_prompts = json.loads(linked_prompts_tag)
2237      assert len(linked_prompts) == 2
2238      assert {p["name"] for p in linked_prompts} == {"qa_prompt", "another_prompt"}
2239  
2240      prompt = mlflow.load_prompt("qa_prompt", 1)
2241      assert prompt.aliases == ["production"]
2242  
2243      prompt = mlflow.load_prompt("another_prompt", 1)
2244  
2245      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2246      response = pyfunc_model.predict({"product": "shoe"})
2247      # Fake OpenAI server echo the input
2248      assert (
2249          response
2250          == '[{"role": "user", "content": "What is a good name for a company that makes shoe?"}]'
2251      )
2252  
2253  
2254  def test_predict_with_callbacks_with_tracing(monkeypatch):
2255      # Simulate the model serving environment
2256      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
2257      monkeypatch.setenv("ENABLE_MLFLOW_TRACING", "true")
2258      mlflow.tracing.reset()
2259  
2260      model_info = mlflow.langchain.log_model(
2261          os.path.abspath("tests/langchain/sample_code/workflow.py"),
2262          name="model_path",
2263          input_example={"messages": [{"role": "user", "content": "What is MLflow?"}]},
2264      )
2265      # serving environment only reads from this environment variable
2266      monkeypatch.setenv("MLFLOW_EXPERIMENT_ID", mlflow.last_logged_model().experiment_id)
2267  
2268      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2269  
2270      request_id = "mock_request_id"
2271      tracer = MlflowLangchainTracer(prediction_context=Context(request_id))
2272      input_example = {"messages": [{"role": "user", "content": TEST_CONTENT}]}
2273  
2274      with mock.patch("mlflow.tracing.client.TracingClient.start_trace") as mock_start_trace:
2275          pyfunc_model._model_impl._predict_with_callbacks(
2276              data=input_example, callback_handlers=[tracer]
2277          )
2278          mlflow.flush_trace_async_logging()
2279          mock_start_trace.assert_called_once()
2280          trace_info = mock_start_trace.call_args[0][0]
2281          assert trace_info.client_request_id == request_id
2282          assert trace_info.request_metadata[TraceMetadataKey.MODEL_ID] == model_info.model_id
2283  
2284  
2285  @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="The test is only for langchain>=1 versions")
2286  def test_langchain_v1_save_model_as_pickle_error():
2287      model = create_openai_runnable()
2288      with mlflow.start_run():
2289          with pytest.raises(
2290              MlflowException,
2291              match="LangChain v1 onward only supports models-from-code",
2292          ):
2293              mlflow.langchain.log_model(
2294                  model, name="langchain_model", input_example={"product": "MLflow"}
2295              )