/ tests / tracking / context / test_jupyter_notebook_context.py
test_jupyter_notebook_context.py
  1  import json
  2  from unittest import mock
  3  
  4  import pytest
  5  
  6  from mlflow.entities import SourceType
  7  from mlflow.tracking.context.jupyter_notebook_context import (
  8      JupyterNotebookRunContext,
  9      _get_kernel_id,
 10      _get_notebook_name,
 11      _get_notebook_path_from_sessions,
 12      _get_running_servers,
 13      _get_sessions_notebook,
 14      _get_vscode_notebook_path,
 15      _is_in_jupyter_notebook,
 16  )
 17  from mlflow.utils.mlflow_tags import MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE
 18  
 19  MOCK_NOTEBOOK_NAME = "test_notebook.ipynb"
 20  MOCK_NOTEBOOK_PATH = f"/path/to/{MOCK_NOTEBOOK_NAME}"
 21  MOCK_KERNEL_ID = "abc123-def456"
 22  
 23  
 24  @pytest.mark.parametrize(
 25      ("shell_name", "has_kernel", "is_in_ipython", "expected"),
 26      [
 27          ("ZMQInteractiveShell", False, True, True),
 28          ("SomeOtherShell", True, True, True),
 29          ("TerminalInteractiveShell", False, True, False),
 30          ("AnyShell", False, False, False),
 31      ],
 32  )
 33  def test_is_in_jupyter_notebook(shell_name, has_kernel, is_in_ipython, expected):
 34      mock_shell = mock.Mock(spec=["kernel"] if has_kernel else [])
 35      mock_shell.__class__.__name__ = shell_name
 36      if has_kernel:
 37          mock_shell.kernel = mock.Mock()
 38  
 39      mock_ipython = mock.Mock()
 40      mock_ipython.get_ipython.return_value = mock_shell
 41  
 42      with (
 43          mock.patch(
 44              "mlflow.tracking.context.jupyter_notebook_context.is_running_in_ipython_environment",
 45              return_value=is_in_ipython,
 46          ),
 47          mock.patch.dict("sys.modules", {"IPython": mock_ipython}),
 48      ):
 49          assert _is_in_jupyter_notebook() is expected
 50  
 51  
 52  @pytest.mark.parametrize(
 53      ("user_ns", "is_in_ipython", "expected"),
 54      [
 55          ({"__vsc_ipynb_file__": MOCK_NOTEBOOK_PATH}, True, MOCK_NOTEBOOK_PATH),
 56          ({}, True, None),
 57          ({"__vsc_ipynb_file__": MOCK_NOTEBOOK_PATH}, False, None),
 58      ],
 59  )
 60  def test_get_vscode_notebook_path(user_ns, is_in_ipython, expected):
 61      mock_shell = mock.Mock()
 62      mock_shell.user_ns = user_ns
 63  
 64      mock_ipython = mock.Mock()
 65      mock_ipython.get_ipython.return_value = mock_shell
 66  
 67      with (
 68          mock.patch(
 69              "mlflow.tracking.context.jupyter_notebook_context.is_running_in_ipython_environment",
 70              return_value=is_in_ipython,
 71          ),
 72          mock.patch.dict("sys.modules", {"IPython": mock_ipython}),
 73      ):
 74          assert _get_vscode_notebook_path() == expected
 75  
 76  
 77  def test_get_kernel_id_success():
 78      mock_ipykernel = mock.Mock()
 79      mock_ipykernel.get_connection_file.return_value = f"/path/to/kernel-{MOCK_KERNEL_ID}.json"
 80  
 81      with (
 82          mock.patch.dict("sys.modules", {"ipykernel": mock_ipykernel}),
 83          mock.patch("mlflow.tracking.context.jupyter_notebook_context.Path") as mock_path,
 84      ):
 85          mock_path.return_value.stem = f"kernel-{MOCK_KERNEL_ID}"
 86          result = _get_kernel_id()
 87          assert result == MOCK_KERNEL_ID
 88  
 89  
 90  def test_get_kernel_id_import_error():
 91      with mock.patch.dict("sys.modules", {"ipykernel": None}):
 92          result = _get_kernel_id()
 93          assert result is None
 94  
 95  
 96  def test_get_running_servers_finds_servers(tmp_path):
 97      server_info = {"url": "http://localhost:8888/", "token": "test_token"}
 98      server_file = tmp_path / "nbserver-12345.json"
 99      server_file.write_text(json.dumps(server_info))
