/ tests / tracing / display / test_ipython.py
test_ipython.py
  1  import json
  2  from collections import defaultdict
  3  from unittest.mock import Mock
  4  
  5  import pytest
  6  
  7  import mlflow
  8  from mlflow.tracing.display import (
  9      IPythonTraceDisplayHandler,
 10      get_display_handler,
 11      get_notebook_iframe_html,
 12  )
 13  
 14  from tests.tracing.helper import create_trace, skip_module_when_testing_trace_sdk
 15  
 16  skip_module_when_testing_trace_sdk()
 17  
 18  
 19  class MockEventRegistry:
 20      def __init__(self):
 21          self.events = defaultdict(list)
 22  
 23      def register(self, event, callback):
 24          self.events[event].append(callback)
 25  
 26      def trigger(self, event):
 27          for callback in self.events[event]:
 28              callback(None)
 29  
 30  
 31  class MockIPython:
 32      def __init__(self):
 33          self.events = MockEventRegistry()
 34  
 35      def mock_run_cell(self):
 36          self.events.trigger("post_run_cell")
 37  
 38  
 39  @pytest.fixture
 40  def _in_databricks(monkeypatch):
 41      monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "15.x")
 42  
 43  
 44  @pytest.fixture(autouse=True)
 45  def reset_singleton():
 46      IPythonTraceDisplayHandler._instance = None
 47      IPythonTraceDisplayHandler.disabled = False
 48  
 49  
 50  in_databricks = pytest.mark.usefixtures(_in_databricks.__name__)
 51  
 52  
 53  @in_databricks
 54  def test_display_is_not_called_without_ipython(monkeypatch):
 55      # in an IPython environment, the interactive shell will
 56      # be returned. however, for test purposes, just mock that
 57      # the value is not None.
 58      mock_display = Mock()
 59      monkeypatch.setattr("IPython.display.display", mock_display)
 60      handler = get_display_handler()
 61  
 62      handler.display_traces([create_trace("a")])
 63      assert mock_display.call_count == 0
 64  
 65      mock_ipython = MockIPython()
 66      monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython)
 67  
 68      # reset the singleton so the handler
 69      # can register the post-display hook
 70      IPythonTraceDisplayHandler._instance = None
 71      handler = get_display_handler()
 72      handler.display_traces([create_trace("b")])
 73  
 74      # simulate cell execution
 75      mock_ipython.mock_run_cell()
 76  
 77      assert mock_display.call_count == 1
 78  
 79  
 80  @in_databricks
 81  def test_ipython_client_clears_display_after_execution(monkeypatch):
 82      mock_ipython = MockIPython()
 83      monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython)
 84      handler = get_display_handler()
 85  
 86      mock_display_handle = Mock()
 87      mock_display = Mock(return_value=mock_display_handle)
 88      monkeypatch.setattr("IPython.display.display", mock_display)
 89      handler.display_traces([create_trace("a")])
 90      handler.display_traces([create_trace("b")])
 91      handler.display_traces([create_trace("c")])
 92  
 93      mock_ipython.mock_run_cell()
 94      # despite many calls to `display_traces`,
 95      # there should only be one call to `display`
 96      assert mock_display.call_count == 1
 97  
 98      mock_ipython.mock_run_cell()
 99      # expect that display is not called,
