test_search_utils.py
1 import base64 2 import json 3 import re 4 5 import pytest 6 7 from mlflow.entities import ( 8 Dataset, 9 DatasetInput, 10 InputTag, 11 LifecycleStage, 12 Metric, 13 Param, 14 Run, 15 RunData, 16 RunInfo, 17 RunInputs, 18 RunStatus, 19 RunTag, 20 TraceState, 21 trace_location, 22 ) 23 from mlflow.entities.trace_info import TraceInfo 24 from mlflow.exceptions import MlflowException 25 from mlflow.utils.mlflow_tags import MLFLOW_DATASET_CONTEXT 26 from mlflow.utils.search_utils import SearchTraceUtils, SearchUtils 27 28 29 @pytest.mark.parametrize( 30 ("filter_string", "parsed_filter"), 31 [ 32 ( 33 "metric.acc >= 0.94", 34 [{"comparator": ">=", "key": "acc", "type": "metric", "value": "0.94"}], 35 ), 36 ("metric.acc>=100", [{"comparator": ">=", "key": "acc", "type": "metric", "value": "100"}]), 37 ("params.m!='tf'", [{"comparator": "!=", "key": "m", "type": "parameter", "value": "tf"}]), 38 ( 39 'params."m"!="tf"', 40 [{"comparator": "!=", "key": "m", "type": "parameter", "value": "tf"}], 41 ), 42 ( 43 'metric."legit name" >= 0.243', 44 [{"comparator": ">=", "key": "legit name", "type": "metric", "value": "0.243"}], 45 ), 46 ("metrics.XYZ = 3", [{"comparator": "=", "key": "XYZ", "type": "metric", "value": "3"}]), 47 ( 48 'params."cat dog" = "pets"', 49 [{"comparator": "=", "key": "cat dog", "type": "parameter", "value": "pets"}], 50 ), 51 ( 52 'metrics."X-Y-Z" = 3', 53 [{"comparator": "=", "key": "X-Y-Z", "type": "metric", "value": "3"}], 54 ), 55 ( 56 'metrics."X//Y#$$@&Z" = 3', 57 [{"comparator": "=", "key": "X//Y#$$@&Z", "type": "metric", "value": "3"}], 58 ), 59 ( 60 "params.model = 'LinearRegression'", 61 [{"comparator": "=", "key": "model", "type": "parameter", "value": "LinearRegression"}], 62 ), 63 ( 64 "metrics.rmse < 1 and params.model_class = 'LR'", 65 [ 66 {"comparator": "<", "key": "rmse", "type": "metric", "value": "1"}, 67 {"comparator": "=", "key": "model_class", "type": "parameter", "value": "LR"}, 68 ], 69 ), 70 ("", []), 71 ("`metric`.a >= 0.1", [{"comparator": ">=", "key": "a", "type": "metric", "value": "0.1"}]), 72 ( 73 "`params`.model >= 'LR'", 74 [{"comparator": ">=", "key": "model", "type": "parameter", "value": "LR"}], 75 ), 76 ( 77 "tags.version = 'commit-hash'", 78 [{"comparator": "=", "key": "version", "type": "tag", "value": "commit-hash"}], 79 ), 80 ( 81 "`tags`.source_name = 'a notebook'", 82 [{"comparator": "=", "key": "source_name", "type": "tag", "value": "a notebook"}], 83 ), 84 ( 85 'metrics."accuracy.2.0" > 5', 86 [{"comparator": ">", "key": "accuracy.2.0", "type": "metric", "value": "5"}], 87 ), 88 ( 89 "metrics.`spacey name` > 5", 90 [{"comparator": ">", "key": "spacey name", "type": "metric", "value": "5"}], 91 ), 92 ( 93 'params."p.a.r.a.m" != "a"', 94 [{"comparator": "!=", "key": "p.a.r.a.m", "type": "parameter", "value": "a"}], 95 ), 96 ('tags."t.a.g" = "a"', [{"comparator": "=", "key": "t.a.g", "type": "tag", "value": "a"}]), 97 ( 98 "attribute.artifact_uri = '1/23/4'", 99 [{"type": "attribute", "comparator": "=", "key": "artifact_uri", "value": "1/23/4"}], 100 ), 101 ( 102 "attribute.start_time >= 1234", 103 [{"type": "attribute", "comparator": ">=", "key": "start_time", "value": "1234"}], 104 ), 105 ( 106 "run.status = 'RUNNING'", 107 [{"type": "attribute", "comparator": "=", "key": "status", "value": "RUNNING"}], 108 ), 109 ( 110 "dataset.name = 'my_dataset'", 111 [{"type": "dataset", "comparator": "=", "key": "name", "value": "my_dataset"}], 112 ), 113 ( 114 "tags.version IS NULL", 115 [{"comparator": "IS NULL", "key": "version", "type": "tag", "value": None}], 116 ), 117 ( 118 "tags.version IS NOT NULL", 119 [{"comparator": "IS NOT NULL", "key": "version", "type": "tag", "value": None}], 120 ), 121 ( 122 "params.lr IS NULL", 123 [{"comparator": "IS NULL", "key": "lr", "type": "parameter", "value": None}], 124 ), 125 ( 126 "params.lr IS NOT NULL", 127 [{"comparator": "IS NOT NULL", "key": "lr", "type": "parameter", "value": None}], 128 ), 129 ( 130 "tags.a IS NULL AND params.b = 'val'", 131 [ 132 {"comparator": "IS NULL", "key": "a", "type": "tag", "value": None}, 133 {"comparator": "=", "key": "b", "type": "parameter", "value": "val"}, 134 ], 135 ), 136 ], 137 ) 138 def test_filter(filter_string, parsed_filter): 139 assert SearchUtils.parse_search_filter(filter_string) == parsed_filter 140 141 142 @pytest.mark.parametrize( 143 ("filter_string", "parsed_filter"), 144 [ 145 ("params.m = 'LR'", [{"type": "parameter", "comparator": "=", "key": "m", "value": "LR"}]), 146 ('params.m = "LR"', [{"type": "parameter", "comparator": "=", "key": "m", "value": "LR"}]), 147 ( 148 'params.m = "L\'Hosp"', 149 [{"type": "parameter", "comparator": "=", "key": "m", "value": "L'Hosp"}], 150 ), 151 ], 152 ) 153 def test_correct_quote_trimming(filter_string, parsed_filter): 154 assert SearchUtils.parse_search_filter(filter_string) == parsed_filter 155 156 157 @pytest.mark.parametrize( 158 ("filter_string", "error_message"), 159 [ 160 ("metric.acc >= 0.94; metrics.rmse < 1", "Search filter contained multiple expression"), 161 ("m.acc >= 0.94", "Invalid entity type"), 162 ("acc >= 0.94", "Invalid attribute key"), 163 ("p.model >= 'LR'", "Invalid entity type"), 164 ("attri.x != 1", "Invalid entity type"), 165 ("a.x != 1", "Invalid entity type"), 166 ("model >= 'LR'", "Invalid attribute key"), 167 ("metrics.A > 0.1 OR params.B = 'LR'", "Invalid clause(s) in filter string"), 168 ("metrics.A > 0.1 NAND params.B = 'LR'", "Invalid clause(s) in filter string"), 169 ("metrics.A > 0.1 AND (params.B = 'LR')", "Invalid clause(s) in filter string"), 170 ("`metrics.A > 0.1", "Invalid clause(s) in filter string"), 171 ("param`.A > 0.1", "Invalid clause(s) in filter string"), 172 ("`dummy.A > 0.1", "Invalid clause(s) in filter string"), 173 ("dummy`.A > 0.1", "Invalid clause(s) in filter string"), 174 ("attribute.start != 1", "Invalid attribute key"), 175 ("attribute.experiment_id != 1", "Invalid attribute key"), 176 ("attribute.lifecycle_stage = 'ACTIVE'", "Invalid attribute key"), 177 ("attribute.name != 1", "Invalid attribute key"), 178 ("attribute.time != 1", "Invalid attribute key"), 179 ("attribute._status != 'RUNNING'", "Invalid attribute key"), 180 ("attribute.status = true", "Invalid clause(s) in filter string"), 181 ("dataset.status = 'true'", "Invalid dataset key"), 182 ("dataset.profile = 'num_rows: 10'", "Invalid dataset key"), 183 ("metrics.acc IS NULL", "IS NULL / IS NOT NULL is only supported for tags and params"), 184 ("attribute.status IS NULL", "IS NULL / IS NOT NULL is only supported for tags and params"), 185 ], 186 ) 187 def test_error_filter(filter_string, error_message): 188 with pytest.raises(MlflowException, match=re.escape(error_message)): 189 SearchUtils.parse_search_filter(filter_string) 190 191 192 @pytest.mark.parametrize( 193 ("filter_string", "error_message"), 194 [ 195 ("metric.model = 'LR'", "Expected numeric value type for metric"), 196 ("metric.model = '5'", "Expected numeric value type for metric"), 197 ("params.acc = 5", "Expected a quoted string value for param"), 198 ("tags.acc = 5", "Expected a quoted string value for tag"), 199 ("metrics.acc != metrics.acc", "Expected numeric value type for metric"), 200 ("1.0 > metrics.acc", "Expected 'Identifier' found"), 201 ("attribute.status = 1", "Expected a quoted string value for attributes"), 202 ], 203 ) 204 def test_error_comparison_clauses(filter_string, error_message): 205 with pytest.raises(MlflowException, match=error_message): 206 SearchUtils.parse_search_filter(filter_string) 207 208 209 @pytest.mark.parametrize( 210 ("filter_string", "error_message"), 211 [ 212 ("params.acc = LR", "value is either not quoted or unidentified quote types"), 213 ("tags.acc = LR", "value is either not quoted or unidentified quote types"), 214 ("params.acc = `LR`", "value is either not quoted or unidentified quote types"), 215 ("params.'acc = LR", "Invalid clause(s) in filter string"), 216 ("params.acc = 'LR", "Invalid clause(s) in filter string"), 217 ("params.acc = LR'", "Invalid clause(s) in filter string"), 218 ("params.acc = \"LR'", "Invalid clause(s) in filter string"), 219 ("tags.acc = \"LR'", "Invalid clause(s) in filter string"), 220 ("tags.acc = = 'LR'", "Invalid clause(s) in filter string"), 221 ("attribute.status IS 'RUNNING'", "Invalid clause(s) in filter string"), 222 ], 223 ) 224 def test_bad_quotes(filter_string, error_message): 225 with pytest.raises(MlflowException, match=re.escape(error_message)): 226 SearchUtils.parse_search_filter(filter_string) 227 228 229 @pytest.mark.parametrize( 230 ("filter_string", "error_message"), 231 [ 232 ("params.acc LR !=", "Invalid clause(s) in filter string"), 233 ("params.acc LR", "Invalid clause(s) in filter string"), 234 ("metric.acc !=", "Invalid clause(s) in filter string"), 235 ("acc != 1.0", "Invalid attribute key"), 236 ("foo is null", "Invalid attribute key"), 237 ("1=1", "Expected 'Identifier' found"), 238 ("1==2", "Expected 'Identifier' found"), 239 ], 240 ) 241 def test_invalid_clauses(filter_string, error_message): 242 with pytest.raises(MlflowException, match=re.escape(error_message)): 243 SearchUtils.parse_search_filter(filter_string) 244 245 246 @pytest.mark.parametrize( 247 ("entity_type", "bad_comparators", "key", "entity_value"), 248 [ 249 ("metrics", ["~", "~="], "abc", 1.0), 250 ("params", [">", "<", ">=", "<=", "~"], "abc", "'my-param-value'"), 251 ("tags", [">", "<", ">=", "<=", "~"], "abc", "'my-tag-value'"), 252 ("attributes", [">", "<", ">=", "<=", "~"], "status", "'my-tag-value'"), 253 ("attributes", ["LIKE", "ILIKE"], "start_time", 1234), 254 ("datasets", [">", "<", ">=", "<=", "~"], "name", "'my-dataset-name'"), 255 ], 256 ) 257 def test_bad_comparators(entity_type, bad_comparators, key, entity_value): 258 run = Run( 259 run_info=RunInfo( 260 run_id="hi", 261 experiment_id=0, 262 user_id="user-id", 263 status=RunStatus.to_string(RunStatus.FAILED), 264 start_time=0, 265 end_time=1, 266 lifecycle_stage=LifecycleStage.ACTIVE, 267 ), 268 run_data=RunData(metrics=[], params=[], tags=[]), 269 ) 270 for bad_comparator in bad_comparators: 271 bad_filter = f"{entity_type}.{key} {bad_comparator} {entity_value}" 272 with pytest.raises(MlflowException, match="Invalid comparator"): 273 SearchUtils.filter([run], bad_filter) 274 275 276 @pytest.mark.parametrize( 277 ("filter_string", "matching_runs"), 278 [ 279 (None, [0, 1, 2]), 280 ("", [0, 1, 2]), 281 ("attributes.status = 'FAILED'", [0, 2]), 282 ("metrics.key1 = 123", [1]), 283 ("metrics.key1 != 123", [0, 2]), 284 ("metrics.key1 >= 123", [1, 2]), 285 ("params.my_param = 'A'", [0, 1]), 286 ("tags.tag1 = 'D'", [2]), 287 ("tags.tag1 != 'D'", [1]), 288 ("params.my_param = 'A' AND attributes.status = 'FAILED'", [0]), 289 ("datasets.name = 'name1'", [0, 1]), 290 ("datasets.name IN ('name1', 'name2')", [0, 1, 2]), 291 ("datasets.digest IN ('digest1', 'digest2')", [0, 1, 2]), 292 ("datasets.name = 'name1' AND datasets.digest = 'digest2'", []), 293 ("datasets.context = 'train'", [0]), 294 ("datasets.name = 'name1' AND datasets.context = 'train'", [0]), 295 ], 296 ) 297 def test_correct_filtering(filter_string, matching_runs): 298 runs = [ 299 Run( 300 run_info=RunInfo( 301 run_id="hi", 302 experiment_id=0, 303 user_id="user-id", 304 status=RunStatus.to_string(RunStatus.FAILED), 305 start_time=0, 306 end_time=1, 307 lifecycle_stage=LifecycleStage.ACTIVE, 308 ), 309 run_data=RunData( 310 metrics=[Metric("key1", 121, 1, 0)], params=[Param("my_param", "A")], tags=[] 311 ), 312 run_inputs=RunInputs( 313 dataset_inputs=[ 314 DatasetInput( 315 dataset=Dataset( 316 name="name1", 317 digest="digest1", 318 source_type="my_source_type", 319 source="source", 320 ), 321 tags=[InputTag(MLFLOW_DATASET_CONTEXT, "train")], 322 ) 323 ] 324 ), 325 ), 326 Run( 327 run_info=RunInfo( 328 run_id="hi2", 329 experiment_id=0, 330 user_id="user-id", 331 status=RunStatus.to_string(RunStatus.FINISHED), 332 start_time=0, 333 end_time=1, 334 lifecycle_stage=LifecycleStage.ACTIVE, 335 ), 336 run_data=RunData( 337 metrics=[Metric("key1", 123, 1, 0)], 338 params=[Param("my_param", "A")], 339 tags=[RunTag("tag1", "C")], 340 ), 341 run_inputs=RunInputs( 342 dataset_inputs=[ 343 DatasetInput( 344 dataset=Dataset( 345 name="name1", 346 digest="digest1", 347 source_type="my_source_type", 348 source="source", 349 ), 350 tags=[], 351 ) 352 ] 353 ), 354 ), 355 Run( 356 run_info=RunInfo( 357 run_id="hi3", 358 experiment_id=1, 359 user_id="user-id", 360 status=RunStatus.to_string(RunStatus.FAILED), 361 start_time=0, 362 end_time=1, 363 lifecycle_stage=LifecycleStage.ACTIVE, 364 ), 365 run_data=RunData( 366 metrics=[Metric("key1", 125, 1, 0)], 367 params=[Param("my_param", "B")], 368 tags=[RunTag("tag1", "D")], 369 ), 370 run_inputs=RunInputs( 371 dataset_inputs=[ 372 DatasetInput( 373 dataset=Dataset( 374 name="name2", 375 digest="digest2", 376 source_type="my_source_type", 377 source="source", 378 ), 379 tags=[], 380 ) 381 ] 382 ), 383 ), 384 ] 385 filtered_runs = SearchUtils.filter(runs, filter_string) 386 assert set(filtered_runs) == {runs[i] for i in matching_runs} 387 388 389 def test_filter_runs_by_start_time(): 390 runs = [ 391 Run( 392 run_info=RunInfo( 393 run_id=run_id, 394 experiment_id=0, 395 user_id="user-id", 396 status=RunStatus.to_string(RunStatus.FINISHED), 397 start_time=idx, 398 end_time=1, 399 lifecycle_stage=LifecycleStage.ACTIVE, 400 ), 401 run_data=RunData(), 402 ) 403 for idx, run_id in enumerate(["a", "b", "c"]) 404 ] 405 assert SearchUtils.filter(runs, "attribute.start_time >= 0") == runs 406 assert SearchUtils.filter(runs, "attribute.start_time > 1") == runs[2:] 407 assert SearchUtils.filter(runs, "attribute.start_time = 2") == runs[2:] 408 409 410 def test_filter_runs_by_user_id(): 411 runs = [ 412 Run( 413 run_info=RunInfo( 414 run_id="a", 415 experiment_id=0, 416 user_id="user-id", 417 status=RunStatus.to_string(RunStatus.FINISHED), 418 start_time=1, 419 end_time=1, 420 lifecycle_stage=LifecycleStage.ACTIVE, 421 ), 422 run_data=RunData(), 423 ), 424 Run( 425 run_info=RunInfo( 426 run_id="b", 427 experiment_id=0, 428 user_id="user-id2", 429 status=RunStatus.to_string(RunStatus.FINISHED), 430 start_time=1, 431 end_time=1, 432 lifecycle_stage=LifecycleStage.ACTIVE, 433 ), 434 run_data=RunData(), 435 ), 436 ] 437 assert SearchUtils.filter(runs, "attribute.user_id = 'user-id2'")[0] == runs[1] 438 439 440 def test_filter_runs_by_end_time(): 441 runs = [ 442 Run( 443 run_info=RunInfo( 444 run_id=run_id, 445 experiment_id=0, 446 user_id="user-id", 447 status=RunStatus.to_string(RunStatus.FINISHED), 448 start_time=idx, 449 end_time=idx, 450 lifecycle_stage=LifecycleStage.ACTIVE, 451 ), 452 run_data=RunData(), 453 ) 454 for idx, run_id in enumerate(["a", "b", "c"]) 455 ] 456 assert SearchUtils.filter(runs, "attribute.end_time >= 0") == runs 457 assert SearchUtils.filter(runs, "attribute.end_time > 1") == runs[2:] 458 assert SearchUtils.filter(runs, "attribute.end_time = 2") == runs[2:] 459 460 461 @pytest.mark.parametrize( 462 ("order_bys", "matching_runs"), 463 [ 464 (None, [2, 1, 0]), 465 ([], [2, 1, 0]), 466 (["tags.noSuchTag"], [2, 1, 0]), 467 (["attributes.status"], [2, 0, 1]), 468 (["attributes.start_time"], [0, 2, 1]), 469 (["metrics.key1 asc"], [0, 1, 2]), 470 (['metrics."key1" desc'], [2, 1, 0]), 471 (["params.my_param"], [1, 0, 2]), 472 (["params.my_param aSc", "attributes.status ASC"], [0, 1, 2]), 473 (["params.my_param", "attributes.status DESC"], [1, 0, 2]), 474 (["params.my_param DESC", "attributes.status DESC"], [2, 1, 0]), 475 (["params.`my_param` DESC", "attributes.status DESC"], [2, 1, 0]), 476 (["tags.tag1"], [1, 2, 0]), 477 (["tags.tag1 DESC"], [2, 1, 0]), 478 ], 479 ) 480 def test_correct_sorting(order_bys, matching_runs): 481 runs = [ 482 Run( 483 run_info=RunInfo( 484 run_id="9", 485 experiment_id=0, 486 user_id="user-id", 487 status=RunStatus.to_string(RunStatus.FAILED), 488 start_time=0, 489 end_time=1, 490 lifecycle_stage=LifecycleStage.ACTIVE, 491 ), 492 run_data=RunData( 493 metrics=[Metric("key1", 121, 1, 0)], params=[Param("my_param", "A")], tags=[] 494 ), 495 ), 496 Run( 497 run_info=RunInfo( 498 run_id="8", 499 experiment_id=0, 500 user_id="user-id", 501 status=RunStatus.to_string(RunStatus.FINISHED), 502 start_time=1, 503 end_time=1, 504 lifecycle_stage=LifecycleStage.ACTIVE, 505 ), 506 run_data=RunData( 507 metrics=[Metric("key1", 123, 1, 0)], 508 params=[Param("my_param", "A")], 509 tags=[RunTag("tag1", "C")], 510 ), 511 ), 512 Run( 513 run_info=RunInfo( 514 run_id="7", 515 experiment_id=1, 516 user_id="user-id", 517 status=RunStatus.to_string(RunStatus.FAILED), 518 start_time=1, 519 end_time=1, 520 lifecycle_stage=LifecycleStage.ACTIVE, 521 ), 522 run_data=RunData( 523 metrics=[Metric("key1", 125, 1, 0)], 524 params=[Param("my_param", "B")], 525 tags=[RunTag("tag1", "D")], 526 ), 527 ), 528 ] 529 sorted_runs = SearchUtils.sort(runs, order_bys) 530 sorted_run_indices = [] 531 for run in sorted_runs: 532 for i, r in enumerate(runs): 533 if r == run: 534 sorted_run_indices.append(i) 535 break 536 assert sorted_run_indices == matching_runs 537 538 539 def test_order_by_metric_with_nans_infs_nones(): 540 metric_vals_str = ["nan", "inf", "-inf", "-1000", "0", "1000", "None"] 541 runs = [ 542 Run( 543 run_info=RunInfo( 544 run_id=x, 545 experiment_id=0, 546 user_id="user", 547 status=RunStatus.to_string(RunStatus.FINISHED), 548 start_time=0, 549 end_time=1, 550 lifecycle_stage=LifecycleStage.ACTIVE, 551 ), 552 run_data=RunData(metrics=[Metric("x", None if x == "None" else float(x), 1, 0)]), 553 ) 554 for x in metric_vals_str 555 ] 556 sorted_runs_asc = [x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x asc"])] 557 sorted_runs_desc = [x.info.run_id for x in SearchUtils.sort(runs, ["metrics.x desc"])] 558 # asc 559 assert sorted_runs_asc == ["-inf", "-1000", "0", "1000", "inf", "nan", "None"] 560 # desc 561 assert sorted_runs_desc == ["inf", "1000", "0", "-1000", "-inf", "nan", "None"] 562 563 564 @pytest.mark.parametrize( 565 ("order_by", "error_message"), 566 [ 567 ("m.acc", "Invalid entity type"), 568 ("acc", "Invalid attribute key"), 569 ("attri.x", "Invalid entity type"), 570 ("`metrics.A", "Invalid order_by clause"), 571 ("`metrics.A`", "Invalid entity type"), 572 ("attribute.start", "Invalid attribute key"), 573 ("attribute.experiment_id", "Invalid attribute key"), 574 ("metrics.A != 1", "Invalid order_by clause"), 575 ("params.my_param ", "Invalid order_by clause"), 576 ("attribute.run_id ACS", "Invalid ordering key"), 577 ("attribute.run_id decs", "Invalid ordering key"), 578 ], 579 ) 580 def test_invalid_order_by_search_runs(order_by, error_message): 581 with pytest.raises(MlflowException, match=error_message): 582 SearchUtils.parse_order_by_for_search_runs(order_by) 583 584 585 @pytest.mark.parametrize( 586 ("order_by", "ascending_expected"), 587 [ 588 ("metrics.`Mean Square Error`", True), 589 ("metrics.`Mean Square Error` ASC", True), 590 ("metrics.`Mean Square Error` DESC", False), 591 ], 592 ) 593 def test_space_order_by_search_runs(order_by, ascending_expected): 594 identifier_type, identifier_name, ascending = SearchUtils.parse_order_by_for_search_runs( 595 order_by 596 ) 597 assert identifier_type == "metric" 598 assert identifier_name == "Mean Square Error" 599 assert ascending == ascending_expected 600 601 602 @pytest.mark.parametrize( 603 ("order_by", "error_message"), 604 [ 605 ("creation_timestamp DESC", "Invalid order by key"), 606 ("last_updated_timestamp DESC blah", "Invalid order_by clause"), 607 ("", "Invalid order_by clause"), 608 ("timestamp somerandomstuff ASC", "Invalid order_by clause"), 609 ("timestamp somerandomstuff", "Invalid order_by clause"), 610 ("timestamp decs", "Invalid order_by clause"), 611 ("timestamp ACS", "Invalid order_by clause"), 612 ("name aCs", "Invalid ordering key"), 613 ], 614 ) 615 def test_invalid_order_by_search_registered_models(order_by, error_message): 616 with pytest.raises(MlflowException, match=re.escape(error_message)): 617 SearchUtils.parse_order_by_for_search_registered_models(order_by) 618 619 620 @pytest.mark.parametrize( 621 ("page_token", "max_results", "matching_runs", "expected_next_page_token"), 622 [ 623 (None, 1, [0], {"offset": 1}), 624 (None, 2, [0, 1], {"offset": 2}), 625 (None, 3, [0, 1, 2], None), 626 (None, 5, [0, 1, 2], None), 627 ({"offset": 1}, 1, [1], {"offset": 2}), 628 ({"offset": 1}, 2, [1, 2], None), 629 ({"offset": 1}, 3, [1, 2], None), 630 ({"offset": 2}, 1, [2], None), 631 ({"offset": 2}, 2, [2], None), 632 ({"offset": 2}, 0, [], {"offset": 2}), 633 ({"offset": 3}, 1, [], None), 634 ], 635 ) 636 def test_pagination(page_token, max_results, matching_runs, expected_next_page_token): 637 runs = [ 638 Run( 639 run_info=RunInfo( 640 run_id="0", 641 experiment_id=0, 642 user_id="user-id", 643 status=RunStatus.to_string(RunStatus.FAILED), 644 start_time=0, 645 end_time=1, 646 lifecycle_stage=LifecycleStage.ACTIVE, 647 ), 648 run_data=RunData([], [], []), 649 ), 650 Run( 651 run_info=RunInfo( 652 run_id="1", 653 experiment_id=0, 654 user_id="user-id", 655 status=RunStatus.to_string(RunStatus.FAILED), 656 start_time=0, 657 end_time=1, 658 lifecycle_stage=LifecycleStage.ACTIVE, 659 ), 660 run_data=RunData([], [], []), 661 ), 662 Run( 663 run_info=RunInfo( 664 run_id="2", 665 experiment_id=0, 666 user_id="user-id", 667 status=RunStatus.to_string(RunStatus.FAILED), 668 start_time=0, 669 end_time=1, 670 lifecycle_stage=LifecycleStage.ACTIVE, 671 ), 672 run_data=RunData([], [], []), 673 ), 674 ] 675 encoded_page_token = None 676 if page_token: 677 encoded_page_token = base64.b64encode(json.dumps(page_token).encode("utf-8")) 678 paginated_runs, next_page_token = SearchUtils.paginate(runs, encoded_page_token, max_results) 679 680 paginated_run_indices = [] 681 for run in paginated_runs: 682 for i, r in enumerate(runs): 683 if r == run: 684 paginated_run_indices.append(i) 685 break 686 assert paginated_run_indices == matching_runs 687 688 decoded_next_page_token = None 689 if next_page_token: 690 decoded_next_page_token = json.loads(base64.b64decode(next_page_token)) 691 assert decoded_next_page_token == expected_next_page_token 692 693 694 @pytest.mark.parametrize( 695 ("page_token", "error_message"), 696 [ 697 (base64.b64encode(json.dumps({}).encode("utf-8")), "Invalid page token"), 698 (base64.b64encode(json.dumps({"offset": "a"}).encode("utf-8")), "Invalid page token"), 699 (base64.b64encode(json.dumps({"offsoot": 7}).encode("utf-8")), "Invalid page token"), 700 (base64.b64encode(b"not json"), "Invalid page token"), 701 ("not base64", "Invalid page token"), 702 ], 703 ) 704 def test_invalid_page_tokens(page_token, error_message): 705 with pytest.raises(MlflowException, match=error_message): 706 SearchUtils.paginate([], page_token, 1) 707 708 709 def test_like_pattern_with_plus_character(): 710 import mlflow 711 712 name = "jamie-foo C+W bar" 713 mlflow.create_experiment(name) 714 715 exps = mlflow.search_experiments(filter_string=f'name LIKE "{name}"') 716 assert len(exps) == 1 717 718 exps = mlflow.search_experiments(filter_string='name LIKE "jamie-foo C+%"') 719 assert len(exps) == 1 720 721 722 def test_filter_runs_by_tag_and_param_is_null(): 723 run_with_tag = Run( 724 run_info=RunInfo( 725 run_id="run1", 726 experiment_id=0, 727 user_id="user", 728 status=RunStatus.to_string(RunStatus.FINISHED), 729 start_time=0, 730 end_time=1, 731 lifecycle_stage=LifecycleStage.ACTIVE, 732 ), 733 run_data=RunData(tags=[RunTag("env", "prod")], params=[], metrics=[]), 734 ) 735 run_with_param = Run( 736 run_info=RunInfo( 737 run_id="run2", 738 experiment_id=0, 739 user_id="user", 740 status=RunStatus.to_string(RunStatus.FINISHED), 741 start_time=0, 742 end_time=1, 743 lifecycle_stage=LifecycleStage.ACTIVE, 744 ), 745 run_data=RunData(tags=[], params=[Param("lr", "0.01")], metrics=[]), 746 ) 747 runs = [run_with_tag, run_with_param] 748 749 assert [r.info.run_id for r in SearchUtils.filter(runs, "tags.env IS NOT NULL")] == ["run1"] 750 assert [r.info.run_id for r in SearchUtils.filter(runs, "tags.env IS NULL")] == ["run2"] 751 assert [r.info.run_id for r in SearchUtils.filter(runs, "params.lr IS NOT NULL")] == ["run2"] 752 assert [r.info.run_id for r in SearchUtils.filter(runs, "params.lr IS NULL")] == ["run1"] 753 754 755 def test_search_trace_utils_filter_tag_is_null(): 756 loc = trace_location.TraceLocation.from_experiment_id("0") 757 trace1 = TraceInfo( 758 trace_id="t1", 759 trace_location=loc, 760 request_time=0, 761 state=TraceState.OK, 762 tags={"env": "prod", "region": "us"}, 763 ) 764 trace2 = TraceInfo( 765 trace_id="t2", 766 trace_location=loc, 767 request_time=0, 768 state=TraceState.OK, 769 tags={"env": "staging"}, 770 ) 771 trace3 = TraceInfo( 772 trace_id="t3", 773 trace_location=loc, 774 request_time=0, 775 state=TraceState.OK, 776 tags={}, 777 ) 778 traces = [trace1, trace2, trace3] 779 780 result = SearchTraceUtils.filter(traces, "tag.region IS NULL") 781 assert {t.trace_id for t in result} == {"t2", "t3"} 782 783 result = SearchTraceUtils.filter(traces, "tag.region IS NOT NULL") 784 assert {t.trace_id for t in result} == {"t1"} 785 786 result = SearchTraceUtils.filter(traces, "tag.env IS NULL") 787 assert {t.trace_id for t in result} == {"t3"} 788 789 result = SearchTraceUtils.filter(traces, "tag.env IS NOT NULL") 790 assert {t.trace_id for t in result} == {"t1", "t2"} 791 792 result = SearchTraceUtils.filter(traces, 'tag.region IS NULL AND tag.env = "staging"') 793 assert {t.trace_id for t in result} == {"t2"} 794 795 796 def test_search_trace_utils_filter_metadata_is_null(): 797 loc = trace_location.TraceLocation.from_experiment_id("0") 798 trace1 = TraceInfo( 799 trace_id="t1", 800 trace_location=loc, 801 request_time=0, 802 state=TraceState.OK, 803 trace_metadata={"user": "alice", "session": "s1"}, 804 ) 805 trace2 = TraceInfo( 806 trace_id="t2", 807 trace_location=loc, 808 request_time=0, 809 state=TraceState.OK, 810 trace_metadata={"user": "bob"}, 811 ) 812 trace3 = TraceInfo( 813 trace_id="t3", 814 trace_location=loc, 815 request_time=0, 816 state=TraceState.OK, 817 trace_metadata={}, 818 ) 819 traces = [trace1, trace2, trace3] 820 821 result = SearchTraceUtils.filter(traces, "metadata.session IS NULL") 822 assert {t.trace_id for t in result} == {"t2", "t3"} 823 824 result = SearchTraceUtils.filter(traces, "metadata.session IS NOT NULL") 825 assert {t.trace_id for t in result} == {"t1"}