/ tests / genai / utils / test_display_utils.py
test_display_utils.py
 1  from unittest import mock
 2  
 3  import mlflow
 4  from mlflow.genai.utils import display_utils
 5  from mlflow.store.tracking.rest_store import RestStore
 6  from mlflow.tracking.client import MlflowClient
 7  from mlflow.utils.mlflow_tags import MLFLOW_DATABRICKS_WORKSPACE_URL
 8  
 9  
10  def test_display_outputs_jupyter(monkeypatch):
11      mock_store = mock.MagicMock(spec=RestStore)
12      mock_store.get_run = MlflowClient().get_run
13      mock_store.get_host_creds = lambda: mock.MagicMock(host="https://mlflow.example.com/")
14  
15      with (
16          mock.patch("IPython.display.display") as mock_display,
17          mock.patch.object(display_utils, "_get_store", return_value=mock_store),
18          mock.patch.object(display_utils, "_is_jupyter", return_value=True),
19          mlflow.start_run() as run,
20      ):
21          display_utils.display_evaluation_output(run.info.run_id)
22  
23      exp_id = run.info.experiment_id
24      expected_url = f"https://mlflow.example.com/#/experiments/{exp_id}/evaluation-runs?selectedRunUuid={run.info.run_id}"
25      html_content = mock_display.call_args[0][0].data
26      assert expected_url in html_content
27  
28  
29  def test_display_outputs_non_ipython(capsys):
30      mock_store = mock.MagicMock(spec=RestStore)
31      mock_store.get_run = mlflow.tracking.MlflowClient().get_run
32      mock_store.get_host_creds = lambda: mock.MagicMock(host="https://mlflow.example.com/")
33  
34      with (
35          mock.patch.object(display_utils, "_get_store", return_value=mock_store),
36          mock.patch.object(display_utils, "_is_jupyter", return_value=False),
37          mlflow.start_run() as run,
38      ):
39          display_utils.display_evaluation_output(run.info.run_id)
40  
41      captured = capsys.readouterr().out
42      exp_id = run.info.experiment_id
43      expected_url = f"https://mlflow.example.com/#/experiments/{exp_id}/evaluation-runs?selectedRunUuid={run.info.run_id}"
44      assert expected_url in captured
45  
46  
47  def test_display_outputs_databricks(monkeypatch):
48      host = "https://workspace.databricks.com"
49      client = mlflow.tracking.MlflowClient()
50  
51      mock_store = mock.MagicMock(spec=RestStore)
52      mock_store.get_run = client.get_run
53      mock_store.get_host_creds = lambda: mock.MagicMock(host=host)
54  
55      with mlflow.start_run() as run:
56          client.set_tag(run.info.run_id, MLFLOW_DATABRICKS_WORKSPACE_URL, host)
57  
58          with (
59              mock.patch("IPython.display.display") as mock_display,
60              mock.patch.object(display_utils, "_get_store", return_value=mock_store),
61              mock.patch.object(display_utils, "_is_jupyter", return_value=True),
62              mock.patch.object(display_utils, "is_databricks_uri", return_value=True),
63          ):
64              display_utils.display_evaluation_output(run.info.run_id)
65  
66      exp_id = run.info.experiment_id
67      expected_url = (
68          f"{host}/ml/experiments/{exp_id}/evaluation-runs?selectedRunUuid={run.info.run_id}"
69      )
70      html_content = mock_display.call_args[0][0].data
71      assert expected_url in html_content
72  
73  
74  def test_display_summary_with_local_store(capsys):
75      with mlflow.start_run() as run:
76          display_utils.display_evaluation_output(run.info.run_id)
77  
78      captured = capsys.readouterr().out
79      assert run.info.run_id in captured
80      assert "Traces" in captured