100  
101      mock_jupyter_core = mock.Mock()
102      mock_jupyter_core.paths.jupyter_runtime_dir.return_value = str(tmp_path)
103  
104      with (
105          mock.patch.dict(
106              "sys.modules",
107              {"jupyter_core": mock_jupyter_core, "jupyter_core.paths": mock_jupyter_core.paths},
108          ),
109          mock.patch(
110              "mlflow.tracking.context.jupyter_notebook_context.Path",
111              return_value=tmp_path,
112          ),
113      ):
114          list(_get_running_servers())
115  
116  
117  def test_get_running_servers_no_servers(tmp_path):
118      mock_jupyter_core = mock.Mock()
119      mock_jupyter_core.paths.jupyter_runtime_dir.return_value = str(tmp_path)
120  
121      with (
122          mock.patch.dict(
123              "sys.modules",
124              {"jupyter_core": mock_jupyter_core, "jupyter_core.paths": mock_jupyter_core.paths},
125          ),
126          mock.patch("mlflow.tracking.context.jupyter_notebook_context.Path") as mock_path,
127      ):
128          mock_path_instance = mock.Mock()
129          mock_path_instance.is_dir.return_value = True
130          mock_path_instance.glob.return_value = []
131          mock_path.return_value = mock_path_instance
132  
133          servers = list(_get_running_servers())
134          assert servers == []
135  
136  
137  def test_get_running_servers_import_error():
138      with mock.patch.dict("sys.modules", {"jupyter_core": None, "jupyter_core.paths": None}):
139          servers = list(_get_running_servers())
140          assert servers == []
141  
142  
143  def test_get_sessions_notebook_finds_notebook():
144      server = {"url": "http://localhost:8888/", "token": "test_token"}
145      mock_sessions = [{"kernel": {"id": MOCK_KERNEL_ID}, "path": MOCK_NOTEBOOK_PATH}]
146  
147      with mock.patch("mlflow.tracking.context.jupyter_notebook_context.urlopen") as mock_urlopen:
148          mock_response = mock.Mock()
149          mock_response.__enter__ = mock.Mock(return_value=mock_response)
150          mock_response.__exit__ = mock.Mock(return_value=False)
151          mock_response.read.return_value = json.dumps(mock_sessions).encode()
152          mock_urlopen.return_value = mock_response
153  
154          with mock.patch("json.load", return_value=mock_sessions):
155              result = _get_sessions_notebook(server, MOCK_KERNEL_ID)
156              assert result == MOCK_NOTEBOOK_PATH
157  
158  
159  def test_get_sessions_notebook_no_matching_kernel():
160      server = {"url": "http://localhost:8888/", "token": "test_token"}
161      mock_sessions = [{"kernel": {"id": "different_kernel"}, "path": "other_notebook.ipynb"}]
162  
163      with mock.patch("mlflow.tracking.context.jupyter_notebook_context.urlopen") as mock_urlopen:
164          mock_response = mock.Mock()
165          mock_response.__enter__ = mock.Mock(return_value=mock_response)
166          mock_response.__exit__ = mock.Mock(return_value=False)
167  
168          with mock.patch("json.load", return_value=mock_sessions):
169              mock_urlopen.return_value = mock_response
170              result = _get_sessions_notebook(server, MOCK_KERNEL_ID)
171              assert result is None
172  
173  
174  def test_get_sessions_notebook_connection_error():
175      server = {"url": "http://localhost:8888/", "token": "test_token"}
176  
177      with mock.patch(
178          "mlflow.tracking.context.jupyter_notebook_context.urlopen",
179          side_effect=Exception("Connection refused"),
180      ):
181          result = _get_sessions_notebook(server, MOCK_KERNEL_ID)
182          assert result is None
183  
184  
185  def test_get_sessions_notebook_with_jupyterhub_token(monkeypatch):
186      server = {"url": "http://localhost:8888/", "token": ""}
187      mock_sessions = [{"kernel": {"id": MOCK_KERNEL_ID}, "path": MOCK_NOTEBOOK_PATH}]
188  
189      monkeypatch.setenv("JUPYTERHUB_API_TOKEN", "hub_token")
190  
191      with (
192          mock.patch("mlflow.tracking.context.jupyter_notebook_context.urlopen") as mock_urlopen,
193          mock.patch("json.load", return_value=mock_sessions),
194      ):
195          mock_response = mock.Mock()
196          mock_response.__enter__ = mock.Mock(return_value=mock_response)
197          mock_response.__exit__ = mock.Mock(return_value=False)
198          mock_urlopen.return_value = mock_response
199          _get_sessions_notebook(server, MOCK_KERNEL_ID)
200          call_args = mock_urlopen.call_args
201          assert "hub_token" in call_args[0][0]
202  
203  
204  @pytest.mark.parametrize(
205      ("vscode_path", "env_vars", "sessions_path", "expected"),
206      [
207          (MOCK_NOTEBOOK_PATH, {}, None, MOCK_NOTEBOOK_NAME),
208          (None, {"__vsc_ipynb_file__": MOCK_NOTEBOOK_PATH}, None, MOCK_NOTEBOOK_NAME),
209          (None, {"IPYNB_FILE": MOCK_NOTEBOOK_PATH}, None, MOCK_NOTEBOOK_NAME),
210          (None, {}, MOCK_NOTEBOOK_PATH, MOCK_NOTEBOOK_NAME),
211          (None, {}, None, None),
212      ],
213  )
214  def test_get_notebook_name(vscode_path, env_vars, sessions_path, expected, monkeypatch):
215      _get_notebook_name.cache_clear()
216  
217      # Clear relevant env vars that the code checks
218      monkeypatch.delenv("__vsc_ipynb_file__", raising=False)
219      monkeypatch.delenv("IPYNB_FILE", raising=False)
220  
221      # Set the test env vars
222      for key, value in env_vars.items():
223          monkeypatch.setenv(key, value)
224  
225      with (
226          mock.patch(
227              "mlflow.tracking.context.jupyter_notebook_context._get_vscode_notebook_path",
228              return_value=vscode_path,
229          ),
230          mock.patch(
231              "mlflow.tracking.context.jupyter_notebook_context._get_notebook_path_from_sessions",
232              return_value=sessions_path,
233          ),
234      ):
235          assert _get_notebook_name() == expected
236  
237  
238  def test_get_notebook_name_is_cached():
239      _get_notebook_name.cache_clear()
240  
241      call_count = 0
242  
243      def mock_vscode_path():
244          nonlocal call_count
245          call_count += 1
246          return MOCK_NOTEBOOK_PATH
247  
248      with mock.patch(
249          "mlflow.tracking.context.jupyter_notebook_context._get_vscode_notebook_path",
250          side_effect=mock_vscode_path,
251      ):
252          result1 = _get_notebook_name()
253          result2 = _get_notebook_name()
254          result3 = _get_notebook_name()
255  
256          assert result1 == MOCK_NOTEBOOK_NAME
257          assert result2 == MOCK_NOTEBOOK_NAME
258          assert result3 == MOCK_NOTEBOOK_NAME
259          assert call_count == 1
260  
261  
262  def test_get_notebook_path_from_sessions_success():
263      mock_server = {"url": "http://localhost:8888/", "token": "test_token"}
264  
265      with (
266          mock.patch(
267              "mlflow.tracking.context.jupyter_notebook_context._get_kernel_id",
268              return_value=MOCK_KERNEL_ID,
269          ),
270          mock.patch(
271              "mlflow.tracking.context.jupyter_notebook_context._get_running_servers",
272              return_value=[mock_server],
273          ),
274          mock.patch(
275              "mlflow.tracking.context.jupyter_notebook_context._get_sessions_notebook",
276              return_value=MOCK_NOTEBOOK_PATH,
277          ),
278      ):
279          result = _get_notebook_path_from_sessions()
280          assert result == MOCK_NOTEBOOK_PATH
281  
282  
283  def test_get_notebook_path_from_sessions_no_kernel_id():
284      with mock.patch(
285          "mlflow.tracking.context.jupyter_notebook_context._get_kernel_id",
286          return_value=None,
287      ):
288          result = _get_notebook_path_from_sessions()
289          assert result is None
290  
291  
292  def test_get_notebook_path_from_sessions_no_servers():
293      with (
294          mock.patch(
295              "mlflow.tracking.context.jupyter_notebook_context._get_kernel_id",
296              return_value=MOCK_KERNEL_ID,
297          ),
298          mock.patch(
299              "mlflow.tracking.context.jupyter_notebook_context._get_running_servers",
300              return_value=[],
301          ),
302      ):
303          result = _get_notebook_path_from_sessions()
304          assert result is None
305  
306  
307  @pytest.mark.parametrize(
308      ("is_in_jupyter", "expected"),
309      [
310          (True, True),
311          (False, False),
312      ],
313  )
314  def test_jupyter_notebook_run_context_in_context(is_in_jupyter, expected):
315      with mock.patch(
316          "mlflow.tracking.context.jupyter_notebook_context._is_in_jupyter_notebook",
317          return_value=is_in_jupyter,
318      ):
319          assert JupyterNotebookRunContext().in_context() is expected
320  
321  
322  @pytest.mark.parametrize(
323      ("notebook_name", "expected_tags"),
324      [
325          (
326              MOCK_NOTEBOOK_NAME,
327              {
328                  MLFLOW_SOURCE_NAME: MOCK_NOTEBOOK_NAME,
329                  MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK),
330              },
331          ),
332          (
333              None,
334              {
335                  MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.NOTEBOOK),
336              },
337          ),
338      ],
339  )
340  def test_jupyter_notebook_run_context_tags(notebook_name, expected_tags):
341      _get_notebook_name.cache_clear()
342  
343      with mock.patch(
344          "mlflow.tracking.context.jupyter_notebook_context._get_notebook_name",
345          return_value=notebook_name,
346      ):
347          tags = JupyterNotebookRunContext().tags()
348          assert tags == expected_tags