/ tests / utils / test_search_utils.py
test_search_utils.py
  1  import base64
  2  import json
  3  import re
  4  
  5  import pytest
  6  
  7  from mlflow.entities import (
  8      Dataset,
  9      DatasetInput,
 10      InputTag,
 11      LifecycleStage,
 12      Metric,
 13      Param,
 14      Run,
 15      RunData,
 16      RunInfo,
 17      RunInputs,
 18      RunStatus,
 19      RunTag,
 20      TraceState,
 21      trace_location,
 22  )
 23  from mlflow.entities.trace_info import TraceInfo
 24  from mlflow.exceptions import MlflowException
 25  from mlflow.utils.mlflow_tags import MLFLOW_DATASET_CONTEXT
 26  from mlflow.utils.search_utils import SearchTraceUtils, SearchUtils
 27  
 28  
 29  @pytest.mark.parametrize(
 30      ("filter_string", "parsed_filter"),
 31      [
 32          (
 33              "metric.acc >= 0.94",
 34              [{"comparator": ">=", "key": "acc", "type": "metric", "value": "0.94"}],
 35          ),
 36          ("metric.acc>=100", [{"comparator": ">=", "key": "acc", "type": "metric", "value": "100"}]),
 37          ("params.m!='tf'", [{"comparator": "!=", "key": "m", "type": "parameter", "value": "tf"}]),
 38          (
 39              'params."m"!="tf"',
 40              [{"comparator": "!=", "key": "m", "type": "parameter", "value": "tf"}],
 41          ),
 42          (
 43              'metric."legit name" >= 0.243',
 44              [{"comparator": ">=", "key": "legit name", "type": "metric", "value": "0.243"}],
 45          ),
 46          ("metrics.XYZ = 3", [{"comparator": "=", "key": "XYZ", "type": "metric", "value": "3"}]),
 47          (
 48              'params."cat dog" = "pets"',
 49              [{"comparator": "=", "key": "cat dog", "type": "parameter", "value": "pets"}],
 50          ),
 51          (
 52              'metrics."X-Y-Z" = 3',
 53              [{"comparator": "=", "key": "X-Y-Z", "type": "metric", "value": "3"}],
 54          ),
 55          (
 56              'metrics."X//Y#$$@&Z" = 3',
 57              [{"comparator": "=", "key": "X//Y#$$@&Z", "type": "metric", "value": "3"}],
 58          ),
 59          (
 60              "params.model = 'LinearRegression'",
 61              [{"comparator": "=", "key": "model", "type": "parameter", "value": "LinearRegression"}],
 62          ),
 63          (
 64              "metrics.rmse < 1 and params.model_class = 'LR'",
 65              [
 66                  {"comparator": "<", "key": "rmse", "type": "metric", "value": "1"},
 67                  {"comparator": "=", "key": "model_class", "type": "parameter", "value": "LR"},
 68              ],
 69          ),
 70          ("", []),
 71          ("`metric`.a >= 0.1", [{"comparator": ">=", "key": "a", "type": "metric", "value": "0.1"}]),
 72          (
 73              "`params`.model >= 'LR'",
 74              [{"comparator": ">=", "key": "model", "type": "parameter", "value": "LR"}],
 75          ),
 76          (
 77              "tags.version = 'commit-hash'",
 78              [{"comparator": "=", "key": "version", "type": "tag", "value": "commit-hash"}],
 79          ),
 80          (
 81              "`tags`.source_name = 'a notebook'",
 82              [{"comparator": "=", "key": "source_name", "type": "tag", "value": "a notebook"}],
 83          ),
 84          (
 85              'metrics."accuracy.2.0" > 5',
 86              [{"comparator": ">", "key": "accuracy.2.0", "type": "metric", "value": "5"}],
 87          ),
 88          (
 89              "metrics.`spacey name` > 5",
 90              [{"comparator": ">", "key": "spacey name", "type": "metric", "value": "5"}],
 91          ),
 92          (
 93              'params."p.a.r.a.m" != "a"',
 94              [{"comparator": "!=", "key": "p.a.r.a.m", "type": "parameter", "value": "a"}],
 95          ),
 96          ('tags."t.a.g" = "a"', [{"comparator": "=", "key": "t.a.g", "type": "tag", "value": "a"}]),
 97          (
 98              "attribute.artifact_uri = '1/23/4'",
 99              [{"type": "attribute", "comparator": "=", "key": "artifact_uri", "value": "1/23/4"}],
100          ),
101          (
102              "attribute.start_time >= 1234",
103              [{"type": "attribute", "comparator": ">=", "key": "start_time", "value": "1234"}],
104          ),
105          (
106              "run.status = 'RUNNING'",
107              [{"type": "attribute", "comparator": "=", "key": "status", "value": "RUNNING"}],
108          ),
109          (
110              "dataset.name = 'my_dataset'",
111              [{"type": "dataset", "comparator": "=", "key": "name", "value": "my_dataset"}],
112          ),
113          (
114              "tags.version IS NULL",
115              [{"comparator": "IS NULL", "key": "version", "type": "tag", "value": None}],
116          ),
117          (
118              "tags.version IS NOT NULL",
119              [{"comparator": "IS NOT NULL", "key": "version", "type": "tag", "value": None}],
120          ),
121          (
122              "params.lr IS NULL",
123              [{"comparator": "IS NULL", "key": "lr", "type": "parameter", "value": None}],
124          ),
125          (
126              "params.lr IS NOT NULL",
127              [{"comparator": "IS NOT NULL", "key": "lr", "type": "parameter", "value": None}],
128          ),
129          (
130              "tags.a IS NULL AND params.b = 'val'",
131              [
132                  {"comparator": "IS NULL", "key": "a", "type": "tag", "value": None},
133                  {"comparator": "=", "key": "b", "type": "parameter", "value": "val"},
134              ],
135          ),
136      ],
137  )
138  def test_filter(filter_string, parsed_filter):
139      assert SearchUtils.parse_search_filter(filter_string) == parsed_filter
140  
141  
142  @pytest.mark.parametrize(
143      ("filter_string", "parsed_filter"),
144      [
145          ("params.m = 'LR'", [{"type": "parameter", "comparator": "=", "key": "m", "value": "LR"}]),
146          ('params.m = "LR"', [{"type": "parameter", "comparator": "=", "key": "m", "value": "LR"}]),
147          (
148              'params.m = "L\'Hosp"',
149              [{"type": "parameter", "comparator": "=", "key": "m", "value": "L'Hosp"}],
150          ),
151      ],
152  )
153  def test_correct_quote_trimming(filter_string, parsed_filter):
154      assert SearchUtils.parse_search_filter(filter_string) == parsed_filter
155  
156  
157  @pytest.mark.parametrize(
158      ("filter_string", "error_message"),
159      [
160          ("metric.acc >= 0.94; metrics.rmse < 1", "Search filter contained multiple expression"),
161          ("m.acc >= 0.94", "Invalid entity type"),
162          ("acc >= 0.94", "Invalid attribute key"),
163          ("p.model >= 'LR'", "Invalid entity type"),
164          ("attri.x != 1", "Invalid entity type"),
165          ("a.x != 1", "Invalid entity type"),
166          ("model >= 'LR'", "Invalid attribute key"),
167          ("metrics.A > 0.1 OR params.B = 'LR'", "Invalid clause(s) in filter string"),
168          ("metrics.A > 0.1 NAND params.B = 'LR'", "Invalid clause(s) in filter string"),
169          ("metrics.A > 0.1 AND (params.B = 'LR')", "Invalid clause(s) in filter string"),
170          ("`metrics.A > 0.1", "Invalid clause(s) in filter string"),
171          ("param`.A > 0.1", "Invalid clause(s) in filter string"),
172          ("`dummy.A > 0.1", "Invalid clause(s) in filter string"),
173          ("dummy`.A > 0.1", "Invalid clause(s) in filter string"),
174          ("attribute.start != 1", "Invalid attribute key"),
175          ("attribute.experiment_id != 1", "Invalid attribute key"),
176          ("attribute.lifecycle_stage = 'ACTIVE'", "Invalid attribute key"),
177          ("attribute.name != 1", "Invalid attribute key"),
178          ("attribute.time != 1", "Invalid attribute key"),
179          ("attribute._status != 'RUNNING'", "Invalid attribute key"),
180          ("attribute.status = true", "Invalid clause(s) in filter string"),
181          ("dataset.status = 'true'", "Invalid dataset key"),
182          ("dataset.profile = 'num_rows: 10'", "Invalid dataset key"),
183          ("metrics.acc IS NULL", "IS NULL / IS NOT NULL is only supported for tags and params"),
184          ("attribute.status IS NULL", "IS NULL / IS NOT NULL is only supported for tags and params"),
185      ],
186  )
187  def test_error_filter(filter_string, error_message):
188      with pytest.raises(MlflowException, match=re.escape(error_message)):
189          SearchUtils.parse_search_filter(filter_string)
190  
191  
192  @pytest.mark.parametrize(
193      ("filter_string", "error_message"),
194      [
195          ("metric.model = 'LR'", "Expected numeric value type for metric"),
196          ("metric.model = '5'", "Expected numeric value type for metric"),
197          ("params.acc = 5", "Expected a quoted string value for param"),
198          ("tags.acc = 5", "Expected a quoted string value for tag"),
199          ("metrics.acc != metrics.acc", "Expected numeric value type for metric"),
200          ("1.0 > metrics.acc", "Expected 'Identifier' found"),
201          ("attribute.status = 1", "Expected a quoted string value for attributes"),
202      ],
203  )
204  def test_error_comparison_clauses(filter_string, error_message):
205      with pytest.raises(MlflowException, match=error_message):
206          SearchUtils.parse_search_filter(filter_string)
207  
208  
209  @pytest.mark.parametrize(
210      ("filter_string", "error_message"),
211      [
212          ("params.acc = LR", "value is either not quoted or unidentified quote types"),
213          ("tags.acc = LR", "value is either not quoted or unidentified quote types"),
214          ("params.acc = `LR`", "value is either not quoted or unidentified quote types"),
215          ("params.'acc = LR", "Invalid clause(s) in filter string"),
216          ("params.acc = 'LR", "Invalid clause(s) in filter string"),
217          ("params.acc = LR'", "Invalid clause(s) in filter string"),
218          ("params.acc = \"LR'", "Invalid clause(s) in filter string"),
219          ("tags.acc = \"LR'", "Invalid clause(s) in filter string"),
220          ("tags.acc = = 'LR'", "Invalid clause(s) in filter string"),
221          ("attribute.status IS 'RUNNING'", "Invalid clause(s) in filter string"),
222      ],
223  )
224  def test_bad_quotes(filter_string, error_message):
225      with pytest.raises(MlflowException, match=re.escape(error_message)):
226          SearchUtils.parse_search_filter(filter_string)
227  
228  
229  @pytest.mark.parametrize(
230      ("filter_string", "error_message"),
231      [
232          ("params.acc LR !=", "Invalid clause(s) in filter string"),
233          ("params.acc LR", "Invalid clause(s) in filter string"),
234          ("metric.acc !=", "Invalid clause(s) in filter string"),
235          ("acc != 1.0", "Invalid attribute key"),
236          ("foo is null", "Invalid attribute key"),
237          ("1=1", "Expected 'Identifier' found"),
238          ("1==2", "Expected 'Identifier' found"),
239      ],
240  )
241  def test_invalid_clauses(filter_string, error_message):
242      with pytest.raises(MlflowException, match=re.escape(error_message)):
243          SearchUtils.parse_search_filter(filter_string)
244  
245  
246  @pytest.mark.parametrize(
247      ("entity_type", "bad_comparators", "key", "entity_value"),
248      [
249          ("metrics", ["~", "~="], "abc", 1.0),
250          ("params", [">", "<", ">=", "<=", "~"], "abc", "'my-param-value'"),
251          ("tags", [">", "<", ">=", "<=", "~"], "abc", "'my-tag-value'"),
252          ("attributes", [">", "<", ">=", "<=", "~"], "status", "'my-tag-value'"),
253          ("attributes", ["LIKE", "ILIKE"], "start_time", 1234),
254          ("datasets", [">", "<", ">=", "<=", "~"], "name", "'my-dataset-name'"),
255      ],
256  )
257  def test_bad_comparators(entity_type, bad_comparators, key, entity_value):
258      run = Run(
259          run_info=RunInfo(
260              run_id="hi",
261              experiment_id=0,
262              user_id="user-id",
263              status=RunStatus.to_string(RunStatus.FAILED),
264              start_time=0,
265              end_time=1,
266              lifecycle_stage=LifecycleStage.ACTIVE,
267          ),
268          run_data=RunData(metrics=[], params=[], tags=[]),
269      )
270      for bad_comparator in bad_comparators:
271          bad_filter = f"{entity_type}.{key} {bad_comparator} {entity_value}"
272          with pytest.raises(MlflowException, match="Invalid comparator"):
273              SearchUtils.filter([run], bad_filter)
274  
275  
276  @pytest.mark.parametrize(
277      ("filter_string", "matching_runs"),
278      [
279          (None, [0, 1, 2]),
280          ("", [0, 1, 2]),
281          ("attributes.status = 'FAILED'", [0, 2]),
282          ("metrics.key1 = 123", [1]),
283          ("metrics.key1 != 123", [0, 2]),
284          ("metrics.key1 >= 123", [1, 2]),
285          ("params.my_param = 'A'", [0, 1]),
286          ("tags.tag1 = 'D'", [2]),
287          ("tags.tag1 != 'D'", [1]),
288          ("params.my_param = 'A' AND attributes.status = 'FAILED'", [0]),
289          ("datasets.name = 'name1'", [0, 1]),
290          ("datasets.name IN ('name1', 'name2')", [0, 1, 2]),
291          ("datasets.digest IN ('digest1', 'digest2')", [0, 1, 2]),
292          ("datasets.name = 'name1' AND datasets.digest = 'digest2'", []),
293          ("datasets.context = 'train'", [0]),
294          ("datasets.name = 'name1' AND datasets.context = 'train'", [0]),
295      ],
296  )
297  def test_correct_filtering(filter_string, matching_runs):
298      runs = [
299          Run(
300              run_info=RunInfo(
301                  run_id="hi",
302                  experiment_id=0,
303                  user_id="user-id",
304                  status=RunStatus.to_string(RunStatus.FAILED),
305                  start_time=0,
306                  end_time=1,
307                  lifecycle_stage=LifecycleStage.ACTIVE,
308              ),
309              run_data=RunData(
310                  metrics=[Metric("key1", 121, 1, 0)], params=[Param("my_param", "A")], tags=[]
311              ),
312              run_inputs=RunInputs(
313                  dataset_inputs=[
314                      DatasetInput(
315                          dataset=Dataset(
316                              name="name1",
317                              digest="digest1",
318                              source_type="my_source_type",
319                              source="source",
320                          ),
321                          tags=[InputTag(MLFLOW_DATASET_CONTEXT, "train")],
322                      )
323                  ]
324              ),
325          ),
326          Run(
327              run_info=RunInfo(
328                  run_id="hi2",
329                  experiment_id=0,
330                  user_id="user-id",
331                  status=RunStatus.to_string(RunStatus.FINISHED),
332                  start_time=0,
333                  end_time=1,
334                  lifecycle_stage=LifecycleStage.ACTIVE,
335              ),
336              run_data=RunData(
337                  metrics=[Metric("key1", 123, 1, 0)],
338                  params=[Param("my_param", "A")],
339                  tags=[RunTag("tag1", "C")],
340              ),
341              run_inputs=RunInputs(
342                  dataset_inputs=[
343                      DatasetInput(
344                          dataset=Dataset(
345                              name="name1",
346                              digest="digest1",
347                              source_type="my_source_type",
348                              source="source",
349                          ),
350                          tags=[],
351                      )
352                  ]
353              ),
354          ),
355          Run(
356              run_info=RunInfo(
357                  run_id="hi3",
358                  experiment_id=1,
359                  user_id="user-id",
360                  status=RunStatus.to_string(RunStatus.FAILED),
361                  start_time=0,
362                  end_time=1,
363                  lifecycle_stage=LifecycleStage.ACTIVE,
364              ),
365              run_data=RunData(
366                  metrics=[Metric("key1", 125, 1, 0)],
367                  params=[Param("my_param", "B")],
368                  tags=[RunTag("tag1", "D")],
369              ),
370              run_inputs=RunInputs(
371                  dataset_inputs=[
372                      DatasetInput(
373                          dataset=Dataset(
374                              name="name2",
375                              digest="digest2",
376                              source_type="my_source_type",
377                              source="source",
378                          ),
379                          tags=[],
380                      )
381                  ]
382              ),
383          ),
384      ]
385      filtered_runs = SearchUtils.filter(runs, filter_string)
386      assert set(filtered_runs) == {runs[i] for i in matching_runs}
387  
388  
389  def test_filter_runs_by_start_time():
390      runs = [
391          Run(
392              run_info=RunInfo(
393                  run_id=run_id,
394                  experiment_id=0,
395                  user_id="user-id",
396                  status=RunStatus.to_string(RunStatus.FINISHED),
397                  start_time=idx,
398                  end_time=1,
399                  lifecycle_stage=LifecycleStage.ACTIVE,
400              ),
401              run_data=RunData(),
402          )
403          for idx, run_id in enumerate(["a", "b", "c"])
404      ]
405      assert SearchUtils.filter(runs, "attribute.start_time >= 0") == runs
406      assert SearchUtils.filter(runs, "attribute.start_time > 1") == runs[2:]
407      assert SearchUtils.filter(runs, "attribute.start_time = 2") == runs[2:]
408  
409  
410  def test_filter_runs_by_user_id():
411      runs = [
412          Run(
413              run_info=RunInfo(
414                  run_id="a",
415                  experiment_id=0,
416                  user_id="user-id",
417                  status=RunStatus.to_string(RunStatus.FINISHED),
418                  start_time=1,
419                  end_time=1,
420                  lifecycle_stage=LifecycleStage.ACTIVE,
421              ),
422              run_data=RunData(),
423          ),
424          Run(
425              run_info=RunInfo(
426                  run_id="b",
427                  experiment_id=0,
428                  user_id="user-id2",
429                  status=RunStatus.to_string(RunStatus.FINISHED),
430                  start_time=1,
431                  end_time=1,
432                  lifecycle_stage=LifecycleStage.ACTIVE,
433              ),
434              run_data=RunData(),
435          ),
436      ]
437      assert SearchUtils.filter(runs, "attribute.user_id = 'user-id2'")[0] == runs[1]
438  
439  
440  def test_filter_runs_by_end_time():
441      runs = [
442          Run(
443              run_info=RunInfo(
444                  run_id=run_id,
445                  experiment_id=0,
446                  user_id="user-id",
447                  status=RunStatus.to_string(RunStatus.FINISHED),
448                  start_time=idx,
449                  end_time=idx,
450                  lifecycle_stage=LifecycleStage.ACTIVE,
451              ),
452              run_data=RunData(),
453          )
454          for idx, run_id in enumerate(["a", "b", "c"])
455      ]
456      assert SearchUtils.filter(runs, "attribute.end_time >= 0") == runs
457      assert SearchUtils.filter(runs, "attribute.end_time > 1") == runs[2:]
458      assert SearchUtils.filter(runs, "attribute.end_time = 2") == runs[2:]
459  
460  
461  @pytest.mark.parametrize(
462      ("order_bys", "matching_runs"),
463      [
464          (None, [2, 1, 0]),
465          ([], [2, 1, 0]),
466          (["tags.noSuchTag"], [2, 1, 0]),
467          (["attributes.status"], [2, 0, 1]),
468          (["attributes.start_time"], [0, 2, 1]),
469          (["metrics.key1 asc"], [0, 1, 2]),
470          (['metrics."key1"  desc'], [2, 1, 0]),
471          (["params.my_param"], [1, 0, 2]),
472          (["params.my_param aSc", "attributes.status  ASC"], [0, 1, 2]),
473          (["params.my_param", "attributes.status DESC"], [1, 0, 2]),
474          (["params.my_param DESC", "attributes.status   DESC"], [2, 1, 0]),
475          (["params.`my_param` DESC", "attributes.status DESC"], [2, 1, 0]),
476          (["tags.tag1"], [1, 2, 0]),
477          (["tags.tag1    DESC"], [2, 1, 0]),
478      ],
479  )
480  def test_correct_sorting(order_bys, matching_runs):
481      runs = [
482          Run(
483              run_info=RunInfo(
484                  run_id="9",
485                  experiment_id=0,
486                  user_id="user-id",
487                  status=RunStatus.to_string(RunStatus.FAILED),
488                  start_time=0,
489                  end_time=1,
490                  lifecycle_stage=LifecycleStage.ACTIVE,
491              ),
492              run_data=RunData(
493                  metrics=[Metric("key1", 121, 1, 0)], params=[Param("my_param", "A")], tags=[]
494              ),
495          ),
496          Run(
497              run_info=RunInfo(
498                  run_id="8",
499                  experiment_id=0,
500                  user_id="user-id",
501                  status=RunStatus.to_string(RunStatus.FINISHED),
502                  start_time=1,
503                  end_time=1,
504                  lifecycle_stage=LifecycleStage.ACTIVE,
505              ),
506              run_data=RunData(
507                  metrics=[Metric("key1", 123, 1, 0)],
508                  params=[Param("my_param", "A")],
509                  tags=[RunTag("tag1", "C")],
510              ),
511          ),
512          Run(
513              run_info=RunInfo(
514                  run_id="7",
515                  experiment_id=1,
516                  user_id="user-id",
517                  status=RunStatus.to_string(RunStatus.FAILED),
518                  start_time=1,
519                  end_time=1,
520                  lifecycle_stage=LifecycleStage.ACTIVE,
521              ),
522              run_data=RunData(
523                  metrics=[Metric("key1", 125, 1, 0)],
524                  params=[Param("my_param", "B")],
525                  tags=[RunTag("tag1", "D")],
526              ),
527          ),
528      ]
529      sorted_runs = SearchUtils.sort(runs, order_bys)
530      sorted_run_indices = []
531      for run in sorted_runs:
532          for i, r in enumerate(runs):
533              if r == run:
534                  sorted_run_indices.append(i)
535                  break
536      assert sorted_run_indices == matching_runs
537  
538  
539  def test_order_by_metric_with_nans_infs_nones():
540      metric_vals_str = ["nan", "inf", "-inf", "-1000", "0", "1000", "None"]
541      runs = [
542          Run(
543              run_info=RunInfo(
544                  run_id=x,
545                  experiment_id=0,
546                  user_id="user",
547                  status=RunStatus.to_string(RunStatus.FINISHED),
548                  start_time=0,
549                  end_time=1,
550                  lifecycle_stage=LifecycleStage.ACTIVE,
551              ),
552              run_data=RunData(metrics=[Metric("x", None if x == "None" else float(x), 1, 0)]),
553          )
554          for x in metric_vals_str
555      ]
556      sorted_runs_asc = [x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x asc"])]
557      sorted_runs_desc = [x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x desc"])]
558      # asc
559      assert sorted_runs_asc == ["-inf", "-1000", "0", "1000", "inf", "nan", "None"]
560      # desc
561      assert sorted_runs_desc == ["inf", "1000", "0", "-1000", "-inf", "nan", "None"]
562  
563  
564  @pytest.mark.parametrize(
565      ("order_by", "error_message"),
566      [
567          ("m.acc", "Invalid entity type"),
568          ("acc", "Invalid attribute key"),
569          ("attri.x", "Invalid entity type"),
570          ("`metrics.A", "Invalid order_by clause"),
571          ("`metrics.A`", "Invalid entity type"),
572          ("attribute.start", "Invalid attribute key"),
573          ("attribute.experiment_id", "Invalid attribute key"),
574          ("metrics.A != 1", "Invalid order_by clause"),
575          ("params.my_param ", "Invalid order_by clause"),
576          ("attribute.run_id ACS", "Invalid ordering key"),
577          ("attribute.run_id decs", "Invalid ordering key"),
578      ],
579  )
580  def test_invalid_order_by_search_runs(order_by, error_message):
581      with pytest.raises(MlflowException, match=error_message):
582          SearchUtils.parse_order_by_for_search_runs(order_by)
583  
584  
585  @pytest.mark.parametrize(
586      ("order_by", "ascending_expected"),
587      [
588          ("metrics.`Mean Square Error`", True),
589          ("metrics.`Mean Square Error` ASC", True),
590          ("metrics.`Mean Square Error` DESC", False),
591      ],
592  )
593  def test_space_order_by_search_runs(order_by, ascending_expected):
594      identifier_type, identifier_name, ascending = SearchUtils.parse_order_by_for_search_runs(
595          order_by
596      )
597      assert identifier_type == "metric"
598      assert identifier_name == "Mean Square Error"
599      assert ascending == ascending_expected
600  
601  
602  @pytest.mark.parametrize(
603      ("order_by", "error_message"),
604      [
605          ("creation_timestamp DESC", "Invalid order by key"),
606          ("last_updated_timestamp DESC blah", "Invalid order_by clause"),
607          ("", "Invalid order_by clause"),
608          ("timestamp somerandomstuff ASC", "Invalid order_by clause"),
609          ("timestamp somerandomstuff", "Invalid order_by clause"),
610          ("timestamp decs", "Invalid order_by clause"),
611          ("timestamp ACS", "Invalid order_by clause"),
612          ("name aCs", "Invalid ordering key"),
613      ],
614  )
615  def test_invalid_order_by_search_registered_models(order_by, error_message):
616      with pytest.raises(MlflowException, match=re.escape(error_message)):
617          SearchUtils.parse_order_by_for_search_registered_models(order_by)
618  
619  
620  @pytest.mark.parametrize(
621      ("page_token", "max_results", "matching_runs", "expected_next_page_token"),
622      [
623          (None, 1, [0], {"offset": 1}),
624          (None, 2, [0, 1], {"offset": 2}),
625          (None, 3, [0, 1, 2], None),
626          (None, 5, [0, 1, 2], None),
627          ({"offset": 1}, 1, [1], {"offset": 2}),
628          ({"offset": 1}, 2, [1, 2], None),
629          ({"offset": 1}, 3, [1, 2], None),
630          ({"offset": 2}, 1, [2], None),
631          ({"offset": 2}, 2, [2], None),
632          ({"offset": 2}, 0, [], {"offset": 2}),
633          ({"offset": 3}, 1, [], None),
634      ],
635  )
636  def test_pagination(page_token, max_results, matching_runs, expected_next_page_token):
637      runs = [
638          Run(
639              run_info=RunInfo(
640                  run_id="0",
641                  experiment_id=0,
642                  user_id="user-id",
643                  status=RunStatus.to_string(RunStatus.FAILED),
644                  start_time=0,
645                  end_time=1,
646                  lifecycle_stage=LifecycleStage.ACTIVE,
647              ),
648              run_data=RunData([], [], []),
649          ),
650          Run(
651              run_info=RunInfo(
652                  run_id="1",
653                  experiment_id=0,
654                  user_id="user-id",
655                  status=RunStatus.to_string(RunStatus.FAILED),
656                  start_time=0,
657                  end_time=1,
658                  lifecycle_stage=LifecycleStage.ACTIVE,
659              ),
660              run_data=RunData([], [], []),
661          ),
662          Run(
663              run_info=RunInfo(
664                  run_id="2",
665                  experiment_id=0,
666                  user_id="user-id",
667                  status=RunStatus.to_string(RunStatus.FAILED),
668                  start_time=0,
669                  end_time=1,
670                  lifecycle_stage=LifecycleStage.ACTIVE,
671              ),
672              run_data=RunData([], [], []),
673          ),
674      ]
675      encoded_page_token = None
676      if page_token:
677          encoded_page_token = base64.b64encode(json.dumps(page_token).encode("utf-8"))
678      paginated_runs, next_page_token = SearchUtils.paginate(runs, encoded_page_token, max_results)
679  
680      paginated_run_indices = []
681      for run in paginated_runs:
682          for i, r in enumerate(runs):
683              if r == run:
684                  paginated_run_indices.append(i)
685                  break
686      assert paginated_run_indices == matching_runs
687  
688      decoded_next_page_token = None
689      if next_page_token:
690          decoded_next_page_token = json.loads(base64.b64decode(next_page_token))
691      assert decoded_next_page_token == expected_next_page_token
692  
693  
694  @pytest.mark.parametrize(
695      ("page_token", "error_message"),
696      [
697          (base64.b64encode(json.dumps({}).encode("utf-8")), "Invalid page token"),
698          (base64.b64encode(json.dumps({"offset": "a"}).encode("utf-8")), "Invalid page token"),
699          (base64.b64encode(json.dumps({"offsoot": 7}).encode("utf-8")), "Invalid page token"),
700          (base64.b64encode(b"not json"), "Invalid page token"),
701          ("not base64", "Invalid page token"),
702      ],
703  )
704  def test_invalid_page_tokens(page_token, error_message):
705      with pytest.raises(MlflowException, match=error_message):
706          SearchUtils.paginate([], page_token, 1)
707  
708  
709  def test_like_pattern_with_plus_character():
710      import mlflow
711  
712      name = "jamie-foo C+W bar"
713      mlflow.create_experiment(name)
714  
715      exps = mlflow.search_experiments(filter_string=f'name LIKE "{name}"')
716      assert len(exps) == 1
717  
718      exps = mlflow.search_experiments(filter_string='name LIKE "jamie-foo C+%"')
719      assert len(exps) == 1
720  
721  
722  def test_filter_runs_by_tag_and_param_is_null():
723      run_with_tag = Run(
724          run_info=RunInfo(
725              run_id="run1",
726              experiment_id=0,
727              user_id="user",
728              status=RunStatus.to_string(RunStatus.FINISHED),
729              start_time=0,
730              end_time=1,
731              lifecycle_stage=LifecycleStage.ACTIVE,
732          ),
733          run_data=RunData(tags=[RunTag("env", "prod")], params=[], metrics=[]),
734      )
735      run_with_param = Run(
736          run_info=RunInfo(
737              run_id="run2",
738              experiment_id=0,
739              user_id="user",
740              status=RunStatus.to_string(RunStatus.FINISHED),
741              start_time=0,
742              end_time=1,
743              lifecycle_stage=LifecycleStage.ACTIVE,
744          ),
745          run_data=RunData(tags=[], params=[Param("lr", "0.01")], metrics=[]),
746      )
747      runs = [run_with_tag, run_with_param]
748  
749      assert [r.info.run_id for r in SearchUtils.filter(runs, "tags.env IS NOT NULL")] == ["run1"]
750      assert [r.info.run_id for r in SearchUtils.filter(runs, "tags.env IS NULL")] == ["run2"]
751      assert [r.info.run_id for r in SearchUtils.filter(runs, "params.lr IS NOT NULL")] == ["run2"]
752      assert [r.info.run_id for r in SearchUtils.filter(runs, "params.lr IS NULL")] == ["run1"]
753  
754  
755  def test_search_trace_utils_filter_tag_is_null():
756      loc = trace_location.TraceLocation.from_experiment_id("0")
757      trace1 = TraceInfo(
758          trace_id="t1",
759          trace_location=loc,
760          request_time=0,
761          state=TraceState.OK,
762          tags={"env": "prod", "region": "us"},
763      )
764      trace2 = TraceInfo(
765          trace_id="t2",
766          trace_location=loc,
767          request_time=0,
768          state=TraceState.OK,
769          tags={"env": "staging"},
770      )
771      trace3 = TraceInfo(
772          trace_id="t3",
773          trace_location=loc,
774          request_time=0,
775          state=TraceState.OK,
776          tags={},
777      )
778      traces = [trace1, trace2, trace3]
779  
780      result = SearchTraceUtils.filter(traces, "tag.region IS NULL")
781      assert {t.trace_id for t in result} == {"t2", "t3"}
782  
783      result = SearchTraceUtils.filter(traces, "tag.region IS NOT NULL")
784      assert {t.trace_id for t in result} == {"t1"}
785  
786      result = SearchTraceUtils.filter(traces, "tag.env IS NULL")
787      assert {t.trace_id for t in result} == {"t3"}
788  
789      result = SearchTraceUtils.filter(traces, "tag.env IS NOT NULL")
790      assert {t.trace_id for t in result} == {"t1", "t2"}
791  
792      result = SearchTraceUtils.filter(traces, 'tag.region IS NULL AND tag.env = "staging"')
793      assert {t.trace_id for t in result} == {"t2"}
794  
795  
796  def test_search_trace_utils_filter_metadata_is_null():
797      loc = trace_location.TraceLocation.from_experiment_id("0")
798      trace1 = TraceInfo(
799          trace_id="t1",
800          trace_location=loc,
801          request_time=0,
802          state=TraceState.OK,
803          trace_metadata={"user": "alice", "session": "s1"},
804      )
805      trace2 = TraceInfo(
806          trace_id="t2",
807          trace_location=loc,
808          request_time=0,
809          state=TraceState.OK,
810          trace_metadata={"user": "bob"},
811      )
812      trace3 = TraceInfo(
813          trace_id="t3",
814          trace_location=loc,
815          request_time=0,
816          state=TraceState.OK,
817          trace_metadata={},
818      )
819      traces = [trace1, trace2, trace3]
820  
821      result = SearchTraceUtils.filter(traces, "metadata.session IS NULL")
822      assert {t.trace_id for t in result} == {"t2", "t3"}
823  
824      result = SearchTraceUtils.filter(traces, "metadata.session IS NOT NULL")
825      assert {t.trace_id for t in result} == {"t1"}