test_jupyter_notebook_context.py
1 import json 2 from unittest import mock 3 4 import pytest 5 6 from mlflow.entities import SourceType 7 from mlflow.tracking.context.jupyter_notebook_context import ( 8 JupyterNotebookRunContext, 9 _get_kernel_id, 10 _get_notebook_name, 11 _get_notebook_path_from_sessions, 12 _get_running_servers, 13 _get_sessions_notebook, 14 _get_vscode_notebook_path, 15 _is_in_jupyter_notebook, 16 ) 17 from mlflow.utils.mlflow_tags import MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE 18 19 MOCK_NOTEBOOK_NAME = "test_notebook.ipynb" 20 MOCK_NOTEBOOK_PATH = f"/path/to/{MOCK_NOTEBOOK_NAME}" 21 MOCK_KERNEL_ID = "abc123-def456" 22 23 24 @pytest.mark.parametrize( 25 ("shell_name", "has_kernel", "is_in_ipython", "expected"), 26 [ 27 ("ZMQInteractiveShell", False, True, True), 28 ("SomeOtherShell", True, True, True), 29 ("TerminalInteractiveShell", False, True, False), 30 ("AnyShell", False, False, False), 31 ], 32 ) 33 def test_is_in_jupyter_notebook(shell_name, has_kernel, is_in_ipython, expected): 34 mock_shell = mock.Mock(spec=["kernel"] if has_kernel else []) 35 mock_shell.__class__.__name__ = shell_name 36 if has_kernel: 37 mock_shell.kernel = mock.Mock() 38 39 mock_ipython = mock.Mock() 40 mock_ipython.get_ipython.return_value = mock_shell 41 42 with ( 43 mock.patch( 44 "mlflow.tracking.context.jupyter_notebook_context.is_running_in_ipython_environment", 45 return_value=is_in_ipython, 46 ), 47 mock.patch.dict("sys.modules", {"IPython": mock_ipython}), 48 ): 49 assert _is_in_jupyter_notebook() is expected 50 51 52 @pytest.mark.parametrize( 53 ("user_ns", "is_in_ipython", "expected"), 54 [ 55 ({"__vsc_ipynb_file__": MOCK_NOTEBOOK_PATH}, True, MOCK_NOTEBOOK_PATH), 56 ({}, True, None), 57 ({"__vsc_ipynb_file__": MOCK_NOTEBOOK_PATH}, False, None), 58 ], 59 ) 60 def test_get_vscode_notebook_path(user_ns, is_in_ipython, expected): 61 mock_shell = mock.Mock() 62 mock_shell.user_ns = user_ns 63 64 mock_ipython = mock.Mock() 65 mock_ipython.get_ipython.return_value = mock_shell 66 67 with ( 68 mock.patch( 69 "mlflow.tracking.context.jupyter_notebook_context.is_running_in_ipython_environment", 70 return_value=is_in_ipython, 71 ), 72 mock.patch.dict("sys.modules", {"IPython": mock_ipython}), 73 ): 74 assert _get_vscode_notebook_path() == expected 75 76 77 def test_get_kernel_id_success(): 78 mock_ipykernel = mock.Mock() 79 mock_ipykernel.get_connection_file.return_value = f"/path/to/kernel-{MOCK_KERNEL_ID}.json" 80 81 with ( 82 mock.patch.dict("sys.modules", {"ipykernel": mock_ipykernel}), 83 mock.patch("mlflow.tracking.context.jupyter_notebook_context.Path") as mock_path, 84 ): 85 mock_path.return_value.stem = f"kernel-{MOCK_KERNEL_ID}" 86 result = _get_kernel_id() 87 assert result == MOCK_KERNEL_ID 88 89 90 def test_get_kernel_id_import_error(): 91 with mock.patch.dict("sys.modules", {"ipykernel": None}): 92 result = _get_kernel_id() 93 assert result is None 94 95 96 def test_get_running_servers_finds_servers(tmp_path): 97 server_info = {"url": "http://localhost:8888/", "token": "test_token"} 98 server_file = tmp_path / "nbserver-12345.json" 99 server_file.write_text(json.dumps(server_info)) 100 101 mock_jupyter_core = mock.Mock() 102 mock_jupyter_core.paths.jupyter_runtime_dir.return_value = str(tmp_path) 103 104 with ( 105 mock.patch.dict( 106 "sys.modules", 107 {"jupyter_core": mock_jupyter_core, "jupyter_core.paths": mock_jupyter_core.paths}, 108 ), 109 mock.patch( 110 "mlflow.tracking.context.jupyter_notebook_context.Path", 111 return_value=tmp_path, 112 ), 113 ): 114 list(_get_running_servers()) 115 116 117 def test_get_running_servers_no_servers(tmp_path): 118 mock_jupyter_core = mock.Mock() 119 mock_jupyter_core.paths.jupyter_runtime_dir.return_value = str(tmp_path) 120 121 with ( 122 mock.patch.dict( 123 "sys.modules", 124 {"jupyter_core": mock_jupyter_core, "jupyter_core.paths": mock_jupyter_core.paths}, 125 ), 126 mock.patch("mlflow.tracking.context.jupyter_notebook_context.Path") as mock_path, 127 ): 128 mock_path_instance = mock.Mock() 129 mock_path_instance.is_dir.return_value = True 130 mock_path_instance.glob.return_value = [] 131 mock_path.return_value = mock_path_instance 132 133 servers = list(_get_running_servers()) 134 assert servers == [] 135 136 137 def test_get_running_servers_import_error(): 138 with mock.patch.dict("sys.modules", {"jupyter_core": None, "jupyter_core.paths": None}): 139 servers = list(_get_running_servers()) 140 assert servers == [] 141 142 143 def test_get_sessions_notebook_finds_notebook(): 144 server = {"url": "http://localhost:8888/", "token": "test_token"} 145 mock_sessions = [{"kernel": {"id": MOCK_KERNEL_ID}, "path": MOCK_NOTEBOOK_PATH}] 146 147 with mock.patch("mlflow.tracking.context.jupyter_notebook_context.urlopen") as mock_urlopen: 148 mock_response = mock.Mock() 149 mock_response.__enter__ = mock.Mock(return_value=mock_response) 150 mock_response.__exit__ = mock.Mock(return_value=False) 151 mock_response.read.return_value = json.dumps(mock_sessions).encode() 152 mock_urlopen.return_value = mock_response 153 154 with mock.patch("json.load", return_value=mock_sessions): 155 result = _get_sessions_notebook(server, MOCK_KERNEL_ID) 156 assert result == MOCK_NOTEBOOK_PATH 157 158 159 def test_get_sessions_notebook_no_matching_kernel(): 160 server = {"url": "http://localhost:8888/", "token": "test_token"} 161 mock_sessions = [{"kernel": {"id": "different_kernel"}, "path": "other_notebook.ipynb"}] 162 163 with mock.patch("mlflow.tracking.context.jupyter_notebook_context.urlopen") as mock_urlopen: 164 mock_response = mock.Mock() 165 mock_response.__enter__ = mock.Mock(return_value=mock_response) 166 mock_response.__exit__ = mock.Mock(return_value=False) 167 168 with mock.patch("json.load", return_value=mock_sessions): 169 mock_urlopen.return_value = mock_response 170 result = _get_sessions_notebook(server, MOCK_KERNEL_ID) 171 assert result is None 172 173 174 def test_get_sessions_notebook_connection_error(): 175 server = {"url": "http://localhost:8888/", "token": "test_token"} 176 177 with mock.patch( 178 "mlflow.tracking.context.jupyter_notebook_context.urlopen", 179 side_effect=Exception("Connection refused"), 180 ): 181 result = _get_sessions_notebook(server, MOCK_KERNEL_ID) 182 assert result is None 183 184 185 def test_get_sessions_notebook_with_jupyterhub_token(monkeypatch): 186 server = {"url": "http://localhost:8888/", "token": ""} 187 mock_sessions = [{"kernel": {"id": MOCK_KERNEL_ID}, "path": MOCK_NOTEBOOK_PATH}] 188 189 monkeypatch.setenv("JUPYTERHUB_API_TOKEN", "hub_token") 190 191 with ( 192 mock.patch("mlflow.tracking.context.jupyter_notebook_context.urlopen") as mock_urlopen, 193 mock.patch("json.load", return_value=mock_sessions), 194 ): 195 mock_response = mock.Mock() 196 mock_response.__enter__ = mock.Mock(return_value=mock_response) 197 mock_response.__exit__ = mock.Mock(return_value=False) 198 mock_urlopen.return_value = mock_response 199 _get_sessions_notebook(server, MOCK_KERNEL_ID) 200 call_args = mock_urlopen.call_args 201 assert "hub_token" in call_args[0][0] 202 203 204 @pytest.mark.parametrize( 205 ("vscode_path", "env_vars", "sessions_path", "expected"), 206 [ 207 (MOCK_NOTEBOOK_PATH, {}, None, MOCK_NOTEBOOK_NAME), 208 (None, {"__vsc_ipynb_file__": MOCK_NOTEBOOK_PATH}, None, MOCK_NOTEBOOK_NAME), 209 (None, {"IPYNB_FILE": MOCK_NOTEBOOK_PATH}, None, MOCK_NOTEBOOK_NAME), 210 (None, {}, MOCK_NOTEBOOK_PATH, MOCK_NOTEBOOK_NAME), 211 (None, {}, None, None), 212 ], 213 ) 214 def test_get_notebook_name(vscode_path, env_vars, sessions_path, expected, monkeypatch): 215 _get_notebook_name.cache_clear() 216 217 # Clear relevant env vars that the code checks 218 monkeypatch.delenv("__vsc_ipynb_file__", raising=False) 219 monkeypatch.delenv("IPYNB_FILE", raising=False) 220 221 # Set the test env vars 222 for key, value in env_vars.items(): 223 monkeypatch.setenv(key, value) 224 225 with ( 226 mock.patch( 227 "mlflow.tracking.context.jupyter_notebook_context._get_vscode_notebook_path", 228 return_value=vscode_path, 229 ), 230 mock.patch( 231 "mlflow.tracking.context.jupyter_notebook_context._get_notebook_path_from_sessions", 232 return_value=sessions_path, 233 ), 234 ): 235 assert _get_notebook_name() == expected 236 237 238 def test_get_notebook_name_is_cached(): 239 _get_notebook_name.cache_clear() 240 241 call_count = 0 242 243 def mock_vscode_path(): 244 nonlocal call_count 245 call_count += 1 246 return MOCK_NOTEBOOK_PATH 247 248 with mock.patch( 249 "mlflow.tracking.context.jupyter_notebook_context._get_vscode_notebook_path", 250 side_effect=mock_vscode_path, 251 ): 252 result1 = _get_notebook_name() 253 result2 = _get_notebook_name() 254 result3 = _get_notebook_name() 255 256 assert result1 == MOCK_NOTEBOOK_NAME 257 assert result2 == MOCK_NOTEBOOK_NAME 258 assert result3 == MOCK_NOTEBOOK_NAME 259 assert call_count == 1 260 261 262 def test_get_notebook_path_from_sessions_success(): 263 mock_server = {"url": "http://localhost:8888/", "token": "test_token"} 264 265 with ( 266 mock.patch( 267 "mlflow.tracking.context.jupyter_notebook_context._get_kernel_id", 268 return_value=MOCK_KERNEL_ID, 269 ), 270 mock.patch( 271 "mlflow.tracking.context.jupyter_notebook_context._get_running_servers", 272 return_value=[mock_server], 273 ), 274 mock.patch( 275 "mlflow.tracking.context.jupyter_notebook_context._get_sessions_notebook", 276 return_value=MOCK_NOTEBOOK_PATH, 277 ), 278 ): 279 result = _get_notebook_path_from_sessions() 280 assert result == MOCK_NOTEBOOK_PATH 281 282 283 def test_get_notebook_path_from_sessions_no_kernel_id(): 284 with mock.patch( 285 "mlflow.tracking.context.jupyter_notebook_context._get_kernel_id", 286 return_value=None, 287 ): 288 result = _get_notebook_path_from_sessions() 289 assert result is None 290 291 292 def test_get_notebook_path_from_sessions_no_servers(): 293 with ( 294 mock.patch( 295 "mlflow.tracking.context.jupyter_notebook_context._get_kernel_id", 296 return_value=MOCK_KERNEL_ID, 297 ), 298 mock.patch( 299 "mlflow.tracking.context.jupyter_notebook_context._get_running_servers", 300 return_value=[], 301 ), 302 ): 303 result = _get_notebook_path_from_sessions() 304 assert result is None 305 306 307 @pytest.mark.parametrize( 308 ("is_in_jupyter", "expected"), 309 [ 310 (True, True), 311 (False, False), 312 ], 313 ) 314 def test_jupyter_notebook_run_context_in_context(is_in_jupyter, expected): 315 with mock.patch( 316 "mlflow.tracking.context.jupyter_notebook_context._is_in_jupyter_notebook", 317 return_value=is_in_jupyter, 318 ): 319 assert JupyterNotebookRunContext().in_context() is expected 320 321 322 @pytest.mark.parametrize( 323 ("notebook_name", "expected_tags"), 324 [ 325 ( 326 MOCK_NOTEBOOK_NAME, 327 { 328 MLFLOW_SOURCE_NAME: MOCK_NOTEBOOK_NAME, 329 MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK), 330 }, 331 ), 332 ( 333 None, 334 { 335 MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK), 336 }, 337 ), 338 ], 339 ) 340 def test_jupyter_notebook_run_context_tags(notebook_name, expected_tags): 341 _get_notebook_name.cache_clear() 342 343 with mock.patch( 344 "mlflow.tracking.context.jupyter_notebook_context._get_notebook_name", 345 return_value=notebook_name, 346 ): 347 tags = JupyterNotebookRunContext().tags() 348 assert tags == expected_tags