/ tests / utils / test_validation.py
test_validation.py
  1  import copy
  2  import socket
  3  from unittest.mock import patch
  4  
  5  import pytest
  6  
  7  from mlflow.entities import Metric, Param, RunTag
  8  from mlflow.environment_variables import MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode
 11  from mlflow.utils.os import is_windows
 12  from mlflow.utils.validation import (
 13      MAX_TAG_VAL_LENGTH,
 14      _is_numeric,
 15      _validate_batch_log_data,
 16      _validate_batch_log_limits,
 17      _validate_db_type_string,
 18      _validate_experiment_artifact_location,
 19      _validate_experiment_artifact_location_length,
 20      _validate_experiment_name,
 21      _validate_list_param,
 22      _validate_metric_name,
 23      _validate_model_alias_name,
 24      _validate_model_alias_name_reserved,
 25      _validate_model_name,
 26      _validate_model_renaming,
 27      _validate_param_name,
 28      _validate_run_id,
 29      _validate_tag_name,
 30      _validate_webhook_url,
 31      path_not_unique,
 32  )
 33  
 34  GOOD_METRIC_OR_PARAM_NAMES = [
 35      "a",
 36      "Ab-5_",
 37      "a/b/c",
 38      "a.b.c",
 39      ".a",
 40      "b.",
 41      "a..a/._./o_O/.e.",
 42      "a b/c d",
 43  ]
 44  BAD_METRIC_OR_PARAM_NAMES = [
 45      "",
 46      ".",
 47      "/",
 48      "..",
 49      "//",
 50      "a//b",
 51      "a/./b",
 52      "/a",
 53      "a/",
 54      "\\",
 55      "./",
 56      "/./",
 57  ]
 58  
 59  GOOD_ALIAS_NAMES = [
 60      "a",
 61      "Ab-5_",
 62      "test-alias",
 63      "1a2b5cDeFgH",
 64      "a" * 255,
 65      "lates",  # spellchecker: disable-line
 66      "v123_temp",
 67      "123",
 68      "123v",
 69      "temp_V123",
 70  ]
 71  
 72  BAD_ALIAS_NAMES = [
 73      "",
 74      ".",
 75      "/",
 76      "..",
 77      "//",
 78      "a b",
 79      "a/./b",
 80      "/a",
 81      "a/",
 82      ":",
 83      "\\",
 84      "./",
 85      "/./",
 86      "a" * 256,
 87      None,
 88      "$dgs",
 89  ]
 90  
 91  
 92  @pytest.mark.parametrize(
 93      ("path", "expected"),
 94      [
 95          ("a", False),
 96          ("a/b/c", False),
 97          ("a.b/c", False),
 98          (".a", False),
 99          # Not unique paths
100          ("./a", True),
101          ("a/b/../c", True),
102          (".", True),
103          ("../a/b", True),
104          ("/a/b/c", True),
105      ],
106  )
107  def test_path_not_unique(path, expected):
108      assert path_not_unique(path) is expected
109  
110  
111  @pytest.mark.parametrize(
112      ("value", "expected"),
113      [
114          (0, True),
115          (0.0, True),
116          # Non-numeric cases
117          (True, False),
118          (False, False),
119          ("0", False),
120          (None, False),
121      ],
122  )
123  def test_is_numeric(value, expected):
124      assert _is_numeric(value) is expected
125  
126  
127  @pytest.mark.parametrize("metric_name", GOOD_METRIC_OR_PARAM_NAMES)
128  def test_validate_metric_name_good(metric_name):
129      _validate_metric_name(metric_name)
130  
131  
132  def _bad_parameter_pattern(name):
133      if name == "\\":
134          return r"Invalid value \"\\\\\" for parameter"  # Manually handle the backslash case
135      elif name == "*****":
136          return r"Invalid value \"\*\*\*\*\*\" for parameter"
137      else:
138          return f'Invalid value "{name}" for parameter'
139  
140  
141  @pytest.mark.parametrize("metric_name", BAD_METRIC_OR_PARAM_NAMES)
142  def test_validate_metric_name_bad(metric_name):
143      with pytest.raises(
144          MlflowException,
145          match=_bad_parameter_pattern(metric_name),
146      ) as e:
147          _validate_metric_name(metric_name)
148      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
149  
150  
151  @pytest.mark.parametrize("param_name", GOOD_METRIC_OR_PARAM_NAMES)
152  def test_validate_param_name_good(param_name):
153      _validate_param_name(param_name)
154  
155  
156  @pytest.mark.parametrize("param_name", BAD_METRIC_OR_PARAM_NAMES)
157  def test_validate_param_name_bad(param_name):
158      with pytest.raises(MlflowException, match=_bad_parameter_pattern(param_name)) as e:
159          _validate_param_name(param_name)
160      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
161  
162  
163  @pytest.mark.skipif(not is_windows(), reason="Windows do not support colon in params and metrics")
164  @pytest.mark.parametrize(
165      "param_name",
166      [
167          ":",
168          "aa:bb:cc",
169      ],
170  )
171  def test_validate_colon_name_bad_windows(param_name):
172      with pytest.raises(MlflowException, match=_bad_parameter_pattern(param_name)) as e:
173          _validate_param_name(param_name)
174      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
175  
176  
177  @pytest.mark.parametrize("tag_name", GOOD_METRIC_OR_PARAM_NAMES)
178  def test_validate_tag_name_good(tag_name):
179      _validate_tag_name(tag_name)
180  
181  
182  @pytest.mark.parametrize("tag_name", BAD_METRIC_OR_PARAM_NAMES)
183  def test_validate_tag_name_bad(tag_name):
184      with pytest.raises(MlflowException, match=_bad_parameter_pattern(tag_name)) as e:
185          _validate_tag_name(tag_name)
186      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
187  
188  
189  @pytest.mark.parametrize("alias_name", GOOD_ALIAS_NAMES)
190  def test_validate_model_alias_name_good(alias_name):
191      _validate_model_alias_name(alias_name)
192  
193  
194  @pytest.mark.parametrize("alias_name", BAD_ALIAS_NAMES)
195  def test_validate_model_alias_name_bad(alias_name):
196      with pytest.raises(MlflowException, match="alias name") as e:
197          _validate_model_alias_name(alias_name)
198      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
199  
200  
201  @pytest.mark.parametrize("alias_name", ["latest", "LATEST", "Latest", "v123", "V1"])
202  def test_validate_model_alias_name_reserved(alias_name):
203      with pytest.raises(MlflowException, match="reserved") as e:
204          _validate_model_alias_name_reserved(alias_name)
205      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
206  
207  
208  @pytest.mark.parametrize(
209      "run_id",
210      [
211          "a" * 32,
212          "f0" * 16,
213          "abcdef0123456789" * 2,
214          "a" * 33,
215          "a" * 31,
216          "a" * 256,
217          "A" * 32,
218          "g" * 32,
219          "a_" * 32,
220          "abcdefghijklmnopqrstuvqxyz",
221      ],
222  )
223  def test_validate_run_id_good(run_id):
224      _validate_run_id(run_id)
225  
226  
227  @pytest.mark.parametrize("run_id", ["a/bc" * 8, "", "a" * 400, "*" * 5])
228  def test_validate_run_id_bad(run_id):
229      with pytest.raises(MlflowException, match=_bad_parameter_pattern(run_id)) as e:
230          _validate_run_id(run_id)
231      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
232  
233  
234  def test_validate_batch_log_limits():
235      too_many_metrics = [Metric(f"metric-key-{i}", 1, 0, i * 2) for i in range(1001)]
236      too_many_params = [Param(f"param-key-{i}", "b") for i in range(101)]
237      too_many_tags = [RunTag(f"tag-key-{i}", "b") for i in range(101)]
238  
239      good_kwargs = {"metrics": [], "params": [], "tags": []}
240      bad_kwargs = {
241          "metrics": [too_many_metrics],
242          "params": [too_many_params],
243          "tags": [too_many_tags],
244      }
245      match = r"A batch logging request can contain at most \d+"
246      for arg_name, arg_values in bad_kwargs.items():
247          for arg_value in arg_values:
248              final_kwargs = copy.deepcopy(good_kwargs)
249              final_kwargs[arg_name] = arg_value
250              with pytest.raises(MlflowException, match=match):
251                  _validate_batch_log_limits(**final_kwargs)
252      # Test the case where there are too many entities in aggregate
253      with pytest.raises(MlflowException, match=match):
254          _validate_batch_log_limits(too_many_metrics[:900], too_many_params[:51], too_many_tags[:50])
255      # Test that we don't reject entities within the limit
256      _validate_batch_log_limits(too_many_metrics[:1000], [], [])
257      _validate_batch_log_limits([], too_many_params[:100], [])
258      _validate_batch_log_limits([], [], too_many_tags[:100])
259  
260  
261  def test_validate_batch_log_data(monkeypatch):
262      metrics_with_bad_key = [
263          Metric("good-metric-key", 1.0, 0, 0),
264          Metric("super-long-bad-key" * 1000, 4.0, 0, 0),
265      ]
266      metrics_with_bad_val = [Metric("good-metric-key", "not-a-double-val", 0, 0)]
267      metrics_with_bool_val = [Metric("good-metric-key", True, 0, 0)]
268      metrics_with_bad_ts = [Metric("good-metric-key", 1.0, "not-a-timestamp", 0)]
269      metrics_with_neg_ts = [Metric("good-metric-key", 1.0, -123, 0)]
270      metrics_with_bad_step = [Metric("good-metric-key", 1.0, 0, "not-a-step")]
271      params_with_bad_key = [
272          Param("good-param-key", "hi"),
273          Param("super-long-bad-key" * 1000, "but-good-val"),
274      ]
275      params_with_bad_val = [
276          Param("good-param-key", "hi"),
277          Param("another-good-key", "but-bad-val" * 1000),
278      ]
279      tags_with_bad_key = [
280          RunTag("good-tag-key", "hi"),
281          RunTag("super-long-bad-key" * 1000, "but-good-val"),
282      ]
283      tags_with_bad_val = [
284          RunTag("good-tag-key", "hi"),
285          RunTag("another-good-key", "a" * (MAX_TAG_VAL_LENGTH + 1)),
286      ]
287      bad_kwargs = {
288          "metrics": [
289              metrics_with_bad_key,
290              metrics_with_bad_val,
291              metrics_with_bool_val,
292              metrics_with_bad_ts,
293              metrics_with_neg_ts,
294              metrics_with_bad_step,
295          ],
296          "params": [params_with_bad_key, params_with_bad_val],
297          "tags": [tags_with_bad_key, tags_with_bad_val],
298      }
299      good_kwargs = {"metrics": [], "params": [], "tags": []}
300      monkeypatch.setenv("MLFLOW_TRUNCATE_LONG_VALUES", "false")
301      for arg_name, arg_values in bad_kwargs.items():
302          for arg_value in arg_values:
303              final_kwargs = copy.deepcopy(good_kwargs)
304              final_kwargs[arg_name] = arg_value
305              with pytest.raises(MlflowException, match=r".+"):
306                  _validate_batch_log_data(**final_kwargs)
307      # Test that we don't reject entities within the limit
308      _validate_batch_log_data(
309          metrics=[Metric("metric-key", 1.0, 0, 0)],
310          params=[Param("param-key", "param-val")],
311          tags=[RunTag("tag-key", "tag-val")],
312      )
313  
314  
315  @pytest.mark.parametrize("location", ["abcde", None])
316  def test_validate_experiment_artifact_location_good(location):
317      _validate_experiment_artifact_location(location)
318  
319  
320  @pytest.mark.parametrize("location", ["runs:/blah/bleh/blergh"])
321  def test_validate_experiment_artifact_location_bad(location):
322      with pytest.raises(MlflowException, match="Artifact location cannot be a runs:/ URI"):
323          _validate_experiment_artifact_location(location)
324  
325  
326  @pytest.mark.parametrize("experiment_name", ["validstring", b"test byte string".decode("utf-8")])
327  def test_validate_experiment_name_good(experiment_name):
328      _validate_experiment_name(experiment_name)
329  
330  
331  @pytest.mark.parametrize("experiment_name", ["", 12, 12.7, None, {}, []])
332  def test_validate_experiment_name_bad(experiment_name):
333      with pytest.raises(MlflowException, match="Invalid experiment name"):
334          _validate_experiment_name(experiment_name)
335  
336  
337  @pytest.mark.parametrize("db_type", ["mysql", "mssql", "postgresql", "sqlite"])
338  def test_validate_db_type_string_good(db_type):
339      _validate_db_type_string(db_type)
340  
341  
342  @pytest.mark.parametrize("db_type", ["MySQL", "mongo", "cassandra", "sql", ""])
343  def test_validate_db_type_string_bad(db_type):
344      with pytest.raises(MlflowException, match="Invalid database engine") as e:
345          _validate_db_type_string(db_type)
346      assert "Invalid database engine" in e.value.message
347  
348  
349  @pytest.mark.parametrize(
350      "artifact_location",
351      [
352          "s3://test-bucket/",
353          "file:///path/to/artifacts",
354          "mlflow-artifacts:/path/to/artifacts",
355          "dbfs:/databricks/mlflow-tracking/some-id",
356      ],
357  )
358  def test_validate_experiment_artifact_location_length_good(artifact_location):
359      _validate_experiment_artifact_location_length(artifact_location)
360  
361  
362  @pytest.mark.parametrize(
363      "artifact_location",
364      ["s3://test-bucket/" + "a" * 10000, "file:///path/to/" + "directory" * 1111],
365      ids=["s3_long_path", "file_long_path"],
366  )
367  def test_validate_experiment_artifact_location_length_bad(artifact_location):
368      with pytest.raises(MlflowException, match="Invalid artifact path length"):
369          _validate_experiment_artifact_location_length(artifact_location)
370  
371  
372  def test_setting_experiment_artifact_location_env_var_works(monkeypatch):
373      artifact_location = "file://aaaa"  # length 11
374  
375      # should not throw
376      _validate_experiment_artifact_location_length(artifact_location)
377  
378      # reduce limit to 10
379      monkeypatch.setenv(MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH.name, "10")
380      with pytest.raises(MlflowException, match="Invalid artifact path length"):
381          _validate_experiment_artifact_location_length(artifact_location)
382  
383      # increase limit to 11
384      monkeypatch.setenv(MLFLOW_ARTIFACT_LOCATION_MAX_LENGTH.name, "11")
385      _validate_experiment_artifact_location_length(artifact_location)
386  
387  
388  @pytest.mark.parametrize(
389      "param_value",
390      [
391          ["1", "2", "3"],
392          [],
393          [1, 2, 3],
394      ],
395  )
396  def test_validate_list_param_with_valid_list(param_value):
397      _validate_list_param("experiment_ids", param_value)
398  
399  
400  def test_validate_list_param_with_none_not_allowed():
401      with pytest.raises(MlflowException, match="experiment_ids must be a list"):
402          _validate_list_param("experiment_ids", None, allow_none=False)
403  
404  
405  def test_validate_list_param_with_none_allowed():
406      _validate_list_param("experiment_ids", None, allow_none=True)
407  
408  
409  @pytest.mark.parametrize(
410      ("param_name", "param_value", "expected_type"),
411      [
412          ("experiment_ids", 4, "int"),
413          ("param_name", "value", "str"),
414          ("my_param", {"key": "value"}, "dict"),
415      ],
416  )
417  def test_validate_list_param_with_invalid_type(param_name, param_value, expected_type):
418      with pytest.raises(
419          MlflowException, match=rf"{param_name} must be a list, got {expected_type}"
420      ) as exc_info:
421          _validate_list_param(param_name, param_value)
422      assert f"Did you mean to use {param_name}=[{param_value!r}]?" in str(exc_info.value)
423      assert exc_info.value.error_code == "INVALID_PARAMETER_VALUE"
424  
425  
426  # -- _validate_webhook_url tests --
427  
428  
429  def _mock_getaddrinfo(ip_str):
430      return lambda host, port, *a, **kw: [(None, None, None, None, (ip_str, 0))]
431  
432  
433  @pytest.mark.parametrize(
434      ("url", "expected_match"),
435      [
436          (123, "Webhook URL must be a string"),
437          ("", "Webhook URL cannot be empty"),
438          ("   ", "Webhook URL cannot be empty"),
439          ("ftp://example.com", "Invalid webhook URL scheme"),
440          ("http://example.com", "Invalid webhook URL scheme"),
441          ("https://", "must include a hostname"),
442      ],
443  )
444  def test_validate_webhook_url_rejects_invalid_input(url, expected_match):
445      with pytest.raises(MlflowException, match=expected_match):
446          _validate_webhook_url(url)
447  
448  
449  @pytest.mark.parametrize(
450      ("url", "resolved_ip"),
451      [
452          ("https://127.0.0.1/callback", "127.0.0.1"),
453          ("https://localhost/callback", "127.0.0.1"),
454          ("https://internal.corp/hook", "10.0.0.1"),
455          ("https://internal.corp/hook", "172.16.0.1"),
456          ("https://internal.corp/hook", "192.168.1.1"),
457          ("https://metadata.internal/hook", "169.254.169.254"),
458          ("https://cgnat.internal/hook", "100.64.0.1"),
459          ("https://ipv6-loopback.internal/hook", "::1"),
460          ("https://ipv6-private.internal/hook", "fc00::1"),
461      ],
462  )
463  def test_validate_webhook_url_rejects_private_ips(url, resolved_ip):
464      with patch(
465          "mlflow.utils.validation.socket.getaddrinfo",
466          side_effect=_mock_getaddrinfo(resolved_ip),
467      ):
468          with pytest.raises(MlflowException, match="must not resolve to a non-public"):
469              _validate_webhook_url(url)
470  
471  
472  def test_validate_webhook_url_rejects_unresolvable_hostname():
473      with patch(
474          "mlflow.utils.validation.socket.getaddrinfo",
475          side_effect=socket.gaierror("Name or service not known"),
476      ):
477          with pytest.raises(MlflowException, match="Cannot resolve webhook URL hostname"):
478              _validate_webhook_url("https://does-not-exist.invalid/hook")
479  
480  
481  def test_validate_webhook_url_rejects_if_any_resolved_address_is_private():
482      def multi_resolve(host, port, *a, **kw):
483          return [
484              (None, None, None, None, ("8.8.8.8", 0)),
485              (None, None, None, None, ("10.0.0.1", 0)),
486          ]
487  
488      with patch("mlflow.utils.validation.socket.getaddrinfo", side_effect=multi_resolve):
489          with pytest.raises(MlflowException, match="must not resolve to a non-public"):
490              _validate_webhook_url("https://dual-homed.example.com/hook")
491  
492  
493  def test_validate_webhook_url_accepts_public_ip():
494      with patch(
495          "mlflow.utils.validation.socket.getaddrinfo",
496          side_effect=_mock_getaddrinfo("8.8.8.8"),
497      ):
498          _validate_webhook_url("https://example.com/webhook")
499  
500  
501  def test_validate_webhook_url_allow_private_ips_env_var(monkeypatch):
502      monkeypatch.setenv("MLFLOW_WEBHOOK_ALLOW_PRIVATE_IPS", "true")
503      with patch(
504          "mlflow.utils.validation.socket.getaddrinfo",
505          side_effect=_mock_getaddrinfo("127.0.0.1"),
506      ):
507          _validate_webhook_url("https://localhost/callback")
508  
509  
510  @pytest.mark.parametrize("invalid_name", ["my/model", "model:v1", "name/with:both"])
511  def test_validate_model_name_invalid_chars(invalid_name):
512      with pytest.raises(
513          MlflowException,
514          match="Names cannot contain '/' or ':'",
515          check=lambda e: e.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE),
516      ):
517          _validate_model_name(invalid_name)
518  
519  
520  @pytest.mark.parametrize("invalid_name", ["my/model", "model:v1", "name/with:both"])
521  def test_validate_model_renaming_invalid_chars(invalid_name):
522      with pytest.raises(
523          MlflowException,
524          match="Names cannot contain '/' or ':'",
525          check=lambda e: e.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE),
526      ):
527          _validate_model_renaming(invalid_name)