search_utils.py
1 import ast 2 import base64 3 import json 4 import math 5 import operator 6 import re 7 import shlex 8 from dataclasses import asdict, dataclass 9 from typing import TYPE_CHECKING, Any, Callable 10 11 import sqlparse 12 from packaging.version import Version 13 from sqlparse.sql import ( 14 Comparison, 15 Identifier, 16 Parenthesis, 17 Statement, 18 Token, 19 TokenList, 20 ) 21 from sqlparse.tokens import Token as TokenType 22 23 from mlflow.entities import LoggedModel, Metric, RunInfo 24 from mlflow.entities.model_registry.model_version_stages import STAGE_DELETED_INTERNAL 25 from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY 26 from mlflow.exceptions import MlflowException 27 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 28 from mlflow.store.db.db_types import MSSQL, MYSQL, POSTGRES, SQLITE 29 from mlflow.tracing.constant import ( 30 AssessmentMetricSearchKey, 31 SpanMetricSearchKey, 32 TraceMetadataKey, 33 TraceMetricSearchKey, 34 TraceTagKey, 35 ) 36 from mlflow.utils.mlflow_tags import ( 37 MLFLOW_DATASET_CONTEXT, 38 ) 39 40 if TYPE_CHECKING: 41 from sqlalchemy.sql.elements import ClauseElement, ColumnElement 42 43 # MSSQL collation for case-sensitive string comparisons 44 _MSSQL_CASE_SENSITIVE_COLLATION = "Japanese_Bushu_Kakusu_100_CS_AS_KS_WS" 45 46 47 def _convert_like_pattern_to_regex(pattern: str, flags: int = 0): 48 regex = re.escape(pattern) 49 regex = regex.replace("%", ".*").replace("_", ".") 50 51 if not pattern.startswith("%"): 52 regex = "^" + regex 53 if not pattern.endswith("%"): 54 regex = regex + "$" 55 56 return re.compile(regex, flags) 57 58 59 def _like(string, pattern): 60 return _convert_like_pattern_to_regex(pattern).match(string) is not None 61 62 63 def _ilike(string, pattern): 64 return _convert_like_pattern_to_regex(pattern, flags=re.IGNORECASE).match(string) is not None 65 66 67 def _join_in_comparison_tokens(tokens, search_traces=False): 68 """ 69 Find a sequence of tokens that matches the pattern of an IN comparison or a NOT IN comparison, 70 join the tokens into a single Comparison token. Otherwise, return the original list of tokens. 71 """ 72 if Version(sqlparse.__version__) < Version("0.4.4"): 73 # In sqlparse < 0.4.4, IN is treated as a comparison, we don't need to join tokens 74 return tokens 75 76 non_whitespace_tokens = [t for t in tokens if not t.is_whitespace] 77 joined_tokens = [] 78 num_tokens = len(non_whitespace_tokens) 79 iterator = enumerate(non_whitespace_tokens) 80 while elem := next(iterator, None): 81 index, first = elem 82 # We need at least 3 tokens to form an IN comparison or a NOT IN comparison 83 if num_tokens - index < 3: 84 joined_tokens.extend(non_whitespace_tokens[index:]) 85 break 86 87 if search_traces: 88 # timestamp 89 if first.match(ttype=TokenType.Name.Builtin, values=["timestamp", "timestamp_ms"]): 90 (_, second) = next(iterator, (None, None)) 91 (_, third) = next(iterator, (None, None)) 92 if any(x is None for x in [second, third]): 93 raise MlflowException( 94 f"Invalid comparison clause with token `{first}, {second}, {third}`, " 95 "expected 3 tokens", 96 error_code=INVALID_PARAMETER_VALUE, 97 ) 98 if ( 99 second.match( 100 ttype=TokenType.Operator.Comparison, 101 values=SearchTraceUtils.VALID_NUMERIC_ATTRIBUTE_COMPARATORS, 102 ) 103 and third.ttype == TokenType.Literal.Number.Integer 104 ): 105 joined_tokens.append(Comparison(TokenList([first, second, third]))) 106 continue 107 else: 108 joined_tokens.extend([first, second, third]) 109 110 # Wait until we encounter an identifier token 111 if not isinstance(first, Identifier): 112 joined_tokens.append(first) 113 continue 114 115 (_, second) = next(iterator) 116 (_, third) = next(iterator) 117 118 # IN 119 if ( 120 isinstance(first, Identifier) 121 and second.match(ttype=TokenType.Keyword, values=["IN"]) 122 and isinstance(third, Parenthesis) 123 ): 124 joined_tokens.append(Comparison(TokenList([first, second, third]))) 125 continue 126 127 # IS NULL 128 if ( 129 isinstance(first, Identifier) 130 and second.match(ttype=TokenType.Keyword, values=["IS"]) 131 and third.match(ttype=TokenType.Keyword, values=["NULL"]) 132 ): 133 joined_tokens.append( 134 Comparison(TokenList([first, Token(TokenType.Keyword, "IS NULL")])) 135 ) 136 continue 137 138 # IS NOT NULL 139 if ( 140 isinstance(first, Identifier) 141 and second.match(ttype=TokenType.Keyword, values=["IS"]) 142 and third.ttype == TokenType.Keyword 143 and third.value.upper() == "NOT NULL" 144 ): 145 joined_tokens.append( 146 Comparison(TokenList([first, Token(TokenType.Keyword, "IS NOT NULL")])) 147 ) 148 continue 149 150 (_, fourth) = next(iterator, (None, None)) 151 if fourth is None: 152 joined_tokens.extend([first, second, third]) 153 break 154 155 # NOT IN 156 if ( 157 isinstance(first, Identifier) 158 and second.match(ttype=TokenType.Keyword, values=["NOT"]) 159 and third.match(ttype=TokenType.Keyword, values=["IN"]) 160 and isinstance(fourth, Parenthesis) 161 ): 162 joined_tokens.append( 163 Comparison(TokenList([first, Token(TokenType.Keyword, "NOT IN"), fourth])) 164 ) 165 continue 166 167 joined_tokens.extend([first, second, third, fourth]) 168 169 return joined_tokens 170 171 172 class SearchUtils: 173 LIKE_OPERATOR = "LIKE" 174 ILIKE_OPERATOR = "ILIKE" 175 ASC_OPERATOR = "asc" 176 DESC_OPERATOR = "desc" 177 VALID_ORDER_BY_TAGS = [ASC_OPERATOR, DESC_OPERATOR] 178 VALID_METRIC_COMPARATORS = {">", ">=", "!=", "=", "<", "<="} 179 VALID_PARAM_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IS NULL", "IS NOT NULL"} 180 VALID_TAG_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IS NULL", "IS NOT NULL"} 181 VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IN", "NOT IN"} 182 VALID_NUMERIC_ATTRIBUTE_COMPARATORS = VALID_METRIC_COMPARATORS 183 VALID_DATASET_COMPARATORS = {"!=", "=", LIKE_OPERATOR, ILIKE_OPERATOR, "IN", "NOT IN"} 184 _BUILTIN_NUMERIC_ATTRIBUTES = {"start_time", "end_time"} 185 _ALTERNATE_NUMERIC_ATTRIBUTES = {"created", "Created"} 186 _ALTERNATE_STRING_ATTRIBUTES = {"run name", "Run name", "Run Name"} 187 NUMERIC_ATTRIBUTES = set( 188 list(_BUILTIN_NUMERIC_ATTRIBUTES) + list(_ALTERNATE_NUMERIC_ATTRIBUTES) 189 ) 190 DATASET_ATTRIBUTES = {"name", "digest", "context"} 191 VALID_SEARCH_ATTRIBUTE_KEYS = set( 192 RunInfo.get_searchable_attributes() 193 + list(_ALTERNATE_NUMERIC_ATTRIBUTES) 194 + list(_ALTERNATE_STRING_ATTRIBUTES) 195 ) 196 VALID_ORDER_BY_ATTRIBUTE_KEYS = set( 197 RunInfo.get_orderable_attributes() + list(_ALTERNATE_NUMERIC_ATTRIBUTES) 198 ) 199 _METRIC_IDENTIFIER = "metric" 200 _ALTERNATE_METRIC_IDENTIFIERS = {"metrics"} 201 _PARAM_IDENTIFIER = "parameter" 202 _ALTERNATE_PARAM_IDENTIFIERS = {"parameters", "param", "params"} 203 _TAG_IDENTIFIER = "tag" 204 _ALTERNATE_TAG_IDENTIFIERS = {"tags"} 205 _ATTRIBUTE_IDENTIFIER = "attribute" 206 _ALTERNATE_ATTRIBUTE_IDENTIFIERS = {"attr", "attributes", "run"} 207 _DATASET_IDENTIFIER = "dataset" 208 _ALTERNATE_DATASET_IDENTIFIERS = {"datasets"} 209 _IDENTIFIERS = [ 210 _METRIC_IDENTIFIER, 211 _PARAM_IDENTIFIER, 212 _TAG_IDENTIFIER, 213 _ATTRIBUTE_IDENTIFIER, 214 _DATASET_IDENTIFIER, 215 ] 216 _VALID_IDENTIFIERS = set( 217 _IDENTIFIERS 218 + list(_ALTERNATE_METRIC_IDENTIFIERS) 219 + list(_ALTERNATE_PARAM_IDENTIFIERS) 220 + list(_ALTERNATE_TAG_IDENTIFIERS) 221 + list(_ALTERNATE_ATTRIBUTE_IDENTIFIERS) 222 + list(_ALTERNATE_DATASET_IDENTIFIERS) 223 ) 224 STRING_VALUE_TYPES = {TokenType.Literal.String.Single} 225 DELIMITER_VALUE_TYPES = {TokenType.Punctuation} 226 WHITESPACE_VALUE_TYPE = TokenType.Text.Whitespace 227 NUMERIC_VALUE_TYPES = {TokenType.Literal.Number.Integer, TokenType.Literal.Number.Float} 228 # Registered Models Constants 229 ORDER_BY_KEY_TIMESTAMP = "timestamp" 230 ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" 231 ORDER_BY_KEY_MODEL_NAME = "name" 232 VALID_ORDER_BY_KEYS_REGISTERED_MODELS = { 233 ORDER_BY_KEY_TIMESTAMP, 234 ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP, 235 ORDER_BY_KEY_MODEL_NAME, 236 } 237 VALID_TIMESTAMP_ORDER_BY_KEYS = {ORDER_BY_KEY_TIMESTAMP, ORDER_BY_KEY_LAST_UPDATED_TIMESTAMP} 238 # We encourage users to use timestamp for order-by 239 RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS = {ORDER_BY_KEY_MODEL_NAME, ORDER_BY_KEY_TIMESTAMP} 240 241 @staticmethod 242 def get_comparison_func(comparator): 243 return { 244 ">": operator.gt, 245 ">=": operator.ge, 246 "=": operator.eq, 247 "!=": operator.ne, 248 "<=": operator.le, 249 "<": operator.lt, 250 "LIKE": _like, 251 "ILIKE": _ilike, 252 "IN": lambda x, y: x in y, 253 "NOT IN": lambda x, y: x not in y, 254 }[comparator] 255 256 @staticmethod 257 def get_sql_comparison_func(comparator, dialect): 258 import sqlalchemy as sa 259 260 def comparison_func(column, value): 261 if comparator == "LIKE": 262 return column.like(value) 263 elif comparator == "ILIKE": 264 return column.ilike(value) 265 elif comparator == "IN": 266 return column.in_(value) 267 elif comparator == "NOT IN": 268 return ~column.in_(value) 269 return SearchUtils.get_comparison_func(comparator)(column, value) 270 271 def mssql_comparison_func(column, value): 272 if comparator == "RLIKE": 273 raise MlflowException( 274 "RLIKE operator is not supported for MSSQL database dialect. " 275 "Consider using LIKE or ILIKE operators instead.", 276 error_code=INVALID_PARAMETER_VALUE, 277 ) 278 if not isinstance(column.type, sa.types.String): 279 return comparison_func(column, value) 280 281 collated = column.collate(_MSSQL_CASE_SENSITIVE_COLLATION) 282 return comparison_func(collated, value) 283 284 def mysql_comparison_func(column, value): 285 if not isinstance(column.type, sa.types.String): 286 return comparison_func(column, value) 287 288 # MySQL is case insensitive by default, so we need to use the binary operator to 289 # perform case sensitive comparisons. 290 templates = { 291 # Use non-binary ahead of binary comparison for runtime performance 292 "=": "({column} = :value AND BINARY {column} = :value)", 293 "!=": "({column} != :value OR BINARY {column} != :value)", 294 "LIKE": "({column} LIKE :value AND BINARY {column} LIKE :value)", 295 # we need to cast the column to binary to perform a case sensitive comparison 296 # to avoid error like: `Character set 'utf8mb4_0900_ai_ci' cannot be used in 297 # conjunction with 'binary' in call to regexp_like` 298 "RLIKE": "(CAST({column} AS BINARY) REGEXP BINARY :value)", 299 } 300 if comparator in templates: 301 column = f"{column.class_.__tablename__}.{column.key}" 302 return sa.text(templates[comparator].format(column=column)).bindparams( 303 sa.bindparam("value", value=value, unique=True) 304 ) 305 306 return comparison_func(column, value) 307 308 def sqlite_comparison_func(column, value): 309 if comparator == "RLIKE": 310 # SQLite requires a custom regexp function to be registered 311 # Use the built-in function if available 312 return column.op("REGEXP")(value) 313 return comparison_func(column, value) 314 315 def postgres_comparison_func(column, value): 316 if comparator == "RLIKE": 317 return column.op("~")(value) 318 return comparison_func(column, value) 319 320 return { 321 POSTGRES: postgres_comparison_func, 322 SQLITE: sqlite_comparison_func, 323 MSSQL: mssql_comparison_func, 324 MYSQL: mysql_comparison_func, 325 }[dialect] 326 327 @staticmethod 328 def translate_key_alias(key): 329 if key in ["created", "Created"]: 330 return "start_time" 331 if key in ["run name", "Run name", "Run Name"]: 332 return "run_name" 333 return key 334 335 @classmethod 336 def _trim_ends(cls, string_value): 337 return string_value[1:-1] 338 339 @classmethod 340 def _is_quoted(cls, value, pattern): 341 return len(value) >= 2 and value.startswith(pattern) and value.endswith(pattern) 342 343 @classmethod 344 def _trim_backticks(cls, entity_type): 345 """Remove backticks from identifier like `param`, if they exist.""" 346 if cls._is_quoted(entity_type, "`"): 347 return cls._trim_ends(entity_type) 348 return entity_type 349 350 @classmethod 351 def _strip_quotes(cls, value, expect_quoted_value=False): 352 """ 353 Remove quotes for input string. 354 Values of type strings are expected to have quotes. 355 Keys containing special characters are also expected to be enclose in quotes. 356 """ 357 if cls._is_quoted(value, "'") or cls._is_quoted(value, '"'): 358 return cls._trim_ends(value) 359 elif expect_quoted_value: 360 raise MlflowException( 361 "Parameter value is either not quoted or unidentified quote " 362 f"types used for string value {value}. Use either single or double " 363 "quotes.", 364 error_code=INVALID_PARAMETER_VALUE, 365 ) 366 else: 367 return value 368 369 @classmethod 370 def _valid_entity_type(cls, entity_type): 371 entity_type = cls._trim_backticks(entity_type) 372 if entity_type not in cls._VALID_IDENTIFIERS: 373 raise MlflowException( 374 f"Invalid entity type '{entity_type}'. Valid values are {cls._IDENTIFIERS}", 375 error_code=INVALID_PARAMETER_VALUE, 376 ) 377 378 if entity_type in cls._ALTERNATE_PARAM_IDENTIFIERS: 379 return cls._PARAM_IDENTIFIER 380 elif entity_type in cls._ALTERNATE_METRIC_IDENTIFIERS: 381 return cls._METRIC_IDENTIFIER 382 elif entity_type in cls._ALTERNATE_TAG_IDENTIFIERS: 383 return cls._TAG_IDENTIFIER 384 elif entity_type in cls._ALTERNATE_ATTRIBUTE_IDENTIFIERS: 385 return cls._ATTRIBUTE_IDENTIFIER 386 elif entity_type in cls._ALTERNATE_DATASET_IDENTIFIERS: 387 return cls._DATASET_IDENTIFIER 388 else: 389 # one of ("metric", "parameter", "tag", or "attribute") since it a valid type 390 return entity_type 391 392 @classmethod 393 def _get_identifier(cls, identifier, valid_attributes): 394 try: 395 tokens = identifier.split(".", 1) 396 if len(tokens) == 1: 397 key = tokens[0] 398 entity_type = cls._ATTRIBUTE_IDENTIFIER 399 else: 400 entity_type, key = tokens 401 except ValueError: 402 raise MlflowException( 403 f"Invalid identifier {identifier!r}. Columns should be specified as " 404 "'attribute.<key>', 'metric.<key>', 'tag.<key>', 'dataset.<key>', or " 405 "'param.'.", 406 error_code=INVALID_PARAMETER_VALUE, 407 ) 408 identifier = cls._valid_entity_type(entity_type) 409 key = cls._trim_backticks(cls._strip_quotes(key)) 410 if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes: 411 raise MlflowException.invalid_parameter_value( 412 f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'" 413 ) 414 elif identifier == cls._DATASET_IDENTIFIER and key not in cls.DATASET_ATTRIBUTES: 415 raise MlflowException.invalid_parameter_value( 416 f"Invalid dataset key '{key}' specified. Valid keys are '{cls.DATASET_ATTRIBUTES}'" 417 ) 418 return {"type": identifier, "key": key} 419 420 @classmethod 421 def validate_list_supported(cls, key: str) -> None: 422 if key != "run_id": 423 raise MlflowException( 424 "Only the 'run_id' attribute supports comparison with a list of quoted " 425 "string values.", 426 error_code=INVALID_PARAMETER_VALUE, 427 ) 428 429 @classmethod 430 def _get_value(cls, identifier_type, key, token): 431 if identifier_type == cls._METRIC_IDENTIFIER: 432 if token.ttype not in cls.NUMERIC_VALUE_TYPES: 433 raise MlflowException( 434 f"Expected numeric value type for metric. Found {token.value}", 435 error_code=INVALID_PARAMETER_VALUE, 436 ) 437 return token.value 438 elif identifier_type in (cls._PARAM_IDENTIFIER, cls._TAG_IDENTIFIER): 439 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 440 return cls._strip_quotes(token.value, expect_quoted_value=True) 441 raise MlflowException( 442 "Expected a quoted string value for " 443 f"{identifier_type} (e.g. 'my-value'). Got value " 444 f"{token.value}", 445 error_code=INVALID_PARAMETER_VALUE, 446 ) 447 elif identifier_type == cls._ATTRIBUTE_IDENTIFIER: 448 if key in cls.NUMERIC_ATTRIBUTES: 449 if token.ttype not in cls.NUMERIC_VALUE_TYPES: 450 raise MlflowException( 451 f"Expected numeric value type for numeric attribute: {key}. " 452 f"Found {token.value}", 453 error_code=INVALID_PARAMETER_VALUE, 454 ) 455 return token.value 456 elif token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 457 return cls._strip_quotes(token.value, expect_quoted_value=True) 458 elif isinstance(token, Parenthesis): 459 cls.validate_list_supported(key) 460 return cls._parse_run_ids(token) 461 else: 462 raise MlflowException( 463 f"Expected a quoted string value for attributes. Got value {token.value}", 464 error_code=INVALID_PARAMETER_VALUE, 465 ) 466 elif identifier_type == cls._DATASET_IDENTIFIER: 467 if key in cls.DATASET_ATTRIBUTES and ( 468 token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier) 469 ): 470 return cls._strip_quotes(token.value, expect_quoted_value=True) 471 elif isinstance(token, Parenthesis): 472 if key not in ("name", "digest", "context"): 473 raise MlflowException( 474 "Only the dataset 'name' and 'digest' supports comparison with a list of " 475 "quoted string values.", 476 error_code=INVALID_PARAMETER_VALUE, 477 ) 478 return cls._parse_run_ids(token) 479 else: 480 raise MlflowException( 481 "Expected a quoted string value for dataset attributes. " 482 f"Got value {token.value}", 483 error_code=INVALID_PARAMETER_VALUE, 484 ) 485 else: 486 # Expected to be either "param" or "metric". 487 raise MlflowException( 488 "Invalid identifier type. Expected one of " 489 f"{[cls._METRIC_IDENTIFIER, cls._PARAM_IDENTIFIER]}." 490 ) 491 492 @classmethod 493 def _validate_comparison(cls, tokens): 494 base_error_string = "Invalid comparison clause" 495 if len(tokens) == 2: 496 comparator = tokens[1].value.upper() 497 if comparator in ("IS NULL", "IS NOT NULL"): 498 if not isinstance(tokens[0], Identifier): 499 raise MlflowException( 500 f"{base_error_string}. Expected 'Identifier' found '{tokens[0]}'", 501 error_code=INVALID_PARAMETER_VALUE, 502 ) 503 return 504 if len(tokens) != 3: 505 raise MlflowException( 506 f"{base_error_string}. Expected 3 tokens found {len(tokens)}", 507 error_code=INVALID_PARAMETER_VALUE, 508 ) 509 if not isinstance(tokens[0], Identifier): 510 raise MlflowException( 511 f"{base_error_string}. Expected 'Identifier' found '{tokens[0]}'", 512 error_code=INVALID_PARAMETER_VALUE, 513 ) 514 if not isinstance(tokens[1], Token) and tokens[1].ttype != TokenType.Operator.Comparison: 515 raise MlflowException( 516 f"{base_error_string}. Expected comparison found '{tokens[1]}'", 517 error_code=INVALID_PARAMETER_VALUE, 518 ) 519 if not isinstance(tokens[2], Token) and ( 520 tokens[2].ttype not in cls.STRING_VALUE_TYPES.union(cls.NUMERIC_VALUE_TYPES) 521 or isinstance(tokens[2], Identifier) 522 ): 523 raise MlflowException( 524 f"{base_error_string}. Expected value token found '{tokens[2]}'", 525 error_code=INVALID_PARAMETER_VALUE, 526 ) 527 528 @classmethod 529 def _get_comparison(cls, comparison): 530 stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace] 531 cls._validate_comparison(stripped_comparison) 532 533 # Handle IS NULL / IS NOT NULL (2 tokens: identifier + comparator, no value) 534 if len(stripped_comparison) == 2: 535 comparator = stripped_comparison[1].value.upper() 536 comp = cls._get_identifier( 537 stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS 538 ) 539 if comp["type"] not in (cls._TAG_IDENTIFIER, cls._PARAM_IDENTIFIER): 540 raise MlflowException( 541 "IS NULL / IS NOT NULL is only supported for tags and params, " 542 f"not for '{comp['type']}' '{comp['key']}'", 543 error_code=INVALID_PARAMETER_VALUE, 544 ) 545 comp["comparator"] = comparator 546 comp["value"] = None 547 return comp 548 549 comp = cls._get_identifier(stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS) 550 comp["comparator"] = stripped_comparison[1].value 551 comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), stripped_comparison[2]) 552 return comp 553 554 @classmethod 555 def _invalid_statement_token_search_runs(cls, token): 556 if ( 557 isinstance(token, Comparison) 558 or token.is_whitespace 559 or token.match(ttype=TokenType.Keyword, values=["AND"]) 560 ): 561 return False 562 return True 563 564 @classmethod 565 def _process_statement(cls, statement): 566 # check validity 567 tokens = _join_in_comparison_tokens(statement.tokens) 568 invalids = list(filter(cls._invalid_statement_token_search_runs, tokens)) 569 if len(invalids) > 0: 570 invalid_clauses = ", ".join(f"'{token}'" for token in invalids) 571 raise MlflowException( 572 f"Invalid clause(s) in filter string: {invalid_clauses}", 573 error_code=INVALID_PARAMETER_VALUE, 574 ) 575 return [cls._get_comparison(si) for si in tokens if isinstance(si, Comparison)] 576 577 @classmethod 578 def parse_search_filter(cls, filter_string): 579 if not filter_string: 580 return [] 581 try: 582 parsed = sqlparse.parse(filter_string) 583 except Exception: 584 raise MlflowException( 585 f"Error on parsing filter '{filter_string}'", error_code=INVALID_PARAMETER_VALUE 586 ) 587 if len(parsed) == 0 or not isinstance(parsed[0], Statement): 588 raise MlflowException( 589 f"Invalid filter '{filter_string}'. Could not be parsed.", 590 error_code=INVALID_PARAMETER_VALUE, 591 ) 592 elif len(parsed) > 1: 593 raise MlflowException( 594 f"Search filter contained multiple expression {filter_string!r}. " 595 "Provide AND-ed expression list.", 596 error_code=INVALID_PARAMETER_VALUE, 597 ) 598 return cls._process_statement(parsed[0]) 599 600 @classmethod 601 def is_metric(cls, key_type, comparator): 602 if key_type == cls._METRIC_IDENTIFIER: 603 if comparator not in cls.VALID_METRIC_COMPARATORS: 604 raise MlflowException( 605 f"Invalid comparator '{comparator}' not one of '{cls.VALID_METRIC_COMPARATORS}", 606 error_code=INVALID_PARAMETER_VALUE, 607 ) 608 return True 609 return False 610 611 @classmethod 612 def is_param(cls, key_type, comparator): 613 if key_type == cls._PARAM_IDENTIFIER: 614 if comparator not in cls.VALID_PARAM_COMPARATORS: 615 raise MlflowException( 616 f"Invalid comparator '{comparator}' not one of '{cls.VALID_PARAM_COMPARATORS}'", 617 error_code=INVALID_PARAMETER_VALUE, 618 ) 619 return True 620 return False 621 622 @classmethod 623 def is_tag(cls, key_type, comparator): 624 if key_type == cls._TAG_IDENTIFIER: 625 if comparator not in cls.VALID_TAG_COMPARATORS: 626 raise MlflowException( 627 f"Invalid comparator '{comparator}' not one of '{cls.VALID_TAG_COMPARATORS}", 628 error_code=INVALID_PARAMETER_VALUE, 629 ) 630 return True 631 return False 632 633 @classmethod 634 def is_attribute(cls, key_type, key_name, comparator): 635 return cls.is_string_attribute(key_type, key_name, comparator) or cls.is_numeric_attribute( 636 key_type, key_name, comparator 637 ) 638 639 @classmethod 640 def is_string_attribute(cls, key_type, key_name, comparator): 641 if key_type == cls._ATTRIBUTE_IDENTIFIER and key_name not in cls.NUMERIC_ATTRIBUTES: 642 if comparator not in cls.VALID_STRING_ATTRIBUTE_COMPARATORS: 643 raise MlflowException( 644 f"Invalid comparator '{comparator}' not one of " 645 f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}'", 646 error_code=INVALID_PARAMETER_VALUE, 647 ) 648 return True 649 return False 650 651 @classmethod 652 def is_numeric_attribute(cls, key_type, key_name, comparator): 653 if key_type == cls._ATTRIBUTE_IDENTIFIER and key_name in cls.NUMERIC_ATTRIBUTES: 654 if comparator not in cls.VALID_NUMERIC_ATTRIBUTE_COMPARATORS: 655 raise MlflowException( 656 f"Invalid comparator '{comparator}' not one of " 657 f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}", 658 error_code=INVALID_PARAMETER_VALUE, 659 ) 660 return True 661 return False 662 663 @classmethod 664 def is_dataset(cls, key_type, comparator): 665 if key_type == cls._DATASET_IDENTIFIER: 666 if comparator not in cls.VALID_DATASET_COMPARATORS: 667 raise MlflowException( 668 f"Invalid comparator '{comparator}' " 669 f"not one of '{cls.VALID_DATASET_COMPARATORS}", 670 error_code=INVALID_PARAMETER_VALUE, 671 ) 672 return True 673 return False 674 675 @classmethod 676 def _is_metric_on_dataset(cls, metric: Metric, dataset: dict[str, Any]) -> bool: 677 return metric.dataset_name == dataset.get("dataset_name") and ( 678 dataset.get("dataset_digest") is None 679 or dataset.get("dataset_digest") == metric.dataset_digest 680 ) 681 682 @classmethod 683 def _does_run_match_clause(cls, run, sed): 684 key_type = sed.get("type") 685 key = sed.get("key") 686 value = sed.get("value") 687 comparator = sed.get("comparator").upper() 688 689 key = SearchUtils.translate_key_alias(key) 690 691 if cls.is_metric(key_type, comparator): 692 lhs = run.data.metrics.get(key, None) 693 value = float(value) 694 elif cls.is_param(key_type, comparator): 695 lhs = run.data.params.get(key, None) 696 elif cls.is_tag(key_type, comparator): 697 lhs = run.data.tags.get(key, None) 698 elif cls.is_string_attribute(key_type, key, comparator): 699 lhs = getattr(run.info, key) 700 elif cls.is_numeric_attribute(key_type, key, comparator): 701 lhs = getattr(run.info, key) 702 value = int(value) 703 elif cls.is_dataset(key_type, comparator): 704 if key == "context": 705 return any( 706 SearchUtils.get_comparison_func(comparator)(tag.value if tag else None, value) 707 for dataset_input in run.inputs.dataset_inputs 708 for tag in dataset_input.tags 709 if tag.key == MLFLOW_DATASET_CONTEXT 710 ) 711 else: 712 return any( 713 SearchUtils.get_comparison_func(comparator)( 714 getattr(dataset_input.dataset, key), value 715 ) 716 for dataset_input in run.inputs.dataset_inputs 717 ) 718 else: 719 raise MlflowException( 720 f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE 721 ) 722 if comparator in ("IS NULL", "IS NOT NULL"): 723 return (lhs is None) if comparator == "IS NULL" else (lhs is not None) 724 725 if lhs is None: 726 return False 727 728 return SearchUtils.get_comparison_func(comparator)(lhs, value) 729 730 @classmethod 731 def _does_model_match_clause(cls, model, sed): 732 key_type = sed.get("type") 733 key = sed.get("key") 734 value = sed.get("value") 735 comparator = sed.get("comparator").upper() 736 737 key = SearchUtils.translate_key_alias(key) 738 739 if cls.is_metric(key_type, comparator): 740 matching_metrics = [metric for metric in model.metrics if metric.key == key] 741 lhs = matching_metrics[0].value if matching_metrics else None 742 value = float(value) 743 elif cls.is_param(key_type, comparator): 744 lhs = model.params.get(key, None) 745 elif cls.is_tag(key_type, comparator): 746 lhs = model.tags.get(key, None) 747 elif cls.is_string_attribute(key_type, key, comparator): 748 lhs = getattr(model.info, key) 749 elif cls.is_numeric_attribute(key_type, key, comparator): 750 lhs = getattr(model.info, key) 751 value = int(value) 752 else: 753 raise MlflowException( 754 f"Invalid model search expression type '{key_type}'", 755 error_code=INVALID_PARAMETER_VALUE, 756 ) 757 if lhs is None: 758 return False 759 760 return SearchUtils.get_comparison_func(comparator)(lhs, value) 761 762 @classmethod 763 def filter(cls, runs, filter_string): 764 """Filters a set of runs based on a search filter string.""" 765 if not filter_string: 766 return runs 767 parsed = cls.parse_search_filter(filter_string) 768 769 def run_matches(run): 770 return all(cls._does_run_match_clause(run, s) for s in parsed) 771 772 return [run for run in runs if run_matches(run)] 773 774 @classmethod 775 def _validate_order_by_and_generate_token(cls, order_by): 776 try: 777 parsed = sqlparse.parse(order_by) 778 except Exception: 779 raise MlflowException( 780 f"Error on parsing order_by clause '{order_by}'", 781 error_code=INVALID_PARAMETER_VALUE, 782 ) 783 if len(parsed) != 1 or not isinstance(parsed[0], Statement): 784 raise MlflowException( 785 f"Invalid order_by clause '{order_by}'. Could not be parsed.", 786 error_code=INVALID_PARAMETER_VALUE, 787 ) 788 statement = parsed[0] 789 ttype_for_timestamp = ( 790 TokenType.Name.Builtin 791 if Version(sqlparse.__version__) >= Version("0.4.3") 792 else TokenType.Keyword 793 ) 794 795 if len(statement.tokens) == 1 and isinstance(statement[0], Identifier): 796 token_value = statement.tokens[0].value 797 elif len(statement.tokens) == 1 and statement.tokens[0].match( 798 ttype=ttype_for_timestamp, values=[cls.ORDER_BY_KEY_TIMESTAMP] 799 ): 800 token_value = cls.ORDER_BY_KEY_TIMESTAMP 801 elif ( 802 statement.tokens[0].match( 803 ttype=ttype_for_timestamp, values=[cls.ORDER_BY_KEY_TIMESTAMP] 804 ) 805 and all(token.is_whitespace for token in statement.tokens[1:-1]) 806 and statement.tokens[-1].ttype == TokenType.Keyword.Order 807 ): 808 token_value = cls.ORDER_BY_KEY_TIMESTAMP + " " + statement.tokens[-1].value 809 else: 810 raise MlflowException( 811 f"Invalid order_by clause '{order_by}'. Could not be parsed.", 812 error_code=INVALID_PARAMETER_VALUE, 813 ) 814 return token_value 815 816 @classmethod 817 def _parse_order_by_string(cls, order_by): 818 token_value = cls._validate_order_by_and_generate_token(order_by) 819 is_ascending = True 820 tokens = shlex.split(token_value.replace("`", '"')) 821 if len(tokens) > 2: 822 raise MlflowException( 823 f"Invalid order_by clause '{order_by}'. Could not be parsed.", 824 error_code=INVALID_PARAMETER_VALUE, 825 ) 826 elif len(tokens) == 2: 827 order_token = tokens[1].lower() 828 if order_token not in cls.VALID_ORDER_BY_TAGS: 829 raise MlflowException( 830 f"Invalid ordering key in order_by clause '{order_by}'.", 831 error_code=INVALID_PARAMETER_VALUE, 832 ) 833 is_ascending = order_token == cls.ASC_OPERATOR 834 token_value = tokens[0] 835 return token_value, is_ascending 836 837 @classmethod 838 def parse_order_by_for_search_runs(cls, order_by): 839 token_value, is_ascending = cls._parse_order_by_string(order_by) 840 identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS) 841 return identifier["type"], identifier["key"], is_ascending 842 843 @classmethod 844 def parse_order_by_for_search_registered_models(cls, order_by): 845 token_value, is_ascending = cls._parse_order_by_string(order_by) 846 token_value = token_value.strip() 847 if token_value not in cls.VALID_ORDER_BY_KEYS_REGISTERED_MODELS: 848 raise MlflowException( 849 f"Invalid order by key '{token_value}' specified. Valid keys " 850 f"are '{cls.RECOMMENDED_ORDER_BY_KEYS_REGISTERED_MODELS}'", 851 error_code=INVALID_PARAMETER_VALUE, 852 ) 853 return token_value, is_ascending 854 855 @classmethod 856 def _get_value_for_sort(cls, run, key_type, key, ascending): 857 """Returns a tuple suitable to be used as a sort key for runs.""" 858 sort_value = None 859 key = SearchUtils.translate_key_alias(key) 860 if key_type == cls._METRIC_IDENTIFIER: 861 sort_value = run.data.metrics.get(key) 862 elif key_type == cls._PARAM_IDENTIFIER: 863 sort_value = run.data.params.get(key) 864 elif key_type == cls._TAG_IDENTIFIER: 865 sort_value = run.data.tags.get(key) 866 elif key_type == cls._ATTRIBUTE_IDENTIFIER: 867 sort_value = getattr(run.info, key) 868 else: 869 raise MlflowException( 870 f"Invalid order_by entity type '{key_type}'", error_code=INVALID_PARAMETER_VALUE 871 ) 872 873 # Return a key such that None values are always at the end. 874 is_none = sort_value is None 875 is_nan = isinstance(sort_value, float) and math.isnan(sort_value) 876 fill_value = (1 if ascending else -1) * math.inf 877 878 if is_none: 879 sort_value = fill_value 880 elif is_nan: 881 sort_value = -fill_value 882 883 is_none_or_nan = is_none or is_nan 884 885 return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value) 886 887 @classmethod 888 def _get_model_value_for_sort(cls, model, key_type, key, ascending): 889 """Returns a tuple suitable to be used as a sort key for models.""" 890 sort_value = None 891 key = SearchUtils.translate_key_alias(key) 892 if key_type == cls._METRIC_IDENTIFIER: 893 matching_metrics = [metric for metric in model.metrics if metric.key == key] 894 sort_value = float(matching_metrics[0].value) if matching_metrics else None 895 elif key_type == cls._PARAM_IDENTIFIER: 896 sort_value = model.params.get(key) 897 elif key_type == cls._TAG_IDENTIFIER: 898 sort_value = model.tags.get(key) 899 elif key_type == cls._ATTRIBUTE_IDENTIFIER: 900 sort_value = getattr(model, key) 901 else: 902 raise MlflowException( 903 f"Invalid models order_by entity type '{key_type}'", 904 error_code=INVALID_PARAMETER_VALUE, 905 ) 906 907 # Return a key such that None values are always at the end. 908 is_none = sort_value is None 909 is_nan = isinstance(sort_value, float) and math.isnan(sort_value) 910 fill_value = (1 if ascending else -1) * math.inf 911 912 if is_none: 913 sort_value = fill_value 914 elif is_nan: 915 sort_value = -fill_value 916 917 is_none_or_nan = is_none or is_nan 918 919 return (is_none_or_nan, sort_value) if ascending else (not is_none_or_nan, sort_value) 920 921 @classmethod 922 def sort(cls, runs, order_by_list): 923 """Sorts a set of runs based on their natural ordering and an overriding set of order_bys. 924 Runs are naturally ordered first by start time descending, then by run id for tie-breaking. 925 """ 926 runs = sorted(runs, key=lambda run: (-run.info.start_time, run.info.run_id)) 927 if not order_by_list: 928 return runs 929 # NB: We rely on the stability of Python's sort function, so that we can apply 930 # the ordering conditions in reverse order. 931 for order_by_clause in reversed(order_by_list): 932 (key_type, key, ascending) = cls.parse_order_by_for_search_runs(order_by_clause) 933 934 runs = sorted( 935 runs, 936 key=lambda run: cls._get_value_for_sort(run, key_type, key, ascending), 937 reverse=not ascending, 938 ) 939 return runs 940 941 @classmethod 942 def parse_start_offset_from_page_token(cls, page_token): 943 # Note: the page_token is expected to be a base64-encoded JSON that looks like 944 # { "offset": xxx }. However, this format is not stable, so it should not be 945 # relied upon outside of this method. 946 if not page_token: 947 return 0 948 949 try: 950 decoded_token = base64.b64decode(page_token) 951 except TypeError: 952 raise MlflowException( 953 "Invalid page token, could not base64-decode", error_code=INVALID_PARAMETER_VALUE 954 ) 955 except base64.binascii.Error: 956 raise MlflowException( 957 "Invalid page token, could not base64-decode", error_code=INVALID_PARAMETER_VALUE 958 ) 959 960 try: 961 parsed_token = json.loads(decoded_token) 962 except ValueError: 963 raise MlflowException( 964 f"Invalid page token, decoded value={decoded_token}", 965 error_code=INVALID_PARAMETER_VALUE, 966 ) 967 968 offset_str = parsed_token.get("offset") 969 if not offset_str: 970 raise MlflowException( 971 f"Invalid page token, parsed value={parsed_token}", 972 error_code=INVALID_PARAMETER_VALUE, 973 ) 974 975 try: 976 offset = int(offset_str) 977 except ValueError: 978 raise MlflowException( 979 f"Invalid page token, not stringable {offset_str}", 980 error_code=INVALID_PARAMETER_VALUE, 981 ) 982 983 return offset 984 985 @classmethod 986 def create_page_token(cls, offset): 987 return base64.b64encode(json.dumps({"offset": offset}).encode("utf-8")) 988 989 @classmethod 990 def paginate(cls, runs, page_token, max_results): 991 """Paginates a set of runs based on an offset encoded into the page_token and a max 992 results limit. Returns a pair containing the set of paginated runs, followed by 993 an optional next_page_token if there are further results that need to be returned. 994 """ 995 start_offset = cls.parse_start_offset_from_page_token(page_token) 996 final_offset = start_offset + max_results 997 998 paginated_runs = runs[start_offset:final_offset] 999 next_page_token = None 1000 if final_offset < len(runs): 1001 next_page_token = cls.create_page_token(final_offset) 1002 return (paginated_runs, next_page_token) 1003 1004 # Model Registry specific parser 1005 # TODO: Tech debt. Refactor search code into common utils, tracking server, and model 1006 # registry specific code. 1007 1008 VALID_SEARCH_KEYS_FOR_MODEL_VERSIONS = {"name", "run_id", "source_path"} 1009 VALID_SEARCH_KEYS_FOR_REGISTERED_MODELS = {"name"} 1010 1011 @classmethod 1012 def _check_valid_identifier_list(cls, tup: tuple[Any, ...]) -> None: 1013 """ 1014 Validate that `tup` is a non-empty tuple of strings. 1015 """ 1016 if len(tup) == 0: 1017 raise MlflowException( 1018 "While parsing a list in the query," 1019 " expected a non-empty list of string values, but got empty list", 1020 error_code=INVALID_PARAMETER_VALUE, 1021 ) 1022 1023 if not all(isinstance(x, str) for x in tup): 1024 raise MlflowException( 1025 "While parsing a list in the query, expected string value, punctuation, " 1026 f"or whitespace, but got different type in list: {tup}", 1027 error_code=INVALID_PARAMETER_VALUE, 1028 ) 1029 1030 @classmethod 1031 def _parse_list_from_sql_token(cls, token): 1032 try: 1033 parsed = ast.literal_eval(token.value) 1034 except SyntaxError as e: 1035 raise MlflowException( 1036 "While parsing a list in the query," 1037 " expected a non-empty list of string values, but got ill-formed list.", 1038 error_code=INVALID_PARAMETER_VALUE, 1039 ) from e 1040 1041 parsed = parsed if isinstance(parsed, tuple) else (parsed,) 1042 cls._check_valid_identifier_list(parsed) 1043 return parsed 1044 1045 @classmethod 1046 def _parse_run_ids(cls, token): 1047 run_id_list = cls._parse_list_from_sql_token(token) 1048 # Because MySQL IN clause is case-insensitive, but all run_ids only contain lower 1049 # case letters, so that we filter out run_ids containing upper case letters here. 1050 return [run_id for run_id in run_id_list if run_id.islower()] 1051 1052 1053 class SearchExperimentsUtils(SearchUtils): 1054 VALID_SEARCH_ATTRIBUTE_KEYS = {"name", "creation_time", "last_update_time"} 1055 VALID_ORDER_BY_ATTRIBUTE_KEYS = {"name", "experiment_id", "creation_time", "last_update_time"} 1056 NUMERIC_ATTRIBUTES = {"creation_time", "last_update_time"} 1057 VALID_TAG_COMPARATORS = {"!=", "=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"} 1058 1059 @classmethod 1060 def _invalid_statement_token_search_experiments(cls, token): 1061 if ( 1062 isinstance(token, Comparison) 1063 or token.is_whitespace 1064 or token.match(ttype=TokenType.Keyword, values=["AND"]) 1065 ): 1066 return False 1067 return True 1068 1069 @classmethod 1070 def _process_statement(cls, statement): 1071 tokens = _join_in_comparison_tokens(statement.tokens) 1072 invalids = list(filter(cls._invalid_statement_token_search_experiments, tokens)) 1073 if len(invalids) > 0: 1074 invalid_clauses = ", ".join(map(str, invalids)) 1075 raise MlflowException.invalid_parameter_value( 1076 f"Invalid clause(s) in filter string: {invalid_clauses}" 1077 ) 1078 return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)] 1079 1080 @classmethod 1081 def _get_identifier(cls, identifier, valid_attributes): 1082 tokens = identifier.split(".", maxsplit=1) 1083 if len(tokens) == 1: 1084 key = tokens[0] 1085 identifier = cls._ATTRIBUTE_IDENTIFIER 1086 else: 1087 entity_type, key = tokens 1088 valid_entity_types = ("attribute", "tag", "tags") 1089 if entity_type not in valid_entity_types: 1090 raise MlflowException.invalid_parameter_value( 1091 f"Invalid entity type '{entity_type}'. " 1092 f"Valid entity types are {valid_entity_types}" 1093 ) 1094 identifier = cls._valid_entity_type(entity_type) 1095 1096 key = cls._trim_backticks(cls._strip_quotes(key)) 1097 if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes: 1098 raise MlflowException.invalid_parameter_value( 1099 f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'" 1100 ) 1101 return {"type": identifier, "key": key} 1102 1103 @classmethod 1104 def _validate_comparison(cls, tokens): 1105 # Allow 2-token IS NULL / IS NOT NULL comparisons for tags 1106 if len(tokens) == 2: 1107 comparator = tokens[1].value.upper() 1108 if comparator in ("IS NULL", "IS NOT NULL"): 1109 if not isinstance(tokens[0], Identifier): 1110 raise MlflowException( 1111 f"Invalid comparison clause. Expected 'Identifier' found '{tokens[0]}'", 1112 error_code=INVALID_PARAMETER_VALUE, 1113 ) 1114 return 1115 super()._validate_comparison(tokens) 1116 1117 @classmethod 1118 def _get_comparison(cls, comparison): 1119 stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace] 1120 cls._validate_comparison(stripped_comparison) 1121 1122 # Handle IS NULL / IS NOT NULL (2 tokens: identifier + comparator, no value) 1123 if len(stripped_comparison) == 2: 1124 comparator = stripped_comparison[1].value.upper() 1125 comp = cls._get_identifier( 1126 stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS 1127 ) 1128 if comp["type"] != cls._TAG_IDENTIFIER: 1129 raise MlflowException.invalid_parameter_value( 1130 f"IS NULL / IS NOT NULL is only supported for tags, " 1131 f"not for attribute '{comp['key']}'" 1132 ) 1133 comp["comparator"] = comparator 1134 comp["value"] = None 1135 return comp 1136 1137 left, comparator, right = stripped_comparison 1138 comp = cls._get_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS) 1139 comp["comparator"] = comparator.value 1140 comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right) 1141 return comp 1142 1143 @classmethod 1144 def parse_order_by_for_search_experiments(cls, order_by): 1145 token_value, is_ascending = cls._parse_order_by_string(order_by) 1146 identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS) 1147 return identifier["type"], identifier["key"], is_ascending 1148 1149 @classmethod 1150 def is_attribute(cls, key_type, comparator): 1151 if key_type == cls._ATTRIBUTE_IDENTIFIER: 1152 if comparator not in cls.VALID_STRING_ATTRIBUTE_COMPARATORS: 1153 raise MlflowException( 1154 f"Invalid comparator '{comparator}' not one of " 1155 f"'{cls.VALID_STRING_ATTRIBUTE_COMPARATORS}'" 1156 ) 1157 return True 1158 return False 1159 1160 @classmethod 1161 def _does_experiment_match_clause(cls, experiment, sed): 1162 key_type = sed.get("type") 1163 key = sed.get("key") 1164 value = sed.get("value") 1165 comparator = sed.get("comparator").upper() 1166 1167 if cls.is_string_attribute(key_type, key, comparator): 1168 lhs = getattr(experiment, key) 1169 elif cls.is_numeric_attribute(key_type, key, comparator): 1170 lhs = getattr(experiment, key) 1171 value = float(value) 1172 elif cls.is_tag(key_type, comparator): 1173 if comparator == "IS NULL": 1174 return key not in experiment.tags 1175 elif comparator == "IS NOT NULL": 1176 return key in experiment.tags 1177 if key not in experiment.tags: 1178 return False 1179 lhs = experiment.tags.get(key, None) 1180 if lhs is None: 1181 return experiment 1182 else: 1183 raise MlflowException( 1184 f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE 1185 ) 1186 1187 return SearchUtils.get_comparison_func(comparator)(lhs, value) 1188 1189 @classmethod 1190 def filter(cls, experiments, filter_string): 1191 if not filter_string: 1192 return experiments 1193 parsed = cls.parse_search_filter(filter_string) 1194 1195 def experiment_matches(experiment): 1196 return all(cls._does_experiment_match_clause(experiment, s) for s in parsed) 1197 1198 return list(filter(experiment_matches, experiments)) 1199 1200 @classmethod 1201 def _get_sort_key(cls, order_by_list): 1202 order_by = [] 1203 parsed_order_by = map(cls.parse_order_by_for_search_experiments, order_by_list) 1204 for type_, key, ascending in parsed_order_by: 1205 if type_ == "attribute": 1206 order_by.append((key, ascending)) 1207 else: 1208 raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}") 1209 1210 # Add a tie-breaker 1211 if not any(key == "experiment_id" for key, _ in order_by): 1212 order_by.append(("experiment_id", False)) 1213 1214 # https://stackoverflow.com/a/56842689 1215 class _Sorter: 1216 def __init__(self, obj, ascending): 1217 self.obj = obj 1218 self.ascending = ascending 1219 1220 # Only need < and == are needed for use as a key parameter in the sorted function 1221 def __eq__(self, other): 1222 return other.obj == self.obj 1223 1224 def __lt__(self, other): 1225 if self.obj is None: 1226 return False 1227 elif other.obj is None: 1228 return True 1229 elif self.ascending: 1230 return self.obj < other.obj 1231 else: 1232 return other.obj < self.obj 1233 1234 def _apply_sorter(experiment, key, ascending): 1235 attr = getattr(experiment, key) 1236 return _Sorter(attr, ascending) 1237 1238 return lambda experiment: tuple(_apply_sorter(experiment, k, asc) for (k, asc) in order_by) 1239 1240 @classmethod 1241 def sort(cls, experiments, order_by_list): 1242 return sorted(experiments, key=cls._get_sort_key(order_by_list)) 1243 1244 1245 # https://stackoverflow.com/a/56842689 1246 class _Reversor: 1247 def __init__(self, obj): 1248 self.obj = obj 1249 1250 # Only need < and == are needed for use as a key parameter in the sorted function 1251 def __eq__(self, other): 1252 return other.obj == self.obj 1253 1254 def __lt__(self, other): 1255 if self.obj is None: 1256 return False 1257 if other.obj is None: 1258 return True 1259 return other.obj < self.obj 1260 1261 1262 def _apply_reversor(model, key, ascending): 1263 attr = getattr(model, key) 1264 return attr if ascending else _Reversor(attr) 1265 1266 1267 class SearchModelUtils(SearchUtils): 1268 NUMERIC_ATTRIBUTES = {"creation_timestamp", "last_updated_timestamp"} 1269 VALID_SEARCH_ATTRIBUTE_KEYS = {"name"} 1270 VALID_ORDER_BY_KEYS_REGISTERED_MODELS = {"name", "creation_timestamp", "last_updated_timestamp"} 1271 VALID_TAG_COMPARATORS = {"!=", "=", "LIKE", "ILIKE"} 1272 1273 @classmethod 1274 def _does_registered_model_match_clauses(cls, model, sed): 1275 key_type = sed.get("type") 1276 key = sed.get("key") 1277 value = sed.get("value") 1278 comparator = sed.get("comparator").upper() 1279 1280 # what comparators do we support here? 1281 if cls.is_string_attribute(key_type, key, comparator): 1282 lhs = getattr(model, key) 1283 elif cls.is_numeric_attribute(key_type, key, comparator): 1284 lhs = getattr(model, key) 1285 value = int(value) 1286 elif cls.is_tag(key_type, comparator): 1287 # NB: We should use the private attribute `_tags` instead of the `tags` property 1288 # to consider all tags including reserved ones. 1289 lhs = model._tags.get(key, None) 1290 else: 1291 raise MlflowException( 1292 f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE 1293 ) 1294 1295 # NB: Handling the special `mlflow.prompt.is_prompt` tag. This tag is used for 1296 # distinguishing between prompt models and normal models. For example, we want to 1297 # search for models only by the following filter string: 1298 # 1299 # tags.`mlflow.prompt.is_prompt` != 'true' 1300 # tags.`mlflow.prompt.is_prompt` = 'false' 1301 # 1302 # However, models do not have this tag, so lhs is None in this case. Instead of returning 1303 # False like normal tag filter, we need to return True here. 1304 if key == IS_PROMPT_TAG_KEY and lhs is None: 1305 return (comparator == "=" and value == "false") or ( 1306 comparator == "!=" and value == "true" 1307 ) 1308 1309 if lhs is None: 1310 return False 1311 1312 return SearchUtils.get_comparison_func(comparator)(lhs, value) 1313 1314 @classmethod 1315 def filter(cls, registered_models, filter_string): 1316 """Filters a set of registered models based on a search filter string.""" 1317 if not filter_string: 1318 return registered_models 1319 parsed = cls.parse_search_filter(filter_string) 1320 1321 def registered_model_matches(model): 1322 return all(cls._does_registered_model_match_clauses(model, s) for s in parsed) 1323 1324 return [ 1325 registered_model 1326 for registered_model in registered_models 1327 if registered_model_matches(registered_model) 1328 ] 1329 1330 @classmethod 1331 def parse_order_by_for_search_registered_models(cls, order_by): 1332 token_value, is_ascending = cls._parse_order_by_string(order_by) 1333 identifier = SearchExperimentsUtils._get_identifier( 1334 token_value.strip(), cls.VALID_ORDER_BY_KEYS_REGISTERED_MODELS 1335 ) 1336 return identifier["type"], identifier["key"], is_ascending 1337 1338 @classmethod 1339 def _get_sort_key(cls, order_by_list): 1340 order_by = [] 1341 parsed_order_by = map(cls.parse_order_by_for_search_registered_models, order_by_list or []) 1342 for type_, key, ascending in parsed_order_by: 1343 if type_ == "attribute": 1344 order_by.append((key, ascending)) 1345 else: 1346 raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}") 1347 1348 # Add a tie-breaker 1349 if not any(key == "name" for key, _ in order_by): 1350 order_by.append(("name", True)) 1351 1352 return lambda model: tuple(_apply_reversor(model, k, asc) for (k, asc) in order_by) 1353 1354 @classmethod 1355 def sort(cls, models, order_by_list): 1356 return sorted(models, key=cls._get_sort_key(order_by_list)) 1357 1358 @classmethod 1359 def _process_statement(cls, statement): 1360 tokens = _join_in_comparison_tokens(statement.tokens) 1361 invalids = list(filter(cls._invalid_statement_token_search_model_registry, tokens)) 1362 if len(invalids) > 0: 1363 invalid_clauses = ", ".join(map(str, invalids)) 1364 raise MlflowException.invalid_parameter_value( 1365 f"Invalid clause(s) in filter string: {invalid_clauses}" 1366 ) 1367 return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)] 1368 1369 @classmethod 1370 def _get_model_search_identifier(cls, identifier, valid_attributes): 1371 tokens = identifier.split(".", maxsplit=1) 1372 if len(tokens) == 1: 1373 key = tokens[0] 1374 identifier = cls._ATTRIBUTE_IDENTIFIER 1375 else: 1376 entity_type, key = tokens 1377 valid_entity_types = ("attribute", "tag", "tags") 1378 if entity_type not in valid_entity_types: 1379 raise MlflowException.invalid_parameter_value( 1380 f"Invalid entity type '{entity_type}'. " 1381 f"Valid entity types are {valid_entity_types}" 1382 ) 1383 identifier = ( 1384 cls._TAG_IDENTIFIER if entity_type in ("tag", "tags") else cls._ATTRIBUTE_IDENTIFIER 1385 ) 1386 1387 if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes: 1388 raise MlflowException.invalid_parameter_value( 1389 f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'" 1390 ) 1391 1392 key = cls._trim_backticks(cls._strip_quotes(key)) 1393 return {"type": identifier, "key": key} 1394 1395 @classmethod 1396 def _get_comparison(cls, comparison): 1397 stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace] 1398 cls._validate_comparison(stripped_comparison) 1399 left, comparator, right = stripped_comparison 1400 comp = cls._get_model_search_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS) 1401 comp["comparator"] = comparator.value.upper() 1402 comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right) 1403 return comp 1404 1405 @classmethod 1406 def _get_value(cls, identifier_type, key, token): 1407 if identifier_type == cls._TAG_IDENTIFIER: 1408 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 1409 return cls._strip_quotes(token.value, expect_quoted_value=True) 1410 raise MlflowException( 1411 "Expected a quoted string value for " 1412 f"{identifier_type} (e.g. 'my-value'). Got value " 1413 f"{token.value}", 1414 error_code=INVALID_PARAMETER_VALUE, 1415 ) 1416 elif identifier_type == cls._ATTRIBUTE_IDENTIFIER: 1417 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 1418 return cls._strip_quotes(token.value, expect_quoted_value=True) 1419 elif isinstance(token, Parenthesis): 1420 if key != "run_id": 1421 raise MlflowException( 1422 "Only the 'run_id' attribute supports comparison with a list of quoted " 1423 "string values.", 1424 error_code=INVALID_PARAMETER_VALUE, 1425 ) 1426 return cls._parse_run_ids(token) 1427 else: 1428 raise MlflowException( 1429 "Expected a quoted string value or a list of quoted string values for " 1430 f"attributes. Got value {token.value}", 1431 error_code=INVALID_PARAMETER_VALUE, 1432 ) 1433 else: 1434 # Expected to be either "param" or "metric". 1435 raise MlflowException( 1436 "Invalid identifier type. Expected one of " 1437 f"{[cls._ATTRIBUTE_IDENTIFIER, cls._TAG_IDENTIFIER]}.", 1438 error_code=INVALID_PARAMETER_VALUE, 1439 ) 1440 1441 @classmethod 1442 def _invalid_statement_token_search_model_registry(cls, token): 1443 if ( 1444 isinstance(token, Comparison) 1445 or token.is_whitespace 1446 or token.match(ttype=TokenType.Keyword, values=["AND"]) 1447 ): 1448 return False 1449 return True 1450 1451 1452 class SearchModelVersionUtils(SearchUtils): 1453 NUMERIC_ATTRIBUTES = {"version_number", "creation_timestamp", "last_updated_timestamp"} 1454 VALID_SEARCH_ATTRIBUTE_KEYS = { 1455 "name", 1456 "version_number", 1457 "run_id", 1458 "source_path", 1459 } 1460 VALID_ORDER_BY_ATTRIBUTE_KEYS = { 1461 "name", 1462 "version_number", 1463 "creation_timestamp", 1464 "last_updated_timestamp", 1465 } 1466 VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", "LIKE", "ILIKE", "IN"} 1467 VALID_TAG_COMPARATORS = {"!=", "=", "LIKE", "ILIKE"} 1468 1469 @classmethod 1470 def _does_model_version_match_clauses(cls, mv, sed): 1471 key_type = sed.get("type") 1472 key = sed.get("key") 1473 value = sed.get("value") 1474 comparator = sed.get("comparator").upper() 1475 1476 if cls.is_string_attribute(key_type, key, comparator): 1477 lhs = getattr(mv, "source" if key == "source_path" else key) 1478 elif cls.is_numeric_attribute(key_type, key, comparator): 1479 if key == "version_number": 1480 key = "version" 1481 lhs = getattr(mv, key) 1482 value = int(value) 1483 elif cls.is_tag(key_type, comparator): 1484 lhs = mv.tags.get(key, None) 1485 else: 1486 raise MlflowException( 1487 f"Invalid search expression type '{key_type}'", error_code=INVALID_PARAMETER_VALUE 1488 ) 1489 1490 # NB: Handling the special `mlflow.prompt.is_prompt` tag. This tag is used for 1491 # distinguishing between prompt models and normal models. For example, we want to 1492 # search for models only by the following filter string: 1493 # 1494 # tags.`mlflow.prompt.is_prompt` != 'true' 1495 # tags.`mlflow.prompt.is_prompt` = 'false' 1496 # 1497 # However, models do not have this tag, so lhs is None in this case. Instead of returning 1498 # False like normal tag filter, we need to return True here. 1499 if key == IS_PROMPT_TAG_KEY and lhs is None: 1500 return (comparator == "=" and value == "false") or ( 1501 comparator == "!=" and value == "true" 1502 ) 1503 1504 if lhs is None: 1505 return False 1506 1507 if comparator == "IN" and isinstance(value, (set, list)): 1508 return lhs in set(value) 1509 1510 return SearchUtils.get_comparison_func(comparator)(lhs, value) 1511 1512 @classmethod 1513 def filter(cls, model_versions, filter_string): 1514 """Filters a set of model versions based on a search filter string.""" 1515 model_versions = [mv for mv in model_versions if mv.current_stage != STAGE_DELETED_INTERNAL] 1516 if not filter_string: 1517 return model_versions 1518 parsed = cls.parse_search_filter(filter_string) 1519 1520 def model_version_matches(mv): 1521 return all(cls._does_model_version_match_clauses(mv, s) for s in parsed) 1522 1523 return [mv for mv in model_versions if model_version_matches(mv)] 1524 1525 @classmethod 1526 def parse_order_by_for_search_model_versions(cls, order_by): 1527 token_value, is_ascending = cls._parse_order_by_string(order_by) 1528 identifier = SearchExperimentsUtils._get_identifier( 1529 token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS 1530 ) 1531 return identifier["type"], identifier["key"], is_ascending 1532 1533 @classmethod 1534 def _get_sort_key(cls, order_by_list): 1535 order_by = [] 1536 parsed_order_by = map(cls.parse_order_by_for_search_model_versions, order_by_list or []) 1537 for type_, key, ascending in parsed_order_by: 1538 if type_ == "attribute": 1539 # Need to add this mapping because version is a keyword in sql 1540 if key == "version_number": 1541 key = "version" 1542 order_by.append((key, ascending)) 1543 else: 1544 raise MlflowException.invalid_parameter_value(f"Invalid order_by entity: {type_}") 1545 1546 # Add a tie-breaker 1547 if not any(key == "name" for key, _ in order_by): 1548 order_by.append(("name", True)) 1549 if not any(key == "version_number" for key, _ in order_by): 1550 order_by.append(("version", False)) 1551 1552 return lambda model_version: tuple( 1553 _apply_reversor(model_version, k, asc) for (k, asc) in order_by 1554 ) 1555 1556 @classmethod 1557 def sort(cls, model_versions, order_by_list): 1558 return sorted(model_versions, key=cls._get_sort_key(order_by_list)) 1559 1560 @classmethod 1561 def _get_model_version_search_identifier(cls, identifier, valid_attributes): 1562 tokens = identifier.split(".", maxsplit=1) 1563 if len(tokens) == 1: 1564 key = tokens[0] 1565 identifier = cls._ATTRIBUTE_IDENTIFIER 1566 else: 1567 entity_type, key = tokens 1568 valid_entity_types = ("attribute", "tag", "tags") 1569 if entity_type not in valid_entity_types: 1570 raise MlflowException.invalid_parameter_value( 1571 f"Invalid entity type '{entity_type}'. " 1572 f"Valid entity types are {valid_entity_types}" 1573 ) 1574 identifier = ( 1575 cls._TAG_IDENTIFIER if entity_type in ("tag", "tags") else cls._ATTRIBUTE_IDENTIFIER 1576 ) 1577 1578 if identifier == cls._ATTRIBUTE_IDENTIFIER and key not in valid_attributes: 1579 raise MlflowException.invalid_parameter_value( 1580 f"Invalid attribute key '{key}' specified. Valid keys are '{valid_attributes}'" 1581 ) 1582 1583 key = cls._trim_backticks(cls._strip_quotes(key)) 1584 return {"type": identifier, "key": key} 1585 1586 @classmethod 1587 def _get_comparison(cls, comparison): 1588 stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace] 1589 cls._validate_comparison(stripped_comparison) 1590 left, comparator, right = stripped_comparison 1591 comp = cls._get_model_version_search_identifier(left.value, cls.VALID_SEARCH_ATTRIBUTE_KEYS) 1592 comp["comparator"] = comparator.value.upper() 1593 comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), right) 1594 return comp 1595 1596 @classmethod 1597 def _get_value(cls, identifier_type, key, token): 1598 if identifier_type == cls._TAG_IDENTIFIER: 1599 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 1600 return cls._strip_quotes(token.value, expect_quoted_value=True) 1601 raise MlflowException( 1602 "Expected a quoted string value for " 1603 f"{identifier_type} (e.g. 'my-value'). Got value " 1604 f"{token.value}", 1605 error_code=INVALID_PARAMETER_VALUE, 1606 ) 1607 elif identifier_type == cls._ATTRIBUTE_IDENTIFIER: 1608 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 1609 return cls._strip_quotes(token.value, expect_quoted_value=True) 1610 elif isinstance(token, Parenthesis): 1611 if key != "run_id": 1612 raise MlflowException( 1613 "Only the 'run_id' attribute supports comparison with a list of quoted " 1614 "string values.", 1615 error_code=INVALID_PARAMETER_VALUE, 1616 ) 1617 return cls._parse_run_ids(token) 1618 elif token.ttype in cls.NUMERIC_VALUE_TYPES: 1619 if key not in cls.NUMERIC_ATTRIBUTES: 1620 raise MlflowException( 1621 f"Only the '{cls.NUMERIC_ATTRIBUTES}' attributes support comparison with " 1622 "numeric values.", 1623 error_code=INVALID_PARAMETER_VALUE, 1624 ) 1625 if token.ttype == TokenType.Literal.Number.Integer: 1626 return int(token.value) 1627 elif token.ttype == TokenType.Literal.Number.Float: 1628 return float(token.value) 1629 else: 1630 raise MlflowException( 1631 "Expected a quoted string value or a list of quoted string values for " 1632 f"attributes. Got value {token.value}", 1633 error_code=INVALID_PARAMETER_VALUE, 1634 ) 1635 else: 1636 # Expected to be either "param" or "metric". 1637 raise MlflowException( 1638 "Invalid identifier type. Expected one of " 1639 f"{[cls._ATTRIBUTE_IDENTIFIER, cls._TAG_IDENTIFIER]}.", 1640 error_code=INVALID_PARAMETER_VALUE, 1641 ) 1642 1643 @classmethod 1644 def _process_statement(cls, statement): 1645 tokens = _join_in_comparison_tokens(statement.tokens) 1646 invalids = list(filter(cls._invalid_statement_token_search_model_version, tokens)) 1647 if len(invalids) > 0: 1648 invalid_clauses = ", ".join(map(str, invalids)) 1649 raise MlflowException.invalid_parameter_value( 1650 f"Invalid clause(s) in filter string: {invalid_clauses}" 1651 ) 1652 return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)] 1653 1654 @classmethod 1655 def _invalid_statement_token_search_model_version(cls, token): 1656 if ( 1657 isinstance(token, Comparison) 1658 or token.is_whitespace 1659 or token.match(ttype=TokenType.Keyword, values=["AND"]) 1660 ): 1661 return False 1662 return True 1663 1664 @classmethod 1665 def parse_search_filter(cls, filter_string): 1666 if not filter_string: 1667 return [] 1668 try: 1669 parsed = sqlparse.parse(filter_string) 1670 except Exception: 1671 raise MlflowException( 1672 f"Error on parsing filter '{filter_string}'", error_code=INVALID_PARAMETER_VALUE 1673 ) 1674 if len(parsed) == 0 or not isinstance(parsed[0], Statement): 1675 raise MlflowException( 1676 f"Invalid filter '{filter_string}'. Could not be parsed.", 1677 error_code=INVALID_PARAMETER_VALUE, 1678 ) 1679 elif len(parsed) > 1: 1680 raise MlflowException( 1681 f"Search filter contained multiple expression {filter_string!r}. " 1682 "Provide AND-ed expression list.", 1683 error_code=INVALID_PARAMETER_VALUE, 1684 ) 1685 return cls._process_statement(parsed[0]) 1686 1687 1688 class SearchTraceUtils(SearchUtils): 1689 """ 1690 Utility class for searching traces. 1691 """ 1692 1693 VALID_SEARCH_ATTRIBUTE_KEYS = { 1694 "request_id", 1695 "timestamp", 1696 "timestamp_ms", 1697 "execution_time", 1698 "execution_time_ms", 1699 "end_time", 1700 "end_time_ms", 1701 "status", 1702 "client_request_id", 1703 # The following keys are mapped to tags or metadata 1704 "name", 1705 "run_id", 1706 "prompt", 1707 # The following key is mapped to span attributes 1708 "text", 1709 } 1710 VALID_ORDER_BY_ATTRIBUTE_KEYS = { 1711 "experiment_id", 1712 "timestamp", 1713 "timestamp_ms", 1714 "execution_time", 1715 "execution_time_ms", 1716 "end_time", 1717 "end_time_ms", 1718 "status", 1719 "request_id", 1720 # The following keys are mapped to tags or metadata 1721 "name", 1722 "run_id", 1723 } 1724 1725 NUMERIC_ATTRIBUTES = { 1726 "timestamp_ms", 1727 "timestamp", 1728 "execution_time_ms", 1729 "execution_time", 1730 "end_time_ms", 1731 "end_time", 1732 } 1733 1734 VALID_TAG_COMPARATORS = {"!=", "=", "LIKE", "ILIKE", "RLIKE", "IS NULL", "IS NOT NULL"} 1735 VALID_STRING_ATTRIBUTE_COMPARATORS = {"!=", "=", "IN", "NOT IN", "LIKE", "ILIKE", "RLIKE"} 1736 VALID_SPAN_ATTRIBUTE_COMPARATORS = {"!=", "=", "IN", "NOT IN", "LIKE", "ILIKE", "RLIKE"} 1737 VALID_METADATA_COMPARATORS = {"!=", "=", "LIKE", "ILIKE", "RLIKE", "IS NULL", "IS NOT NULL"} 1738 VALID_ASSESSMENT_COMPARATORS = {"!=", "=", "LIKE", "ILIKE", "RLIKE", "IS NULL", "IS NOT NULL"} 1739 1740 _REQUEST_METADATA_IDENTIFIER = "request_metadata" 1741 _TAG_IDENTIFIER = "tag" 1742 _ATTRIBUTE_IDENTIFIER = "attribute" 1743 _SPAN_IDENTIFIER = "span" 1744 _FEEDBACK_IDENTIFIER = "feedback" 1745 _EXPECTATION_IDENTIFIER = "expectation" 1746 _ISSUE_IDENTIFIER = "issue" 1747 1748 # These are aliases for the base identifiers 1749 # e.g. trace.status is equivalent to attribute.status 1750 _ALTERNATE_IDENTIFIERS = { 1751 "tags": _TAG_IDENTIFIER, 1752 "attributes": _ATTRIBUTE_IDENTIFIER, 1753 "trace": _ATTRIBUTE_IDENTIFIER, 1754 "metadata": _REQUEST_METADATA_IDENTIFIER, 1755 } 1756 _IDENTIFIERS = { 1757 _TAG_IDENTIFIER, 1758 _REQUEST_METADATA_IDENTIFIER, 1759 _ATTRIBUTE_IDENTIFIER, 1760 _SPAN_IDENTIFIER, 1761 _FEEDBACK_IDENTIFIER, 1762 _EXPECTATION_IDENTIFIER, 1763 _ISSUE_IDENTIFIER, 1764 } 1765 _VALID_IDENTIFIERS = _IDENTIFIERS | set(_ALTERNATE_IDENTIFIERS.keys()) 1766 1767 # Supported span attributes 1768 _SUPPORTED_SPAN_ATTRIBUTES = {"name", "type", "status"} 1769 _SPAN_CONTENT_KEY = "content" 1770 VALID_SPAN_CONTENT_COMPARATORS = {"LIKE", "ILIKE"} 1771 1772 # Supported issue attributes 1773 _SUPPORTED_ISSUE_ATTRIBUTES = {"id"} 1774 VALID_ISSUE_COMPARATORS = {"="} 1775 1776 SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS = { 1777 "name", 1778 "status", 1779 "request_id", 1780 "run_id", 1781 "client_request_id", 1782 } 1783 1784 # Some search keys are defined differently in the DB models. 1785 # E.g. "name" is mapped to TraceTagKey.TRACE_NAME 1786 SEARCH_KEY_TO_TAG = { 1787 "name": TraceTagKey.TRACE_NAME, 1788 "prompt": TraceTagKey.LINKED_PROMPTS, 1789 } 1790 SEARCH_KEY_TO_METADATA = { 1791 "run_id": TraceMetadataKey.SOURCE_RUN, 1792 } 1793 # Alias for attribute keys 1794 SEARCH_KEY_TO_ATTRIBUTE = { 1795 "timestamp": "timestamp_ms", 1796 "execution_time": "execution_time_ms", 1797 "end_time": "end_time_ms", 1798 } 1799 # Map trace search keys to span attributes for full text search 1800 SEARCH_KEY_TO_SPAN = { 1801 "text": _SPAN_CONTENT_KEY, 1802 } 1803 1804 @classmethod 1805 def filter(cls, traces, filter_string): 1806 """Filters a set of traces based on a search filter string.""" 1807 if not filter_string: 1808 return traces 1809 parsed = cls.parse_search_filter_for_search_traces(filter_string) 1810 1811 def trace_matches(trace): 1812 return all(cls._does_trace_match_clause(trace, s) for s in parsed) 1813 1814 return list(filter(trace_matches, traces)) 1815 1816 @classmethod 1817 def _does_trace_match_clause(cls, trace, sed): 1818 type_ = sed.get("type") 1819 key = sed.get("key") 1820 value = sed.get("value") 1821 comparator = sed.get("comparator").upper() 1822 1823 if cls.is_tag(type_, comparator): 1824 if comparator == "IS NULL": 1825 return key not in trace.tags 1826 elif comparator == "IS NOT NULL": 1827 return key in trace.tags 1828 lhs = trace.tags.get(key) 1829 elif cls.is_request_metadata(type_, comparator): 1830 if comparator == "IS NULL": 1831 return key not in trace.request_metadata 1832 elif comparator == "IS NOT NULL": 1833 return key in trace.request_metadata 1834 lhs = trace.request_metadata.get(key) 1835 elif cls.is_attribute(type_, key, comparator): 1836 lhs = getattr(trace, key) 1837 elif cls.is_span(type_, key, comparator): 1838 raise MlflowException( 1839 "Span filtering requires database support and cannot be performed " 1840 "on in-memory trace data.", 1841 error_code=INVALID_PARAMETER_VALUE, 1842 ) 1843 elif cls.is_assessment(type_, key, comparator): 1844 raise MlflowException( 1845 "Assessment filtering requires database support and cannot be performed " 1846 "on in-memory trace data.", 1847 error_code=INVALID_PARAMETER_VALUE, 1848 ) 1849 elif sed.get("type") == cls._TAG_IDENTIFIER: 1850 lhs = trace.tags.get(key) 1851 else: 1852 raise MlflowException( 1853 f"Invalid search key '{key}', supported are {cls.VALID_SEARCH_ATTRIBUTE_KEYS}", 1854 error_code=INVALID_PARAMETER_VALUE, 1855 ) 1856 if lhs is None: 1857 return False 1858 1859 return SearchUtils.get_comparison_func(comparator)(lhs, value) 1860 1861 @classmethod 1862 def sort(cls, traces, order_by_list): 1863 return sorted(traces, key=cls._get_sort_key(order_by_list)) 1864 1865 @classmethod 1866 def parse_order_by_for_search_traces(cls, order_by): 1867 token_value, is_ascending = cls._parse_order_by_string(order_by) 1868 identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS) 1869 identifier = cls._replace_key_to_tag_or_metadata(identifier) 1870 return identifier["type"], identifier["key"], is_ascending 1871 1872 @classmethod 1873 def parse_search_filter_for_search_traces(cls, filter_string): 1874 parsed = cls.parse_search_filter(filter_string) 1875 return [cls._replace_key_to_tag_or_metadata(p) for p in parsed] 1876 1877 @classmethod 1878 def _replace_key_to_tag_or_metadata(cls, parsed: dict[str, Any]): 1879 """ 1880 Replace search key to tag or metadata key if it is in the mapping. 1881 """ 1882 # Don't replace keys for span filters - they have their own namespace 1883 if parsed.get("type") == cls._SPAN_IDENTIFIER: 1884 return parsed 1885 1886 key = parsed.get("key").lower() 1887 if key in cls.SEARCH_KEY_TO_TAG: 1888 parsed["type"] = cls._TAG_IDENTIFIER 1889 parsed["key"] = cls.SEARCH_KEY_TO_TAG[key] 1890 elif key in cls.SEARCH_KEY_TO_METADATA: 1891 parsed["type"] = cls._REQUEST_METADATA_IDENTIFIER 1892 parsed["key"] = cls.SEARCH_KEY_TO_METADATA[key] 1893 elif key in cls.SEARCH_KEY_TO_SPAN: 1894 parsed["type"] = cls._SPAN_IDENTIFIER 1895 parsed["key"] = cls.SEARCH_KEY_TO_SPAN[key] 1896 elif key in cls.SEARCH_KEY_TO_ATTRIBUTE: 1897 parsed["key"] = cls.SEARCH_KEY_TO_ATTRIBUTE[key] 1898 return parsed 1899 1900 @classmethod 1901 def is_request_metadata(cls, key_type, comparator): 1902 if key_type == cls._REQUEST_METADATA_IDENTIFIER: 1903 if comparator not in cls.VALID_METADATA_COMPARATORS: 1904 raise MlflowException( 1905 f"Invalid comparator '{comparator}' not one of " 1906 f"'{cls.VALID_METADATA_COMPARATORS}'", 1907 error_code=INVALID_PARAMETER_VALUE, 1908 ) 1909 return True 1910 return False 1911 1912 @classmethod 1913 def is_span(cls, key_type, key_name, comparator): 1914 if key_type == cls._SPAN_IDENTIFIER: 1915 # Support span.attributes.<attribute> format 1916 if key_name.startswith("attributes."): 1917 # Extract the actual attribute name after "attributes." 1918 attr_name = key_name[len("attributes.") :] 1919 if not attr_name: 1920 raise MlflowException( 1921 "Span attribute name cannot be empty after 'attributes.'", 1922 error_code=INVALID_PARAMETER_VALUE, 1923 ) 1924 if comparator not in cls.VALID_SPAN_ATTRIBUTE_COMPARATORS: 1925 raise MlflowException( 1926 f"span.{key_name} comparator '{comparator}' not one of " 1927 f"'{cls.VALID_SPAN_ATTRIBUTE_COMPARATORS}'", 1928 error_code=INVALID_PARAMETER_VALUE, 1929 ) 1930 elif key_name in cls._SUPPORTED_SPAN_ATTRIBUTES: 1931 if comparator not in cls.VALID_SPAN_ATTRIBUTE_COMPARATORS: 1932 raise MlflowException( 1933 f"span.{key_name} comparator '{comparator}' not one of " 1934 f"'{cls.VALID_SPAN_ATTRIBUTE_COMPARATORS}'", 1935 error_code=INVALID_PARAMETER_VALUE, 1936 ) 1937 elif key_name == cls._SPAN_CONTENT_KEY: 1938 if comparator not in cls.VALID_SPAN_CONTENT_COMPARATORS: 1939 raise MlflowException( 1940 f"span.{key_name} comparator '{comparator}' not one of " 1941 f"'{cls.VALID_SPAN_CONTENT_COMPARATORS}'", 1942 error_code=INVALID_PARAMETER_VALUE, 1943 ) 1944 else: 1945 supported_attrs = ", ".join(sorted(cls._SUPPORTED_SPAN_ATTRIBUTES)) 1946 raise MlflowException( 1947 f"Invalid span attribute '{key_name}'. " 1948 f"Supported attributes: {supported_attrs}, attributes.<attribute_name>.", 1949 error_code=INVALID_PARAMETER_VALUE, 1950 ) 1951 return True 1952 return False 1953 1954 @classmethod 1955 def is_assessment(cls, key_type, key_name, comparator): 1956 if key_type in (cls._FEEDBACK_IDENTIFIER, cls._EXPECTATION_IDENTIFIER): 1957 if not key_name: 1958 raise MlflowException( 1959 "Assessment field name cannot be empty", 1960 error_code=INVALID_PARAMETER_VALUE, 1961 ) 1962 if comparator not in cls.VALID_ASSESSMENT_COMPARATORS: 1963 raise MlflowException( 1964 f"assessment.{key_name} comparator '{comparator}' not one of " 1965 f"'{cls.VALID_ASSESSMENT_COMPARATORS}'", 1966 error_code=INVALID_PARAMETER_VALUE, 1967 ) 1968 return True 1969 return False 1970 1971 @classmethod 1972 def is_issue(cls, key_type, key_name, comparator): 1973 if key_type == cls._ISSUE_IDENTIFIER: 1974 if key_name not in cls._SUPPORTED_ISSUE_ATTRIBUTES: 1975 supported_attrs = ", ".join(sorted(cls._SUPPORTED_ISSUE_ATTRIBUTES)) 1976 raise MlflowException( 1977 f"Invalid issue attribute '{key_name}'. " 1978 f"Supported attributes: {supported_attrs}", 1979 error_code=INVALID_PARAMETER_VALUE, 1980 ) 1981 if comparator not in cls.VALID_ISSUE_COMPARATORS: 1982 raise MlflowException( 1983 f"issue.{key_name} comparator '{comparator}' not one of " 1984 f"'{cls.VALID_ISSUE_COMPARATORS}'", 1985 error_code=INVALID_PARAMETER_VALUE, 1986 ) 1987 return True 1988 return False 1989 1990 @staticmethod 1991 def _get_sql_json_comparison_func( 1992 comparator: str, dialect: str 1993 ) -> Callable[["ColumnElement", str], "ClauseElement"]: 1994 """ 1995 Returns a comparison function for JSON-serialized values. 1996 1997 Assessment values are stored as JSON primitives in the database: 1998 - Boolean False -> false (no quotes in JSON) 1999 - Numeric value 5 -> 5 (no quotes in JSON) 2000 - String "yes" -> '"yes"' (WITH quotes in JSON) 2001 2002 For equality comparisons, we match either the raw JSON primitive value 2003 (for booleans and numeric values) or the JSON-serialized value (for strings). 2004 """ 2005 import sqlalchemy as sa 2006 2007 def mysql_json_equality_inequality_comparison( 2008 column: "ColumnElement", value: str 2009 ) -> "ClauseElement": 2010 # MySQL is case insensitive by default, so we need to use the BINARY operator 2011 # for case sensitive comparisons. We check both the raw value (for booleans/numbers) 2012 # and the JSON-serialized value (for strings). 2013 json_string_value = json.dumps(value) 2014 col_ref = f"{column.class_.__tablename__}.{column.key}" 2015 template = ( 2016 f"(({col_ref} = :value1 AND BINARY {col_ref} = :value1) OR " 2017 f"({col_ref} = :value2 AND BINARY {col_ref} = :value2))" 2018 ) 2019 if comparator == "!=": 2020 template = f"NOT {template}" 2021 return sa.text(template).bindparams( 2022 sa.bindparam("value1", value=value, unique=True), 2023 sa.bindparam("value2", value=json_string_value, unique=True), 2024 ) 2025 2026 def json_equality_inequality_comparison( 2027 column: "ColumnElement", value: str 2028 ) -> "ClauseElement": 2029 # MSSQL uses collation for case-sensitive comparisons on String columns 2030 if dialect == MSSQL: 2031 column = column.collate(_MSSQL_CASE_SENSITIVE_COLLATION) 2032 2033 json_string_value = json.dumps(value) 2034 clause = sa.or_(column == value, column == json_string_value) 2035 if comparator == "!=": 2036 clause = sa.not_(clause) 2037 return clause 2038 2039 if comparator not in ("=", "!="): 2040 return SearchTraceUtils.get_sql_comparison_func(comparator, dialect) 2041 elif dialect == MYSQL: 2042 return mysql_json_equality_inequality_comparison 2043 else: 2044 return json_equality_inequality_comparison 2045 2046 @classmethod 2047 def _valid_entity_type(cls, entity_type): 2048 entity_type = cls._trim_backticks(entity_type) 2049 if entity_type not in cls._VALID_IDENTIFIERS: 2050 raise MlflowException( 2051 f"Invalid entity type '{entity_type}'. Valid values are {cls._VALID_IDENTIFIERS}", 2052 error_code=INVALID_PARAMETER_VALUE, 2053 ) 2054 elif entity_type in cls._ALTERNATE_IDENTIFIERS: 2055 return cls._ALTERNATE_IDENTIFIERS[entity_type] 2056 else: 2057 return entity_type 2058 2059 @classmethod 2060 def _get_sort_key(cls, order_by_list): 2061 order_by = [] 2062 parsed_order_by = map(cls.parse_order_by_for_search_traces, order_by_list or []) 2063 for type_, key, ascending in parsed_order_by: 2064 if type_ == "attribute": 2065 order_by.append((key, ascending)) 2066 else: 2067 raise MlflowException.invalid_parameter_value( 2068 f"Invalid order_by entity `{type_}` with key `{key}`" 2069 ) 2070 2071 # Add a tie-breaker 2072 if not any(key == "timestamp_ms" for key, _ in order_by): 2073 order_by.append(("timestamp_ms", False)) 2074 if not any(key == "request_id" for key, _ in order_by): 2075 order_by.append(("request_id", True)) 2076 2077 return lambda trace: tuple(_apply_reversor(trace, k, asc) for (k, asc) in order_by) 2078 2079 @classmethod 2080 def _get_value(cls, identifier_type, key, token): 2081 if identifier_type == cls._TAG_IDENTIFIER: 2082 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 2083 return cls._strip_quotes(token.value, expect_quoted_value=True) 2084 elif isinstance(token, Parenthesis): 2085 return cls._parse_attribute_lists(token) 2086 raise MlflowException( 2087 "Expected a quoted string value for " 2088 f"{identifier_type} (e.g. 'my-value'). Got value " 2089 f"{token.value}", 2090 error_code=INVALID_PARAMETER_VALUE, 2091 ) 2092 elif identifier_type == cls._ATTRIBUTE_IDENTIFIER: 2093 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 2094 return cls._strip_quotes(token.value, expect_quoted_value=True) 2095 elif isinstance(token, Parenthesis): 2096 if key not in cls.SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS: 2097 raise MlflowException( 2098 f"Only attributes in {cls.SUPPORT_IN_COMPARISON_ATTRIBUTE_KEYS} " 2099 "supports comparison with a list of quoted string values.", 2100 error_code=INVALID_PARAMETER_VALUE, 2101 ) 2102 return cls._parse_attribute_lists(token) 2103 elif token.ttype in cls.NUMERIC_VALUE_TYPES: 2104 if key not in cls.NUMERIC_ATTRIBUTES: 2105 raise MlflowException( 2106 f"Only the '{cls.NUMERIC_ATTRIBUTES}' attributes support comparison with " 2107 "numeric values.", 2108 error_code=INVALID_PARAMETER_VALUE, 2109 ) 2110 if token.ttype == TokenType.Literal.Number.Integer: 2111 return int(token.value) 2112 elif token.ttype == TokenType.Literal.Number.Float: 2113 return float(token.value) 2114 else: 2115 raise MlflowException( 2116 "Expected a quoted string value or a list of quoted string values for " 2117 f"attributes. Got value {token.value}", 2118 error_code=INVALID_PARAMETER_VALUE, 2119 ) 2120 elif identifier_type == cls._REQUEST_METADATA_IDENTIFIER: 2121 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 2122 return cls._strip_quotes(token.value, expect_quoted_value=True) 2123 else: 2124 raise MlflowException( 2125 "Expected a quoted string value for " 2126 f"{identifier_type} (e.g. 'my-value'). Got value " 2127 f"{token.value}", 2128 error_code=INVALID_PARAMETER_VALUE, 2129 ) 2130 elif identifier_type == cls._SPAN_IDENTIFIER: 2131 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 2132 return cls._strip_quotes(token.value, expect_quoted_value=True) 2133 elif isinstance(token, Parenthesis): 2134 return cls._parse_attribute_lists(token) 2135 else: 2136 raise MlflowException( 2137 "Expected a quoted string value for " 2138 f"{identifier_type} (e.g. 'my-value'). Got value " 2139 f"{token.value}", 2140 error_code=INVALID_PARAMETER_VALUE, 2141 ) 2142 elif identifier_type in ( 2143 cls._FEEDBACK_IDENTIFIER, 2144 cls._EXPECTATION_IDENTIFIER, 2145 cls._ISSUE_IDENTIFIER, 2146 ): 2147 # Feedback and expectation values are stored as JSON, so we expect string values 2148 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 2149 return cls._strip_quotes(token.value, expect_quoted_value=True) 2150 else: 2151 raise MlflowException( 2152 "Expected a quoted string value for " 2153 f"{identifier_type} (e.g. 'my-value'). Got value " 2154 f"{token.value}", 2155 error_code=INVALID_PARAMETER_VALUE, 2156 ) 2157 else: 2158 # Expected to be either "param" or "metric". 2159 raise MlflowException( 2160 f"Invalid identifier type: {identifier_type}. " 2161 f"Expected one of {cls._VALID_IDENTIFIERS}.", 2162 error_code=INVALID_PARAMETER_VALUE, 2163 ) 2164 2165 @classmethod 2166 def _parse_attribute_lists(cls, token): 2167 return cls._parse_list_from_sql_token(token) 2168 2169 @classmethod 2170 def _process_statement(cls, statement): 2171 # check validity 2172 tokens = _join_in_comparison_tokens(statement.tokens, search_traces=True) 2173 invalids = list(filter(cls._invalid_statement_token_search_traces, tokens)) 2174 if len(invalids) > 0: 2175 invalid_clauses = ", ".join(f"'{token}'" for token in invalids) 2176 raise MlflowException( 2177 f"Invalid clause(s) in filter string: {invalid_clauses}", 2178 error_code=INVALID_PARAMETER_VALUE, 2179 ) 2180 return [cls._get_comparison(si) for si in tokens if isinstance(si, Comparison)] 2181 2182 @classmethod 2183 def _invalid_statement_token_search_traces(cls, token): 2184 if ( 2185 isinstance(token, Comparison) 2186 or token.is_whitespace 2187 or token.match(ttype=TokenType.Keyword, values=["AND"]) 2188 ): 2189 return False 2190 return True 2191 2192 @classmethod 2193 def _validate_comparison(cls, tokens): 2194 # Allow 2-token IS NULL / IS NOT NULL comparisons 2195 if len(tokens) == 2: 2196 comparator = tokens[1].value.upper() 2197 if comparator in ("IS NULL", "IS NOT NULL"): 2198 if not isinstance(tokens[0], Identifier): 2199 raise MlflowException( 2200 f"Invalid comparison clause. Expected 'Identifier' found '{tokens[0]}'", 2201 error_code=INVALID_PARAMETER_VALUE, 2202 ) 2203 return 2204 # Allow timestamp/timestamp_ms as the first token for trace search 2205 if ( 2206 len(tokens) == 3 2207 and not isinstance(tokens[0], Identifier) 2208 and tokens[0].match(ttype=TokenType.Name.Builtin, values=["timestamp", "timestamp_ms"]) 2209 ): 2210 return 2211 super()._validate_comparison(tokens) 2212 2213 @classmethod 2214 def _get_comparison(cls, comparison): 2215 stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace] 2216 cls._validate_comparison(stripped_comparison) 2217 2218 # Handle IS NULL / IS NOT NULL (2 tokens: identifier + comparator, no value) 2219 if len(stripped_comparison) == 2: 2220 comparator = stripped_comparison[1].value.upper() 2221 comp = cls._get_identifier( 2222 stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS 2223 ) 2224 comp["comparator"] = comparator 2225 comp["value"] = None 2226 return comp 2227 2228 comp = cls._get_identifier(stripped_comparison[0].value, cls.VALID_SEARCH_ATTRIBUTE_KEYS) 2229 comp["comparator"] = stripped_comparison[1].value 2230 comp["value"] = cls._get_value(comp.get("type"), comp.get("key"), stripped_comparison[2]) 2231 2232 if comp.get("type") == cls._SPAN_IDENTIFIER: 2233 cls.is_span(comp["type"], comp["key"], comp["comparator"]) 2234 2235 return comp 2236 2237 2238 @dataclass 2239 class TraceMetricsFilter: 2240 view_type: str 2241 entity: str 2242 key: str | None 2243 comparator: str 2244 value: Any 2245 2246 2247 class SearchTraceMetricsUtils(SearchTraceUtils): 2248 _VALID_VIEW_TYPES_TO_ENTITIES = { 2249 TraceMetricSearchKey.VIEW_TYPE: TraceMetricSearchKey.entity_to_key_requirement(), 2250 SpanMetricSearchKey.VIEW_TYPE: SpanMetricSearchKey.entity_to_key_requirement(), 2251 AssessmentMetricSearchKey.VIEW_TYPE: AssessmentMetricSearchKey.entity_to_key_requirement(), 2252 } 2253 2254 @classmethod 2255 def parse_search_filter(cls, filter_string: str) -> TraceMetricsFilter: 2256 parsed = super().parse_search_filter(filter_string) 2257 if len(parsed) != 1: 2258 raise MlflowException.invalid_parameter_value( 2259 f"Invalid filter: '{filter_string}'. Expected one filter clause." 2260 ) 2261 return parsed[0] 2262 2263 @classmethod 2264 def _process_statement(cls, statement: Statement) -> list[TraceMetricsFilter]: 2265 tokens = statement.tokens 2266 invalids = list(filter(cls._invalid_statement_token, tokens)) 2267 if len(invalids) > 0: 2268 invalid_clauses = ", ".join(map(str, invalids)) 2269 raise MlflowException.invalid_parameter_value( 2270 f"Invalid clause(s) in filter string: {invalid_clauses}" 2271 ) 2272 return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)] 2273 2274 @classmethod 2275 def _invalid_statement_token(cls, token: Token) -> bool: 2276 if isinstance(token, Comparison) or token.is_whitespace: 2277 return False 2278 return True 2279 2280 @classmethod 2281 def _get_identifier(cls, identifier) -> dict[str, Any]: 2282 error_message = ( 2283 f"Invalid identifier {identifier!r}. Columns should be specified as " 2284 f"'trace.<key>', 'span.<key>', 'assessment.<key>'." 2285 ) 2286 try: 2287 tokens = identifier.split(".", 2) 2288 match tokens: 2289 case [view_type, entity]: 2290 return cls._validate_metrics_fields(view_type, entity) 2291 case [view_type, entity, key]: 2292 return cls._validate_metrics_fields(view_type, entity, key) 2293 case _: 2294 raise MlflowException.invalid_parameter_value(error_message) 2295 except ValueError: 2296 raise MlflowException.invalid_parameter_value(error_message) 2297 2298 @classmethod 2299 def _validate_metrics_fields(cls, view_type, entity, key=None) -> dict[str, Any]: 2300 view_type = cls._trim_backticks(view_type) 2301 if view_type not in cls._VALID_VIEW_TYPES_TO_ENTITIES: 2302 raise MlflowException.invalid_parameter_value( 2303 f"Invalid view type '{view_type}'. " 2304 f"Valid values are {cls._VALID_VIEW_TYPES_TO_ENTITIES.keys()}" 2305 ) 2306 valid_entities = cls._VALID_VIEW_TYPES_TO_ENTITIES[view_type] 2307 if entity not in valid_entities: 2308 raise MlflowException.invalid_parameter_value( 2309 f"Invalid entity '{entity}' specified for view type '{view_type}'. " 2310 f"Valid entities are {list(valid_entities.keys())}" 2311 ) 2312 key_is_required = valid_entities[entity] 2313 if key_is_required and key is None: 2314 raise MlflowException.invalid_parameter_value( 2315 f"Filtering by {entity} requires a key, e.g. '{view_type}.{entity}.<key> = <value>'" 2316 ) 2317 elif not key_is_required and key is not None: 2318 raise MlflowException.invalid_parameter_value( 2319 f"Filtering by {entity} does not require a key, use '{view_type}.{entity}' instead" 2320 ) 2321 key = cls._trim_backticks(cls._strip_quotes(key)) if key else None 2322 return {"view_type": view_type, "entity": entity, "key": key} 2323 2324 @classmethod 2325 def _get_value(cls, entity, key, token): 2326 if token.ttype in cls.STRING_VALUE_TYPES or isinstance(token, Identifier): 2327 return cls._strip_quotes(token.value, expect_quoted_value=True) 2328 else: 2329 raise MlflowException.invalid_parameter_value( 2330 f"Expected a quoted string value for {entity} value (e.g. 'my-value'). " 2331 f"Got value {token.value}", 2332 ) 2333 2334 @classmethod 2335 def _get_comparison(cls, comparison: Comparison) -> TraceMetricsFilter: 2336 stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace] 2337 cls._validate_comparison(stripped_comparison) 2338 comp = cls._get_identifier(stripped_comparison[0].value) 2339 comparator = stripped_comparison[1].value 2340 if comparator != "=": 2341 raise MlflowException.invalid_parameter_value( 2342 f"Invalid comparator: '{comparator}', only '=' operator is supported" 2343 ) 2344 value = cls._get_value(comp["entity"], comp["key"], stripped_comparison[2]) 2345 2346 return TraceMetricsFilter( 2347 view_type=comp["view_type"], 2348 entity=comp["entity"], 2349 key=comp["key"], 2350 comparator=comparator, 2351 value=value, 2352 ) 2353 2354 2355 class SearchEvaluationDatasetsUtils(SearchUtils): 2356 """ 2357 Utility class for searching evaluation datasets. 2358 """ 2359 2360 VALID_SEARCH_ATTRIBUTE_KEYS = { 2361 "name", 2362 "created_time", 2363 "last_update_time", 2364 "created_by", 2365 "last_updated_by", 2366 } 2367 VALID_ORDER_BY_ATTRIBUTE_KEYS = {"name", "created_time", "last_update_time"} 2368 NUMERIC_ATTRIBUTES = {"created_time", "last_update_time"} 2369 VALID_TAG_COMPARATORS = {"!=", "=", "LIKE", "ILIKE"} 2370 2371 @classmethod 2372 def _invalid_statement_token(cls, token): 2373 if ( 2374 isinstance(token, Comparison) 2375 or token.is_whitespace 2376 or token.match(ttype=TokenType.Keyword, values=["AND"]) 2377 ): 2378 return False 2379 return True 2380 2381 @classmethod 2382 def _process_statement(cls, statement): 2383 tokens = _join_in_comparison_tokens(statement.tokens) 2384 invalids = list(filter(cls._invalid_statement_token, tokens)) 2385 if len(invalids) > 0: 2386 invalid_clauses = ", ".join(map(str, invalids)) 2387 raise MlflowException.invalid_parameter_value( 2388 f"Invalid clause(s) in filter string: {invalid_clauses}" 2389 ) 2390 return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)] 2391 2392 @classmethod 2393 def _get_identifier(cls, identifier, valid_attributes): 2394 tokens = identifier.split(".", maxsplit=1) 2395 if len(tokens) == 1: 2396 key = tokens[0] 2397 if key not in valid_attributes: 2398 raise MlflowException.invalid_parameter_value( 2399 f"Invalid attribute key '{key}' specified. Valid keys are: {valid_attributes}" 2400 ) 2401 return {"type": "attribute", "key": key} 2402 else: 2403 if tokens[0] == "tags": 2404 key = tokens[1] 2405 return {"type": "tag", "key": key} 2406 else: 2407 raise MlflowException.invalid_parameter_value( 2408 f"Invalid identifier token '{tokens[0]}' specified" 2409 ) 2410 2411 @classmethod 2412 def parse_order_by_for_search_evaluation_datasets(cls, order_by): 2413 token_value, is_ascending = cls._parse_order_by_string(order_by) 2414 identifier = cls._get_identifier(token_value.strip(), cls.VALID_ORDER_BY_ATTRIBUTE_KEYS) 2415 return identifier["type"], identifier["key"], is_ascending 2416 2417 @classmethod 2418 def is_string_attribute(cls, type_, key, comparator): 2419 return ( 2420 type_ == "attribute" 2421 and key not in cls.NUMERIC_ATTRIBUTES 2422 and comparator in cls.VALID_STRING_ATTRIBUTE_COMPARATORS 2423 ) 2424 2425 @classmethod 2426 def is_numeric_attribute(cls, type_, key, comparator): 2427 return ( 2428 type_ == "attribute" 2429 and key in cls.NUMERIC_ATTRIBUTES 2430 and comparator in cls.VALID_NUMERIC_ATTRIBUTE_COMPARATORS 2431 ) 2432 2433 2434 class SearchLoggedModelsUtils(SearchUtils): 2435 NUMERIC_ATTRIBUTES = { 2436 "creation_timestamp", 2437 "creation_time", 2438 "last_updated_timestamp", 2439 "last_updated_time", 2440 } 2441 VALID_SEARCH_ATTRIBUTE_KEYS = { 2442 "name", 2443 "model_id", 2444 "model_type", 2445 "status", 2446 "source_run_id", 2447 } | NUMERIC_ATTRIBUTES 2448 VALID_TAG_COMPARATORS = {"!=", "=", "LIKE", "ILIKE"} 2449 VALID_PARAM_COMPARATORS = {"!=", "=", "LIKE", "ILIKE"} 2450 VALID_ORDER_BY_ATTRIBUTE_KEYS = VALID_SEARCH_ATTRIBUTE_KEYS 2451 2452 @classmethod 2453 def _does_logged_model_match_clause( 2454 cls, 2455 model: LoggedModel, 2456 condition: dict[str, Any], 2457 datasets: list[dict[str, Any]] | None = None, 2458 ): 2459 key_type = condition.get("type") 2460 key = condition.get("key") 2461 value = condition.get("value") 2462 comparator = condition.get("comparator").upper() 2463 2464 key = SearchUtils.translate_key_alias(key) 2465 2466 if cls.is_metric(key_type, comparator): 2467 matching_metrics = [metric for metric in model.metrics if metric.key == key] 2468 if datasets: 2469 matching_metrics = [ 2470 metric 2471 for metric in matching_metrics 2472 if any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets) 2473 ] 2474 lhs = matching_metrics[0].value if matching_metrics else None 2475 value = float(value) 2476 elif cls.is_param(key_type, comparator): 2477 lhs = model.params.get(key, None) 2478 elif cls.is_tag(key_type, comparator): 2479 lhs = model.tags.get(key, None) 2480 elif cls.is_numeric_attribute(key_type, key, comparator): 2481 lhs = getattr(model, key) 2482 value = int(value) 2483 elif hasattr(model, key): 2484 lhs = getattr(model, key) 2485 else: 2486 raise MlflowException.invalid_parameter_value( 2487 f"Invalid logged model search key '{key}'", 2488 ) 2489 if lhs is None: 2490 return False 2491 2492 return SearchUtils.get_comparison_func(comparator)(lhs, value) 2493 2494 @classmethod 2495 def validate_list_supported(cls, key: str) -> None: 2496 """ 2497 Override to allow logged model attributes to be used with IN/NOT IN. 2498 """ 2499 2500 @classmethod 2501 def filter_logged_models( 2502 cls, 2503 models: list[LoggedModel], 2504 filter_string: str | None = None, 2505 datasets: list[dict[str, Any]] | None = None, 2506 ): 2507 """Filters a set of runs based on a search filter string and list of dataset filters.""" 2508 if not filter_string and not datasets: 2509 return models 2510 2511 parsed = cls.parse_search_filter(filter_string) 2512 2513 # If there are dataset filters but no metric filters in the filter string, 2514 # filter for models that have any metrics on the datasets 2515 if datasets and not any( 2516 cls.is_metric(s.get("type"), s.get("comparator").upper()) for s in parsed 2517 ): 2518 2519 def model_has_metrics_on_datasets(model): 2520 return any( 2521 any(cls._is_metric_on_dataset(metric, dataset) for dataset in datasets) 2522 for metric in model.metrics 2523 ) 2524 2525 models = [model for model in models if model_has_metrics_on_datasets(model)] 2526 2527 def model_matches(model): 2528 return all(cls._does_logged_model_match_clause(model, s, datasets) for s in parsed) 2529 2530 return [model for model in models if model_matches(model)] 2531 2532 @dataclass 2533 class OrderBy: 2534 field_name: str 2535 ascending: bool = True 2536 dataset_name: str | None = None 2537 dataset_digest: str | None = None 2538 2539 @classmethod 2540 def parse_order_by_for_logged_models(cls, order_by: dict[str, Any]) -> OrderBy: 2541 if not isinstance(order_by, dict): 2542 raise MlflowException.invalid_parameter_value( 2543 "`order_by` must be a list of dictionaries." 2544 ) 2545 field_name = order_by.get("field_name") 2546 if field_name is None: 2547 raise MlflowException.invalid_parameter_value( 2548 "`field_name` in the `order_by` clause must be specified." 2549 ) 2550 if "." in field_name: 2551 entity = field_name.split(".", 1)[0] 2552 if entity != "metrics": 2553 raise MlflowException.invalid_parameter_value( 2554 f"Invalid order by field name: {entity}, only `metrics.<name>` is allowed." 2555 ) 2556 else: 2557 field_name = field_name.strip() 2558 if field_name not in cls.VALID_ORDER_BY_ATTRIBUTE_KEYS: 2559 raise MlflowException.invalid_parameter_value( 2560 f"Invalid order by field name: {field_name}." 2561 ) 2562 ascending = order_by.get("ascending", True) 2563 if ascending not in [True, False]: 2564 raise MlflowException.invalid_parameter_value( 2565 "Value of `ascending` in the `order_by` clause must be a boolean, got " 2566 f"{type(ascending)} for field {field_name}." 2567 ) 2568 dataset_name = order_by.get("dataset_name") 2569 dataset_digest = order_by.get("dataset_digest") 2570 if dataset_digest and not dataset_name: 2571 raise MlflowException.invalid_parameter_value( 2572 "`dataset_digest` can only be specified if `dataset_name` is also specified." 2573 ) 2574 2575 aliases = { 2576 "creation_time": "creation_timestamp", 2577 } 2578 return cls.OrderBy( 2579 aliases.get(field_name, field_name), ascending, dataset_name, dataset_digest 2580 ) 2581 2582 @classmethod 2583 def _apply_reversor_for_logged_model( 2584 cls, 2585 model: LoggedModel, 2586 order_by: OrderBy, 2587 ): 2588 if "." in order_by.field_name: 2589 metric_key = order_by.field_name.split(".", 1)[1] 2590 filtered_metrics = sorted( 2591 [ 2592 m 2593 for m in model.metrics 2594 if m.key == metric_key 2595 and (not order_by.dataset_name or m.dataset_name == order_by.dataset_name) 2596 and (not order_by.dataset_digest or m.dataset_digest == order_by.dataset_digest) 2597 ], 2598 key=lambda metric: metric.timestamp, 2599 reverse=True, 2600 ) 2601 latest_metric_value = None if len(filtered_metrics) == 0 else filtered_metrics[0].value 2602 return ( 2603 _LoggedModelMetricComp(latest_metric_value) 2604 if order_by.ascending 2605 else _Reversor(latest_metric_value) 2606 ) 2607 else: 2608 value = getattr(model, order_by.field_name) 2609 return value if order_by.ascending else _Reversor(value) 2610 2611 @classmethod 2612 def _get_sort_key(cls, order_by_list: list[dict[str, Any]] | None): 2613 parsed_order_by = list(map(cls.parse_order_by_for_logged_models, order_by_list or [])) 2614 2615 # Add a tie-breaker 2616 if not any(order_by.field_name == "creation_timestamp" for order_by in parsed_order_by): 2617 parsed_order_by.append(cls.OrderBy("creation_timestamp", False)) 2618 if not any(order_by.field_name == "model_id" for order_by in parsed_order_by): 2619 parsed_order_by.append(cls.OrderBy("model_id")) 2620 2621 return lambda logged_model: tuple( 2622 cls._apply_reversor_for_logged_model(logged_model, order_by) 2623 for order_by in parsed_order_by 2624 ) 2625 2626 @classmethod 2627 def sort(cls, models, order_by_list): 2628 return sorted(models, key=cls._get_sort_key(order_by_list)) 2629 2630 2631 class _LoggedModelMetricComp: 2632 def __init__(self, obj): 2633 self.obj = obj 2634 2635 def __eq__(self, other): 2636 return other.obj == self.obj 2637 2638 def __lt__(self, other): 2639 if self.obj is None: 2640 return False 2641 if other.obj is None: 2642 return True 2643 return self.obj < other.obj 2644 2645 2646 @dataclass 2647 class SearchLoggedModelsPaginationToken: 2648 experiment_ids: list[str] 2649 filter_string: str | None = None 2650 order_by: list[dict[str, Any]] | None = None 2651 offset: int = 0 2652 2653 def to_json(self) -> str: 2654 return json.dumps(asdict(self)) 2655 2656 def encode(self) -> str: 2657 return base64.b64encode(self.to_json().encode("utf-8")).decode("utf-8") 2658 2659 @classmethod 2660 def decode(cls, token: str) -> "SearchLoggedModelsPaginationToken": 2661 try: 2662 token = json.loads(base64.b64decode(token.encode("utf-8")).decode("utf-8")) 2663 except json.JSONDecodeError as e: 2664 raise MlflowException.invalid_parameter_value(f"Invalid page token: {token}. {e}") 2665 2666 return cls( 2667 experiment_ids=token.get("experiment_ids"), 2668 filter_string=token.get("filter_string") or None, 2669 order_by=token.get("order_by") or None, 2670 offset=token.get("offset") or 0, 2671 ) 2672 2673 def validate( 2674 self, 2675 experiment_ids: list[str], 2676 filter_string: str | None, 2677 order_by: list[dict[str, Any]] | None, 2678 ) -> None: 2679 if self.experiment_ids != experiment_ids: 2680 raise MlflowException.invalid_parameter_value( 2681 f"Experiment IDs in the page token do not match the requested experiment IDs. " 2682 f"Expected: {experiment_ids}. Found: {self.experiment_ids}" 2683 ) 2684 2685 if self.filter_string != filter_string: 2686 raise MlflowException.invalid_parameter_value( 2687 f"Filter string in the page token does not match the requested filter string. " 2688 f"Expected: {filter_string}. Found: {self.filter_string}" 2689 ) 2690 2691 if self.order_by != order_by: 2692 raise MlflowException.invalid_parameter_value( 2693 f"Order by in the page token does not match the requested order by. " 2694 f"Expected: {order_by}. Found: {self.order_by}" 2695 ) 2696 2697 2698 class SearchIssuesUtils(SearchUtils): 2699 """Utility class for parsing issue search filters.""" 2700 2701 VALID_SEARCH_ATTRIBUTE_KEYS = {"status", "source_run_id"} 2702 VALID_STRING_ATTRIBUTE_COMPARATORS = {"=", "!="} 2703 2704 @classmethod 2705 def _invalid_statement_token(cls, token): 2706 """Check if a token is invalid for issue search filters.""" 2707 if ( 2708 isinstance(token, Comparison) 2709 or token.is_whitespace 2710 or token.match(ttype=TokenType.Keyword, values=["AND"]) 2711 ): 2712 return False 2713 return True 2714 2715 @classmethod 2716 def _process_statement(cls, statement): 2717 """Process SQL statement and extract comparisons.""" 2718 tokens = _join_in_comparison_tokens(statement.tokens) 2719 invalids = list(filter(cls._invalid_statement_token, tokens)) 2720 if len(invalids) > 0: 2721 invalid_clauses = ", ".join(map(str, invalids)) 2722 raise MlflowException.invalid_parameter_value( 2723 f"Invalid clause(s) in filter string: {invalid_clauses}" 2724 ) 2725 return [cls._get_comparison(t) for t in tokens if isinstance(t, Comparison)] 2726 2727 @classmethod 2728 def _get_comparison(cls, comparison): 2729 """Extract comparison details from a Comparison token.""" 2730 stripped_comparison = [token for token in comparison.tokens if not token.is_whitespace] 2731 2732 if len(stripped_comparison) != 3: 2733 raise MlflowException.invalid_parameter_value( 2734 f"Invalid comparison: expected 3 tokens, got {len(stripped_comparison)}" 2735 ) 2736 2737 left, comparator_token, right = stripped_comparison 2738 2739 # Get field name 2740 if not isinstance(left, Identifier): 2741 raise MlflowException.invalid_parameter_value( 2742 f"Invalid comparison: left side must be an identifier, got {type(left)}" 2743 ) 2744 2745 key = cls._strip_quotes(left.value).strip() 2746 if key not in cls.VALID_SEARCH_ATTRIBUTE_KEYS: 2747 raise MlflowException.invalid_parameter_value( 2748 f"Invalid filter field '{key}'. Supported fields: {cls.VALID_SEARCH_ATTRIBUTE_KEYS}" 2749 ) 2750 2751 # Get comparator 2752 comparator = comparator_token.value.upper() 2753 if comparator not in cls.VALID_STRING_ATTRIBUTE_COMPARATORS: 2754 raise MlflowException.invalid_parameter_value( 2755 f"Invalid comparator '{comparator}'. " 2756 f"Supported comparators: {cls.VALID_STRING_ATTRIBUTE_COMPARATORS}" 2757 ) 2758 2759 # Get value 2760 value = cls._strip_quotes(right.value).strip() 2761 2762 return { 2763 "type": cls._ATTRIBUTE_IDENTIFIER, 2764 "key": key, 2765 "comparator": comparator, 2766 "value": value, 2767 }