test_ipython.py
1 import json 2 from collections import defaultdict 3 from unittest.mock import Mock 4 5 import pytest 6 7 import mlflow 8 from mlflow.tracing.display import ( 9 IPythonTraceDisplayHandler, 10 get_display_handler, 11 get_notebook_iframe_html, 12 ) 13 14 from tests.tracing.helper import create_trace, skip_module_when_testing_trace_sdk 15 16 skip_module_when_testing_trace_sdk() 17 18 19 class MockEventRegistry: 20 def __init__(self): 21 self.events = defaultdict(list) 22 23 def register(self, event, callback): 24 self.events[event].append(callback) 25 26 def trigger(self, event): 27 for callback in self.events[event]: 28 callback(None) 29 30 31 class MockIPython: 32 def __init__(self): 33 self.events = MockEventRegistry() 34 35 def mock_run_cell(self): 36 self.events.trigger("post_run_cell") 37 38 39 @pytest.fixture 40 def _in_databricks(monkeypatch): 41 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "15.x") 42 43 44 @pytest.fixture(autouse=True) 45 def reset_singleton(): 46 IPythonTraceDisplayHandler._instance = None 47 IPythonTraceDisplayHandler.disabled = False 48 49 50 in_databricks = pytest.mark.usefixtures(_in_databricks.__name__) 51 52 53 @in_databricks 54 def test_display_is_not_called_without_ipython(monkeypatch): 55 # in an IPython environment, the interactive shell will 56 # be returned. however, for test purposes, just mock that 57 # the value is not None. 58 mock_display = Mock() 59 monkeypatch.setattr("IPython.display.display", mock_display) 60 handler = get_display_handler() 61 62 handler.display_traces([create_trace("a")]) 63 assert mock_display.call_count == 0 64 65 mock_ipython = MockIPython() 66 monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython) 67 68 # reset the singleton so the handler 69 # can register the post-display hook 70 IPythonTraceDisplayHandler._instance = None 71 handler = get_display_handler() 72 handler.display_traces([create_trace("b")]) 73 74 # simulate cell execution 75 mock_ipython.mock_run_cell() 76 77 assert mock_display.call_count == 1 78 79 80 @in_databricks 81 def test_ipython_client_clears_display_after_execution(monkeypatch): 82 mock_ipython = MockIPython() 83 monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython) 84 handler = get_display_handler() 85 86 mock_display_handle = Mock() 87 mock_display = Mock(return_value=mock_display_handle) 88 monkeypatch.setattr("IPython.display.display", mock_display) 89 handler.display_traces([create_trace("a")]) 90 handler.display_traces([create_trace("b")]) 91 handler.display_traces([create_trace("c")]) 92 93 mock_ipython.mock_run_cell() 94 # despite many calls to `display_traces`, 95 # there should only be one call to `display` 96 assert mock_display.call_count == 1 97 98 mock_ipython.mock_run_cell() 99 # expect that display is not called, 100 # since no traces should be present 101 assert mock_display.call_count == 1 102 103 104 @in_databricks 105 def test_display_is_called_in_correct_functions(monkeypatch): 106 mock_ipython = MockIPython() 107 monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython) 108 mock_display_handle = Mock() 109 mock_display = Mock(return_value=mock_display_handle) 110 monkeypatch.setattr("IPython.display.display", mock_display) 111 112 @mlflow.trace 113 def foo(): 114 return 3 115 116 # display should be called after trace creation 117 foo() 118 mlflow.flush_trace_async_logging() 119 mock_ipython.mock_run_cell() 120 assert mock_display.call_count == 1 121 122 123 @in_databricks 124 def test_display_deduplicates_traces(monkeypatch): 125 mock_ipython = MockIPython() 126 monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython) 127 handler = get_display_handler() 128 129 mock_display = Mock() 130 monkeypatch.setattr("IPython.display.display", mock_display) 131 132 trace_a = create_trace("a") 133 trace_b = create_trace("b") 134 trace_c = create_trace("c") 135 136 # The display client should dedupe traces to display and only display 3 (not 6). 137 handler.display_traces([trace_a]) 138 handler.display_traces([trace_b]) 139 handler.display_traces([trace_c]) 140 handler.display_traces([trace_a, trace_b, trace_c]) 141 mock_ipython.mock_run_cell() 142 143 expected = [trace_a, trace_b, trace_c] 144 145 assert mock_display.call_count == 1 146 assert mock_display.call_args[0][0] == { 147 "application/databricks.mlflow.trace": json.dumps([ 148 json.loads(t._serialize_for_mimebundle()) for t in expected 149 ]), 150 "text/plain": repr(expected), 151 } 152 153 154 @in_databricks 155 def test_display_respects_max_limit(monkeypatch): 156 mock_ipython = MockIPython() 157 monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython) 158 handler = get_display_handler() 159 160 mock_display = Mock() 161 monkeypatch.setattr("IPython.display.display", mock_display) 162 163 monkeypatch.setenv("MLFLOW_MAX_TRACES_TO_DISPLAY_IN_NOTEBOOK", "1") 164 165 trace_a = create_trace("a") 166 trace_b = create_trace("b") 167 trace_c = create_trace("c") 168 handler.display_traces([trace_a, trace_b, trace_c]) 169 mock_ipython.mock_run_cell() 170 171 assert mock_display.call_count == 1 172 assert mock_display.call_args[0][0] == { 173 "application/databricks.mlflow.trace": trace_a._serialize_for_mimebundle(), 174 "text/plain": repr(trace_a), 175 } 176 177 178 @in_databricks 179 def test_enable_and_disable_display(monkeypatch): 180 mock_ipython = MockIPython() 181 monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython) 182 mock_display_handle = Mock() 183 mock_display = Mock(return_value=mock_display_handle) 184 monkeypatch.setattr("IPython.display.display", mock_display) 185 trace_a = create_trace("a") 186 187 # test that disabling the display handler prevents display() from being called 188 mlflow.tracing.disable_notebook_display() 189 handler = get_display_handler() 190 handler.display_traces([trace_a]) 191 mock_ipython.mock_run_cell() 192 193 mock_display.assert_not_called() 194 195 # test that re-enabling it will make things display again 196 mlflow.tracing.enable_notebook_display() 197 handler = get_display_handler() 198 handler.display_traces([trace_a]) 199 mock_ipython.mock_run_cell() 200 201 assert mock_display.call_count == 1 202 assert mock_display.call_args[0][0] == { 203 "application/databricks.mlflow.trace": trace_a._serialize_for_mimebundle(), 204 "text/plain": repr(trace_a), 205 } 206 207 208 @in_databricks 209 def test_mimebundle_in_databricks(): 210 # by default, it should contain the metadata 211 # necessary for rendering the trace UI 212 trace = create_trace("a") 213 assert trace._repr_mimebundle_() == { 214 "application/databricks.mlflow.trace": trace._serialize_for_mimebundle(), 215 "text/plain": repr(trace), 216 } 217 218 # if trace display is disabled, only "text/plain" should exist 219 mlflow.tracing.disable_notebook_display() 220 assert trace._repr_mimebundle_() == { 221 "text/plain": repr(trace), 222 } 223 224 # re-enabling should bring the metadata back 225 mlflow.tracing.enable_notebook_display() 226 assert trace._repr_mimebundle_() == { 227 "application/databricks.mlflow.trace": trace._serialize_for_mimebundle(), 228 "text/plain": repr(trace), 229 } 230 231 232 def test_mimebundle_in_oss(): 233 # if the user is not using a tracking server, it should only contain text/plain 234 trace = create_trace("a") 235 assert trace._repr_mimebundle_() == { 236 "text/plain": repr(trace), 237 } 238 239 # if the user is using a tracking server, it 240 # should contain an iframe in the text/html key 241 mlflow.set_tracking_uri("http://localhost:5000") 242 assert trace._repr_mimebundle_() == { 243 "text/plain": repr(trace), 244 "text/html": get_notebook_iframe_html([trace]), 245 } 246 247 # disabling should remove this key, even if tracking server is used 248 mlflow.tracing.disable_notebook_display() 249 assert trace._repr_mimebundle_() == { 250 "text/plain": repr(trace), 251 } 252 253 254 def test_notebook_trace_renderer_base_url_override(monkeypatch): 255 trace = create_trace("a") 256 mlflow.set_tracking_uri("http://mlflow:5000") 257 monkeypatch.setenv("MLFLOW_NOTEBOOK_TRACE_RENDERER_BASE_URL", "http://localhost:5000") 258 259 html = get_notebook_iframe_html([trace]) 260 assert "http://localhost:5000/static-files/lib/notebook-trace-renderer/index.html" in html 261 assert "http://mlflow:5000/static-files/lib/notebook-trace-renderer/index.html" not in html 262 263 264 def test_notebook_iframe_includes_workspace_query_param(monkeypatch): 265 trace = create_trace("a") 266 mlflow.set_tracking_uri("http://localhost:5000") 267 268 # Without workspace set, the query string should not contain workspace 269 html = get_notebook_iframe_html([trace]) 270 assert "workspace=" not in html 271 272 # With workspace set, the query string should contain workspace 273 monkeypatch.setenv("MLFLOW_WORKSPACE", "my-workspace") 274 html = get_notebook_iframe_html([trace]) 275 assert "workspace=my-workspace" in html 276 277 278 def test_display_in_oss(monkeypatch): 279 mock_ipython = MockIPython() 280 monkeypatch.setattr("IPython.get_ipython", lambda: mock_ipython) 281 mock_display_handle = Mock() 282 mock_display = Mock(return_value=mock_display_handle) 283 monkeypatch.setattr("IPython.display.display", mock_display) 284 monkeypatch.setattr("IPython.display.HTML", Mock(side_effect=lambda html: html)) 285 286 handler = get_display_handler() 287 handler.display_traces([create_trace("a")]) 288 289 mock_ipython.mock_run_cell() 290 291 # default tracking uri is sqlite, so no display call should be made 292 assert mock_display.call_count == 0 293 294 # after setting an HTTP tracking URI, it should work 295 mlflow.set_tracking_uri("http://localhost:5000") 296 297 handler = get_display_handler() 298 handler.display_traces([create_trace("a")]) 299 300 mock_ipython.mock_run_cell() 301 302 assert mock_display.call_count == 1 303 assert "<iframe" in mock_display.call_args[0][0]["text/html"]