100      # since no traces should be present
101      assert mock_display.call_count == 1
102  
103  
104  @in_databricks
105  def test_display_is_called_in_correct_functions(monkeypatch):
106      mock_ipython = MockIPython()
107      monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython)
108      mock_display_handle = Mock()
109      mock_display = Mock(return_value=mock_display_handle)
110      monkeypatch.setattr("IPython.display.display", mock_display)
111  
112      @mlflow.trace
113      def foo():
114          return 3
115  
116      # display should be called after trace creation
117      foo()
118      mlflow.flush_trace_async_logging()
119      mock_ipython.mock_run_cell()
120      assert mock_display.call_count == 1
121  
122  
123  @in_databricks
124  def test_display_deduplicates_traces(monkeypatch):
125      mock_ipython = MockIPython()
126      monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython)
127      handler = get_display_handler()
128  
129      mock_display = Mock()
130      monkeypatch.setattr("IPython.display.display", mock_display)
131  
132      trace_a = create_trace("a")
133      trace_b = create_trace("b")
134      trace_c = create_trace("c")
135  
136      # The display client should dedupe traces to display and only display 3 (not 6).
137      handler.display_traces([trace_a])
138      handler.display_traces([trace_b])
139      handler.display_traces([trace_c])
140      handler.display_traces([trace_a, trace_b, trace_c])
141      mock_ipython.mock_run_cell()
142  
143      expected = [trace_a, trace_b, trace_c]
144  
145      assert mock_display.call_count == 1
146      assert mock_display.call_args[0][0] == {
147          "application/databricks.mlflow.trace": json.dumps([
148              json.loads(t._serialize_for_mimebundle()) for t in expected
149          ]),
150          "text/plain": repr(expected),
151      }
152  
153  
154  @in_databricks
155  def test_display_respects_max_limit(monkeypatch):
156      mock_ipython = MockIPython()
157      monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython)
158      handler = get_display_handler()
159  
160      mock_display = Mock()
161      monkeypatch.setattr("IPython.display.display", mock_display)
162  
163      monkeypatch.setenv("MLFLOW_MAX_TRACES_TO_DISPLAY_IN_NOTEBOOK", "1")
164  
165      trace_a = create_trace("a")
166      trace_b = create_trace("b")
167      trace_c = create_trace("c")
168      handler.display_traces([trace_a, trace_b, trace_c])
169      mock_ipython.mock_run_cell()
170  
171      assert mock_display.call_count == 1
172      assert mock_display.call_args[0][0] == {
173          "application/databricks.mlflow.trace": trace_a._serialize_for_mimebundle(),
174          "text/plain": repr(trace_a),
175      }
176  
177  
178  @in_databricks
179  def test_enable_and_disable_display(monkeypatch):
180      mock_ipython = MockIPython()
181      monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython)
182      mock_display_handle = Mock()
183      mock_display = Mock(return_value=mock_display_handle)
184      monkeypatch.setattr("IPython.display.display", mock_display)
185      trace_a = create_trace("a")
186  
187      # test that disabling the display handler prevents display() from being called
188      mlflow.tracing.disable_notebook_display()
189      handler = get_display_handler()
190      handler.display_traces([trace_a])
191      mock_ipython.mock_run_cell()
192  
193      mock_display.assert_not_called()
194  
195      # test that re-enabling it will make things display again
196      mlflow.tracing.enable_notebook_display()
197      handler = get_display_handler()
198      handler.display_traces([trace_a])
199      mock_ipython.mock_run_cell()
200  
201      assert mock_display.call_count == 1
202      assert mock_display.call_args[0][0] == {
203          "application/databricks.mlflow.trace": trace_a._serialize_for_mimebundle(),
204          "text/plain": repr(trace_a),
205      }
206  
207  
208  @in_databricks
209  def test_mimebundle_in_databricks():
210      # by default, it should contain the metadata
211      # necessary for rendering the trace UI
212      trace = create_trace("a")
213      assert trace._repr_mimebundle_() == {
214          "application/databricks.mlflow.trace": trace._serialize_for_mimebundle(),
215          "text/plain": repr(trace),
216      }
217  
218      # if trace display is disabled, only "text/plain" should exist
219      mlflow.tracing.disable_notebook_display()
220      assert trace._repr_mimebundle_() == {
221          "text/plain": repr(trace),
222      }
223  
224      # re-enabling should bring the metadata back
225      mlflow.tracing.enable_notebook_display()
226      assert trace._repr_mimebundle_() == {
227          "application/databricks.mlflow.trace": trace._serialize_for_mimebundle(),
228          "text/plain": repr(trace),
229      }
230  
231  
232  def test_mimebundle_in_oss():
233      # if the user is not using a tracking server, it should only contain text/plain
234      trace = create_trace("a")
235      assert trace._repr_mimebundle_() == {
236          "text/plain": repr(trace),
237      }
238  
239      # if the user is using a tracking server, it
240      # should contain an iframe in the text/html key
241      mlflow.set_tracking_uri("http://localhost:5000")
242      assert trace._repr_mimebundle_() == {
243          "text/plain": repr(trace),
244          "text/html": get_notebook_iframe_html([trace]),
245      }
246  
247      # disabling should remove this key, even if tracking server is used
248      mlflow.tracing.disable_notebook_display()
249      assert trace._repr_mimebundle_() == {
250          "text/plain": repr(trace),
251      }
252  
253  
254  def test_notebook_trace_renderer_base_url_override(monkeypatch):
255      trace = create_trace("a")
256      mlflow.set_tracking_uri("http://mlflow:5000")
257      monkeypatch.setenv("MLFLOW_NOTEBOOK_TRACE_RENDERER_BASE_URL", "http://localhost:5000")
258  
259      html = get_notebook_iframe_html([trace])
260      assert "http://localhost:5000/static-files/lib/notebook-trace-renderer/index.html" in html
261      assert "http://mlflow:5000/static-files/lib/notebook-trace-renderer/index.html" not in html
262  
263  
264  def test_notebook_iframe_includes_workspace_query_param(monkeypatch):
265      trace = create_trace("a")
266      mlflow.set_tracking_uri("http://localhost:5000")
267  
268      # Without workspace set, the query string should not contain workspace
269      html = get_notebook_iframe_html([trace])
270      assert "workspace=" not in html
271  
272      # With workspace set, the query string should contain workspace
273      monkeypatch.setenv("MLFLOW_WORKSPACE", "my-workspace")
274      html = get_notebook_iframe_html([trace])
275      assert "workspace=my-workspace" in html
276  
277  
278  def test_display_in_oss(monkeypatch):
279      mock_ipython = MockIPython()
280      monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython)
281      mock_display_handle = Mock()
282      mock_display = Mock(return_value=mock_display_handle)
283      monkeypatch.setattr("IPython.display.display", mock_display)
284      monkeypatch.setattr("IPython.display.HTML", Mock(side_effect=lambda html: html))
285  
286      handler = get_display_handler()
287      handler.display_traces([create_trace("a")])
288  
289      mock_ipython.mock_run_cell()
290  
291      # default tracking uri is sqlite, so no display call should be made
292      assert mock_display.call_count == 0
293  
294      # after setting an HTTP tracking URI, it should work
295      mlflow.set_tracking_uri("http://localhost:5000")
296  
297      handler = get_display_handler()
298      handler.display_traces([create_trace("a")])
299  
300      mock_ipython.mock_run_cell()
301  
302      assert mock_display.call_count == 1
303      assert "<iframe" in mock_display.call_args[0][0]["text/html"]