optimizer.py
1 import datetime 2 import random 3 import uuid 4 from abc import ABC 5 from abc import abstractmethod 6 from asyncio import Lock 7 from typing import Any 8 from typing import ClassVar 9 from typing import Dict 10 from typing import Generic 11 from typing import Iterable 12 from typing import List 13 from typing import Optional 14 from typing import Tuple 15 from typing import Type 16 from typing import TypeVar 17 18 import numpy as np 19 import pandas as pd 20 from sklearn.model_selection import train_test_split 21 22 from evidently._pydantic_compat import BaseModel 23 from evidently._pydantic_compat import Field 24 from evidently._pydantic_compat import PrivateAttr 25 from evidently.legacy.core import new_id 26 from evidently.legacy.options.base import Options 27 from evidently.legacy.utils.llm.wrapper import LLMWrapper 28 from evidently.legacy.utils.llm.wrapper import get_llm_wrapper 29 from evidently.llm.optimization.errors import OptimizationConfigurationError 30 from evidently.pydantic_utils import AutoAliasMixin 31 from evidently.pydantic_utils import EvidentlyBaseModel 32 33 34 class Params: 35 """Parameter keys used throughout the optimizer context and configuration.""" 36 37 BasePrompt = "base_prompt" 38 EarlyStop = "early_stop" 39 LLMClassification = "llm_classification" 40 Options = "options" 41 Scorer = "scorer" 42 Executor = "executor" 43 Dataset = "dataset" 44 TargetValue = "target_value" 45 Task = "task" 46 OptimizerPromptInstructions = "optimizer_prompt_instructions" 47 DataSplitShares = "data_split_shares" 48 49 50 class LLMDatasetSplit: 51 Train = "train" 52 Val = "val" 53 Test = "test" 54 All = "all" 55 56 57 class LLMDatasetColumns: 58 InputValues = "input_values" 59 Target = "target" 60 Reasoning = "reasoning" 61 Scores = "scores" 62 Pred = "preds" 63 PredReasoning = "preds_reasoning" 64 PredScores = "preds_scores" 65 66 67 class LLMDatasetSplitView: 68 """View of a specific split in an LLM dataset. 69 70 Provides filtered access to dataset columns for a specific split 71 (train, val, test, or all). 72 """ 73 74 def __init__(self, dataset: "LLMDataset", split_name: str): 75 """Initialize a split view. 76 77 Args: 78 * `dataset`: `LLMDataset` to view. 79 * `split_name`: Name of the split to view. 80 """ 81 self.split_name = split_name 82 self.dataset = dataset 83 84 @property 85 def input_values(self) -> Optional[pd.Series]: 86 """Get input values for this split. 87 88 Returns: 89 * `pd.Series` with input values, or `None` if not available. 90 """ 91 if ( 92 self.split_name == LLMDatasetSplit.All 93 or self.split_name not in self.dataset.split_masks 94 or self.dataset.input_values is None 95 ): 96 return self.dataset.input_values 97 return self.dataset.input_values[self.dataset.split_masks[self.split_name]] 98 99 @property 100 def target(self) -> Optional[pd.Series]: 101 """Get target values for this split. 102 103 Returns: 104 * `pd.Series` with target values, or `None` if not available. 105 """ 106 if ( 107 self.split_name == LLMDatasetSplit.All 108 or self.split_name not in self.dataset.split_masks 109 or self.dataset.target is None 110 ): 111 return self.dataset.target 112 return self.dataset.target[self.dataset.split_masks[self.split_name]] 113 114 @property 115 def reasoning(self) -> Optional[pd.Series]: 116 """Get reasoning for this split. 117 118 Returns: 119 * `pd.Series` with reasoning, or `None` if not available. 120 """ 121 if ( 122 self.split_name == LLMDatasetSplit.All 123 or self.split_name not in self.dataset.split_masks 124 or self.dataset.reasoning is None 125 ): 126 return self.dataset.reasoning 127 return self.dataset.reasoning[self.dataset.split_masks[self.split_name]] 128 129 @property 130 def predictions(self) -> Optional[pd.Series]: 131 """Get predictions for this split. 132 133 Returns: 134 * `pd.Series` with predictions, or `None` if not available. 135 """ 136 if ( 137 self.split_name == LLMDatasetSplit.All 138 or self.split_name not in self.dataset.split_masks 139 or self.dataset.predictions is None 140 ): 141 return self.dataset.predictions 142 return self.dataset.predictions[self.dataset.split_masks[self.split_name]] 143 144 @property 145 def prediction_reasoning(self) -> Optional[pd.Series]: 146 """Get prediction reasoning for this split. 147 148 Returns: 149 * `pd.Series` with prediction reasoning, or `None` if not available. 150 """ 151 if ( 152 self.split_name == LLMDatasetSplit.All 153 or self.split_name not in self.dataset.split_masks 154 or self.dataset.prediction_reasoning is None 155 ): 156 return self.dataset.prediction_reasoning 157 return self.dataset.prediction_reasoning[self.dataset.split_masks[self.split_name]] 158 159 160 class LLMDataset(BaseModel): 161 """Dataset for LLM optimization tasks. 162 163 Contains input values, targets, reasoning, predictions, and split masks 164 for train/val/test splits. 165 """ 166 167 input_values: pd.Series 168 """Input values (e.g., prompts).""" 169 target: Optional[pd.Series] = None 170 """Optional target values (e.g., expected outputs).""" 171 reasoning: Optional[pd.Series] = None 172 """Optional reasoning for targets.""" 173 predictions: Optional[pd.Series] = None 174 """Optional model predictions.""" 175 prediction_reasoning: Optional[pd.Series] = None 176 """Optional reasoning for predictions.""" 177 split_masks: Dict[str, pd.Series] = Field(default_factory=dict) 178 """Dictionary mapping split names to boolean masks.""" 179 180 def __getitem__(self, split_name: str) -> LLMDatasetSplitView: 181 """Get a view of a specific split. 182 183 Args: 184 * `split_name`: Name of the split to view. 185 186 Returns: 187 * `LLMDatasetSplitView` for the specified split. 188 """ 189 return LLMDatasetSplitView(self, split_name) 190 191 def get_mask(self, split_name: str) -> pd.Series: 192 """Get the boolean mask for a split. 193 194 Args: 195 * `split_name`: Name of the split. 196 197 Returns: 198 * `pd.Series` with boolean values indicating which rows belong to this split. 199 """ 200 if split_name not in self.split_masks or split_name == LLMDatasetSplit.All: 201 return pd.Series(True, index=np.arange(len(self.input_values))) 202 return self.split_masks[split_name] 203 204 def split(self, shares: Dict[str, Optional[float]], seed: Optional[int]) -> None: 205 """Split the dataset into train/val/test splits. 206 207 Creates boolean masks for each split based on the specified shares. 208 Uses stratified splitting if target values are available. 209 210 Args: 211 * `shares`: Dictionary mapping split names to proportions (must sum to 1.0). 212 * `seed`: Optional random seed for reproducibility. 213 214 Raises: 215 * `ValueError`: If shares don't sum to 1.0. 216 """ 217 n = len(self.input_values) 218 indices = np.arange(n) 219 220 # normalize shares (ignoring None) 221 specified = {k: v for k, v in shares.items() if v is not None} 222 total = sum(specified.values()) 223 if specified and not np.isclose(total, 1.0): 224 raise ValueError("Specified shares must sum to 1.0") 225 226 remaining_indices: np.ndarray[Any, Any] = indices 227 remaining_shares = dict(specified) 228 229 self.split_masks = {} 230 231 for i, (name, share) in enumerate(shares.items()): 232 if share is None: 233 continue 234 235 if i == len(specified) - 1: 236 split_indices = remaining_indices 237 else: 238 test_size = share / sum(remaining_shares.values()) 239 _, split_indices = train_test_split( 240 remaining_indices, 241 test_size=test_size, 242 stratify=self.target[remaining_indices] if self.target is not None else None, # type: ignore[index] 243 random_state=seed, 244 ) 245 remaining_indices = np.setdiff1d(remaining_indices, split_indices) 246 remaining_shares.pop(name) 247 248 mask = pd.Series(False, index=np.arange(n)) 249 mask.iloc[split_indices] = True 250 self.split_masks[name] = mask 251 252 253 class LLMResultDataset(BaseModel): 254 """Dataset containing LLM optimization results. 255 256 Stores predictions, reasoning, and scores from optimization runs. 257 """ 258 259 predictions: Optional[pd.Series] = None 260 """Optional model predictions.""" 261 reasoning: Optional[pd.Series] = None 262 """Optional reasoning for predictions.""" 263 scores: Optional[pd.Series] = None 264 """Optional scores for predictions.""" 265 266 @property 267 def has_predictions(self) -> bool: 268 """Check if predictions are available. 269 270 Returns: 271 * `True` if predictions exist, `False` otherwise. 272 """ 273 return self.predictions is not None 274 275 @property 276 def has_reasoning(self) -> bool: 277 """Check if reasoning is available. 278 279 Returns: 280 * `True` if reasoning exists, `False` otherwise. 281 """ 282 return self.reasoning is not None 283 284 @property 285 def has_scores(self) -> bool: 286 """Check if scores are available. 287 288 Returns: 289 * `True` if scores exist, `False` otherwise. 290 """ 291 return self.scores is not None 292 293 def get_predictions(self, mask: Optional[pd.Series] = None) -> pd.Series: 294 """Get predictions, optionally filtered by mask. 295 296 Args: 297 * `mask`: Optional boolean mask to filter predictions. 298 299 Returns: 300 * `pd.Series` with predictions. 301 302 Raises: 303 * `KeyError`: If predictions are not available. 304 """ 305 if self.predictions is None: 306 raise KeyError("Dataset has no predictions") 307 if mask is not None: 308 return self.predictions[mask] 309 return self.predictions 310 311 def get_reasoning(self, mask: Optional[pd.Series] = None) -> pd.Series: 312 """Get reasoning, optionally filtered by mask. 313 314 Args: 315 * `mask`: Optional boolean mask to filter reasoning. 316 317 Returns: 318 * `pd.Series` with reasoning. 319 320 Raises: 321 * `KeyError`: If reasoning is not available. 322 """ 323 if self.reasoning is None: 324 raise KeyError("Dataset has no reasoning") 325 if mask is not None: 326 return self.reasoning[mask] 327 return self.reasoning 328 329 def get_scores(self, mask: Optional[pd.Series] = None) -> pd.Series: 330 """Get scores, optionally filtered by mask. 331 332 Args: 333 * `mask`: Optional boolean mask to filter scores. 334 335 Returns: 336 * `pd.Series` with scores. 337 338 Raises: 339 * `KeyError`: If scores are not available. 340 """ 341 if self.scores is None: 342 raise KeyError("Dataset has no scores") 343 if mask is not None: 344 return self.scores[mask] 345 return self.scores 346 347 def items(self) -> Iterable[Tuple[str, pd.Series]]: 348 """Iterate over all Series fields. 349 350 Yields: 351 * Tuples of (field_name, pd.Series) for each Series field. 352 """ 353 for field_name in self.__fields__: 354 value = getattr(self, field_name) 355 if isinstance(value, pd.Series): 356 yield field_name, value 357 358 359 class OptimizerConfig(AutoAliasMixin, EvidentlyBaseModel): 360 """Configuration for the optimizer, including provider and model.""" 361 362 __alias_type__: ClassVar = "optimizer_config" 363 364 class Config: 365 is_base_type = True 366 367 provider: str = "openai" 368 """LLM provider name.""" 369 model: str = "gpt-4o-mini" 370 """LLM model name.""" 371 verbose: bool = False 372 """Whether to print optimization progress.""" 373 seed: Optional[int] = None 374 """Optional random seed for reproducibility.""" 375 376 377 LogID = uuid.UUID 378 T = TypeVar("T") 379 380 381 class OptimizerLog(AutoAliasMixin, EvidentlyBaseModel, ABC): 382 """Base class for all optimizer logs. 383 384 Logs track events and steps during optimization runs. 385 """ 386 387 __alias_type__: ClassVar = "optimizer_log" 388 __is_step__: ClassVar[bool] = False 389 390 class Config: 391 is_base_type = True 392 393 id: LogID = Field(default_factory=new_id) 394 """Unique log identifier.""" 395 timestamp: datetime.datetime = Field(default_factory=datetime.datetime.now) 396 """Timestamp when the log was created.""" 397 398 @abstractmethod 399 def message(self) -> str: 400 """Get a human-readable message for this log. 401 402 Returns: 403 * String message describing the log event. 404 """ 405 raise NotImplementedError() 406 407 def full_message(self): 408 """Get the full message for this log. 409 410 Returns: 411 * String message (may be overridden by subclasses for additional context). 412 """ 413 return self.message() 414 415 416 class LLMCallOptimizerLog(OptimizerLog, ABC): 417 """Log entry for an LLM API call during optimization. 418 419 Tracks token usage for cost monitoring and rate limiting. 420 """ 421 422 input_tokens: int 423 """Number of input tokens used.""" 424 output_tokens: int 425 """Number of output tokens used.""" 426 427 428 TLog = TypeVar("TLog", bound=OptimizerLog) 429 LogsDict = Dict[LogID, OptimizerLog] 430 RunID = int 431 432 433 class OptimizerRun(BaseModel): 434 """A single optimization run with logs and statistics. 435 436 Tracks all events, steps, and LLM calls during an optimization run. 437 """ 438 439 run_id: RunID 440 """Unique identifier for this run.""" 441 logs: LogsDict = {} 442 """Dictionary of log entries by log ID.""" 443 seed: Optional[int] 444 """Random seed used for this run.""" 445 start_time: datetime.datetime = Field(default_factory=datetime.datetime.now) 446 """Timestamp when the run started.""" 447 _context: "OptimizerContext" = PrivateAttr() 448 449 def bind(self, context: "OptimizerContext") -> "OptimizerRun": 450 """Bind this run to an optimizer context. 451 452 Args: 453 * `context`: `OptimizerContext` to bind to. 454 455 Returns: 456 * Self for method chaining. 457 """ 458 self._context = context 459 return self 460 461 @property 462 def context(self) -> "OptimizerContext": 463 """Get the optimizer context. 464 465 Returns: 466 * `OptimizerContext` associated with this run. 467 """ 468 return self._context 469 470 def add_log(self, log: OptimizerLog): 471 """Add a log entry to the context and log its message. 472 473 Args: 474 * `log`: `OptimizerLog` to add. 475 """ 476 if self._context.config.verbose: 477 print(f"[{self.run_id}]", log.message()) 478 self.logs[log.id] = log 479 480 def get_log(self, log_id: LogID) -> OptimizerLog: 481 """Retrieve a log entry by its ID. 482 483 Args: 484 * `log_id`: ID of the log to retrieve. 485 486 Returns: 487 * `OptimizerLog` with the specified ID. 488 489 Raises: 490 * `KeyError`: If log ID not found. 491 """ 492 if log_id not in self.logs: 493 raise KeyError(f"Log with id {log_id} not found") 494 return self.logs[log_id] 495 496 def get_logs(self, log_type: Type[TLog]) -> List[TLog]: 497 """Get all logs of a specific type. 498 499 Args: 500 * `log_type`: Type of logs to retrieve. 501 502 Returns: 503 * List of logs of the specified type. 504 """ 505 return [log for log in self.logs.values() if isinstance(log, log_type)] 506 507 def get_last_log(self, log_type: Type[TLog]) -> Optional[TLog]: 508 """Get the most recent log of a specific type, or None if not found. 509 510 Args: 511 * `log_type`: Type of log to retrieve. 512 513 Returns: 514 * Most recent log of the specified type, or `None` if not found. 515 """ 516 for log in reversed(self.logs.values()): 517 if isinstance(log, log_type): 518 return log 519 return None 520 521 def print_stats(self): 522 """Print statistics about this optimization run. 523 524 Displays run ID, seed, number of steps, elapsed time, token usage, 525 and a timeline of all log events. 526 """ 527 print(f"Optimizer Run [{self.run_id}], seed [{self.seed}]") 528 log_list = list(self.logs.values()) 529 last_log = log_list[-1] 530 print(f"Steps: {sum(1 for log in log_list if log.__is_step__)}") 531 print(f"Time: {(last_log.timestamp - self.start_time).total_seconds():.1f}s") 532 print("Input tokens:", sum(log.input_tokens for log in self.get_logs(LLMCallOptimizerLog))) 533 print("Output tokens:", sum(log.output_tokens for log in self.get_logs(LLMCallOptimizerLog))) 534 start_time = self.start_time 535 for log in log_list: 536 elapsed = (log.timestamp - start_time).total_seconds() 537 start_time = log.timestamp 538 print(f"\t[{elapsed:.1f}s] {log.full_message()}") 539 540 541 def get_seeded_nth_int(seed: int, n: int) -> int: 542 rng = random.Random(seed) 543 value = None 544 for _ in range(n): 545 value = rng.getrandbits(32) 546 return value or seed 547 548 549 _run_lock = Lock() 550 551 552 class OptimizerContext(BaseModel): 553 """Holds the state, parameters, and logs for an optimization run. 554 555 Manages configuration, parameters, and multiple optimization runs. 556 Can be locked to prevent parameter changes during optimization. 557 """ 558 559 config: OptimizerConfig 560 """Optimizer configuration.""" 561 params: Dict[str, Any] 562 """Dictionary of optimization parameters.""" 563 runs: List[OptimizerRun] 564 """List of completed optimization runs.""" 565 locked: bool = False 566 """Whether the context is locked (prevents parameter changes).""" 567 568 # @classmethod 569 # def load(cls: Type[Self], path: str) -> Self: 570 # raise NotImplementedError() 571 # 572 # def save(self, path: str): 573 # raise NotImplementedError() 574 575 async def new_run(self) -> OptimizerRun: 576 """Create a new optimization run. 577 578 Generates a new run with a unique ID and seed (if configured). 579 580 Returns: 581 * New `OptimizerRun` bound to this context. 582 """ 583 async with _run_lock: 584 seed = None if self.config.seed is None else get_seeded_nth_int(self.config.seed, len(self.runs)) 585 run = OptimizerRun(run_id=len(self.runs), seed=seed).bind(self) 586 self.runs.append(run) 587 return run 588 589 @property 590 def llm_wrapper(self) -> LLMWrapper: 591 """Get the LLM wrapper for this context. 592 593 Returns: 594 * `LLMWrapper` configured with the context's provider and model. 595 """ 596 return get_llm_wrapper(self.config.provider, self.config.model, self.params[Params.Options]) 597 598 @property 599 def options(self) -> Options: 600 """Get the processing options. 601 602 Returns: 603 * `Options` from the context parameters. 604 """ 605 return self.params[Params.Options] 606 607 def set_param(self, name: str, value: Any): 608 """Set a parameter in the context. Raises if context is locked. 609 610 Args: 611 * `name`: Parameter name. 612 * `value`: Parameter value. 613 614 Raises: 615 * `OptimizationConfigurationError`: If context is locked. 616 """ 617 if self.locked: 618 raise OptimizationConfigurationError("OptimizerContext is locked") 619 self.params[name] = value 620 if isinstance(value, InitContextMixin): 621 value.on_param_set(self) 622 623 def get_param(self, name: str, cls: Optional[Type[T]] = None, missing_error_message: Optional[str] = None) -> T: 624 """Get a parameter, optionally checking type and raising with a custom message if missing. 625 626 Args: 627 * `name`: Parameter name. 628 * `cls`: Optional type to validate against. 629 * `missing_error_message`: Optional custom error message if parameter is missing. 630 631 Returns: 632 * Parameter value. 633 634 Raises: 635 * `OptimizationConfigurationError`: If context is not locked, parameter is missing, or type mismatch. 636 """ 637 if not self.locked: 638 raise OptimizationConfigurationError("Attempted to get param from unlocked OptimizerContext") 639 value = self.params.get(name, None) 640 if value is None and missing_error_message is not None: 641 raise OptimizationConfigurationError(missing_error_message) 642 if cls is not None and not isinstance(value, cls): 643 raise OptimizationConfigurationError(f"Expected {cls.__name__}, got {type(value).__name__}") 644 return value 645 646 def lock(self): 647 """Lock the context to prevent further parameter changes. 648 649 After locking, parameters can only be read, not modified. 650 """ 651 self.locked = True 652 for value in self.params.values(): 653 if isinstance(value, InitContextMixin): 654 value.on_context_lock(self) 655 656 def has_param(self, name: str): 657 """Check if a parameter exists. 658 659 Args: 660 * `name`: Parameter name to check. 661 662 Returns: 663 * `True` if parameter exists, `False` otherwise. 664 """ 665 return name in self.params 666 667 668 class InitContextMixin(ABC): 669 """Mixin for objects that need to alter OptimizerContext on init. 670 671 Provides hooks for objects to react to context parameter changes 672 and context locking. 673 """ 674 675 def on_param_set(self, context: OptimizerContext): 676 """Called when this object is set as a context parameter. 677 678 Args: 679 * `context`: `OptimizerContext` that this parameter was set in. 680 """ 681 pass 682 683 def on_context_lock(self, context: OptimizerContext): 684 """Called when the context is locked. 685 686 Args: 687 * `context`: `OptimizerContext` that was locked. 688 """ 689 pass 690 691 692 TOptimizerConfig = TypeVar("TOptimizerConfig", bound=OptimizerConfig) 693 694 695 class BaseOptimizer(ABC, Generic[TOptimizerConfig]): 696 """Base class for all optimizers, handling context and parameter management. 697 698 Provides common functionality for managing optimizer configuration, 699 parameters, and runs. 700 """ 701 702 def __init__(self, name: str, config: TOptimizerConfig, checkpoint_path: Optional[str] = None): 703 """Initialize the optimizer. 704 705 Args: 706 * `name`: Name of the optimizer. 707 * `config`: `OptimizerConfig` with optimizer settings. 708 * `checkpoint_path`: Optional path for saving/loading checkpoints. 709 """ 710 self.name = name 711 self.checkpoint_path = checkpoint_path or f".optimizer_checkpoint_{name}" 712 # if os.path.exists(self.checkpoint_path): 713 # self.context = OptimizerContext.load(self.checkpoint_path) 714 # if self.context.config != config: 715 # raise ValueError(f"Optimizer config changed, cannot load from checkpoint at {self.checkpoint_path}") 716 # else: 717 # self.context = OptimizerContext(config=config, inputs={}, logs={}) 718 self.context = OptimizerContext(config=config, params={}, runs=[]) 719 720 def _lock(self): 721 """Lock the context to prevent further parameter changes.""" 722 self.context.lock() 723 724 def set_param(self, name: str, value: Any): 725 """Set a parameter in the optimizer context. 726 727 Args: 728 * `name`: Parameter name. 729 * `value`: Parameter value. 730 """ 731 self.context.set_param(name, value) 732 733 def get_param(self, name: str, cls: Optional[Type[T]] = None, missing_error_message: Optional[str] = None) -> T: 734 """Get a parameter from the optimizer context. 735 736 Args: 737 * `name`: Parameter name. 738 * `cls`: Optional type to validate against. 739 * `missing_error_message`: Optional custom error message if parameter is missing. 740 741 Returns: 742 * Parameter value. 743 """ 744 return self.context.get_param(name, cls, missing_error_message) 745 746 def has_param(self, name: str) -> bool: 747 """Check if a parameter exists. 748 749 Args: 750 * `name`: Parameter name to check. 751 752 Returns: 753 * `True` if parameter exists, `False` otherwise. 754 """ 755 return self.context.has_param(name)