/ src / evidently / llm / optimization / optimizer.py
optimizer.py
  1  import datetime
  2  import random
  3  import uuid
  4  from abc import ABC
  5  from abc import abstractmethod
  6  from asyncio import Lock
  7  from typing import Any
  8  from typing import ClassVar
  9  from typing import Dict
 10  from typing import Generic
 11  from typing import Iterable
 12  from typing import List
 13  from typing import Optional
 14  from typing import Tuple
 15  from typing import Type
 16  from typing import TypeVar
 17  
 18  import numpy as np
 19  import pandas as pd
 20  from sklearn.model_selection import train_test_split
 21  
 22  from evidently._pydantic_compat import BaseModel
 23  from evidently._pydantic_compat import Field
 24  from evidently._pydantic_compat import PrivateAttr
 25  from evidently.legacy.core import new_id
 26  from evidently.legacy.options.base import Options
 27  from evidently.legacy.utils.llm.wrapper import LLMWrapper
 28  from evidently.legacy.utils.llm.wrapper import get_llm_wrapper
 29  from evidently.llm.optimization.errors import OptimizationConfigurationError
 30  from evidently.pydantic_utils import AutoAliasMixin
 31  from evidently.pydantic_utils import EvidentlyBaseModel
 32  
 33  
 34  class Params:
 35      """Parameter keys used throughout the optimizer context and configuration."""
 36  
 37      BasePrompt = "base_prompt"
 38      EarlyStop = "early_stop"
 39      LLMClassification = "llm_classification"
 40      Options = "options"
 41      Scorer = "scorer"
 42      Executor = "executor"
 43      Dataset = "dataset"
 44      TargetValue = "target_value"
 45      Task = "task"
 46      OptimizerPromptInstructions = "optimizer_prompt_instructions"
 47      DataSplitShares = "data_split_shares"
 48  
 49  
 50  class LLMDatasetSplit:
 51      Train = "train"
 52      Val = "val"
 53      Test = "test"
 54      All = "all"
 55  
 56  
 57  class LLMDatasetColumns:
 58      InputValues = "input_values"
 59      Target = "target"
 60      Reasoning = "reasoning"
 61      Scores = "scores"
 62      Pred = "preds"
 63      PredReasoning = "preds_reasoning"
 64      PredScores = "preds_scores"
 65  
 66  
 67  class LLMDatasetSplitView:
 68      """View of a specific split in an LLM dataset.
 69  
 70      Provides filtered access to dataset columns for a specific split
 71      (train, val, test, or all).
 72      """
 73  
 74      def __init__(self, dataset: "LLMDataset", split_name: str):
 75          """Initialize a split view.
 76  
 77          Args:
 78          * `dataset`: `LLMDataset` to view.
 79          * `split_name`: Name of the split to view.
 80          """
 81          self.split_name = split_name
 82          self.dataset = dataset
 83  
 84      @property
 85      def input_values(self) -> Optional[pd.Series]:
 86          """Get input values for this split.
 87  
 88          Returns:
 89          * `pd.Series` with input values, or `None` if not available.
 90          """
 91          if (
 92              self.split_name == LLMDatasetSplit.All
 93              or self.split_name not in self.dataset.split_masks
 94              or self.dataset.input_values is None
 95          ):
 96              return self.dataset.input_values
 97          return self.dataset.input_values[self.dataset.split_masks[self.split_name]]
 98  
 99      @property
100      def target(self) -> Optional[pd.Series]:
101          """Get target values for this split.
102  
103          Returns:
104          * `pd.Series` with target values, or `None` if not available.
105          """
106          if (
107              self.split_name == LLMDatasetSplit.All
108              or self.split_name not in self.dataset.split_masks
109              or self.dataset.target is None
110          ):
111              return self.dataset.target
112          return self.dataset.target[self.dataset.split_masks[self.split_name]]
113  
114      @property
115      def reasoning(self) -> Optional[pd.Series]:
116          """Get reasoning for this split.
117  
118          Returns:
119          * `pd.Series` with reasoning, or `None` if not available.
120          """
121          if (
122              self.split_name == LLMDatasetSplit.All
123              or self.split_name not in self.dataset.split_masks
124              or self.dataset.reasoning is None
125          ):
126              return self.dataset.reasoning
127          return self.dataset.reasoning[self.dataset.split_masks[self.split_name]]
128  
129      @property
130      def predictions(self) -> Optional[pd.Series]:
131          """Get predictions for this split.
132  
133          Returns:
134          * `pd.Series` with predictions, or `None` if not available.
135          """
136          if (
137              self.split_name == LLMDatasetSplit.All
138              or self.split_name not in self.dataset.split_masks
139              or self.dataset.predictions is None
140          ):
141              return self.dataset.predictions
142          return self.dataset.predictions[self.dataset.split_masks[self.split_name]]
143  
144      @property
145      def prediction_reasoning(self) -> Optional[pd.Series]:
146          """Get prediction reasoning for this split.
147  
148          Returns:
149          * `pd.Series` with prediction reasoning, or `None` if not available.
150          """
151          if (
152              self.split_name == LLMDatasetSplit.All
153              or self.split_name not in self.dataset.split_masks
154              or self.dataset.prediction_reasoning is None
155          ):
156              return self.dataset.prediction_reasoning
157          return self.dataset.prediction_reasoning[self.dataset.split_masks[self.split_name]]
158  
159  
160  class LLMDataset(BaseModel):
161      """Dataset for LLM optimization tasks.
162  
163      Contains input values, targets, reasoning, predictions, and split masks
164      for train/val/test splits.
165      """
166  
167      input_values: pd.Series
168      """Input values (e.g., prompts)."""
169      target: Optional[pd.Series] = None
170      """Optional target values (e.g., expected outputs)."""
171      reasoning: Optional[pd.Series] = None
172      """Optional reasoning for targets."""
173      predictions: Optional[pd.Series] = None
174      """Optional model predictions."""
175      prediction_reasoning: Optional[pd.Series] = None
176      """Optional reasoning for predictions."""
177      split_masks: Dict[str, pd.Series] = Field(default_factory=dict)
178      """Dictionary mapping split names to boolean masks."""
179  
180      def __getitem__(self, split_name: str) -> LLMDatasetSplitView:
181          """Get a view of a specific split.
182  
183          Args:
184          * `split_name`: Name of the split to view.
185  
186          Returns:
187          * `LLMDatasetSplitView` for the specified split.
188          """
189          return LLMDatasetSplitView(self, split_name)
190  
191      def get_mask(self, split_name: str) -> pd.Series:
192          """Get the boolean mask for a split.
193  
194          Args:
195          * `split_name`: Name of the split.
196  
197          Returns:
198          * `pd.Series` with boolean values indicating which rows belong to this split.
199          """
200          if split_name not in self.split_masks or split_name == LLMDatasetSplit.All:
201              return pd.Series(True, index=np.arange(len(self.input_values)))
202          return self.split_masks[split_name]
203  
204      def split(self, shares: Dict[str, Optional[float]], seed: Optional[int]) -> None:
205          """Split the dataset into train/val/test splits.
206  
207          Creates boolean masks for each split based on the specified shares.
208          Uses stratified splitting if target values are available.
209  
210          Args:
211          * `shares`: Dictionary mapping split names to proportions (must sum to 1.0).
212          * `seed`: Optional random seed for reproducibility.
213  
214          Raises:
215          * `ValueError`: If shares don't sum to 1.0.
216          """
217          n = len(self.input_values)
218          indices = np.arange(n)
219  
220          # normalize shares (ignoring None)
221          specified = {k: v for k, v in shares.items() if v is not None}
222          total = sum(specified.values())
223          if specified and not np.isclose(total, 1.0):
224              raise ValueError("Specified shares must sum to 1.0")
225  
226          remaining_indices: np.ndarray[Any, Any] = indices
227          remaining_shares = dict(specified)
228  
229          self.split_masks = {}
230  
231          for i, (name, share) in enumerate(shares.items()):
232              if share is None:
233                  continue
234  
235              if i == len(specified) - 1:
236                  split_indices = remaining_indices
237              else:
238                  test_size = share / sum(remaining_shares.values())
239                  _, split_indices = train_test_split(
240                      remaining_indices,
241                      test_size=test_size,
242                      stratify=self.target[remaining_indices] if self.target is not None else None,  # type: ignore[index]
243                      random_state=seed,
244                  )
245                  remaining_indices = np.setdiff1d(remaining_indices, split_indices)
246                  remaining_shares.pop(name)
247  
248              mask = pd.Series(False, index=np.arange(n))
249              mask.iloc[split_indices] = True
250              self.split_masks[name] = mask
251  
252  
253  class LLMResultDataset(BaseModel):
254      """Dataset containing LLM optimization results.
255  
256      Stores predictions, reasoning, and scores from optimization runs.
257      """
258  
259      predictions: Optional[pd.Series] = None
260      """Optional model predictions."""
261      reasoning: Optional[pd.Series] = None
262      """Optional reasoning for predictions."""
263      scores: Optional[pd.Series] = None
264      """Optional scores for predictions."""
265  
266      @property
267      def has_predictions(self) -> bool:
268          """Check if predictions are available.
269  
270          Returns:
271          * `True` if predictions exist, `False` otherwise.
272          """
273          return self.predictions is not None
274  
275      @property
276      def has_reasoning(self) -> bool:
277          """Check if reasoning is available.
278  
279          Returns:
280          * `True` if reasoning exists, `False` otherwise.
281          """
282          return self.reasoning is not None
283  
284      @property
285      def has_scores(self) -> bool:
286          """Check if scores are available.
287  
288          Returns:
289          * `True` if scores exist, `False` otherwise.
290          """
291          return self.scores is not None
292  
293      def get_predictions(self, mask: Optional[pd.Series] = None) -> pd.Series:
294          """Get predictions, optionally filtered by mask.
295  
296          Args:
297          * `mask`: Optional boolean mask to filter predictions.
298  
299          Returns:
300          * `pd.Series` with predictions.
301  
302          Raises:
303          * `KeyError`: If predictions are not available.
304          """
305          if self.predictions is None:
306              raise KeyError("Dataset has no predictions")
307          if mask is not None:
308              return self.predictions[mask]
309          return self.predictions
310  
311      def get_reasoning(self, mask: Optional[pd.Series] = None) -> pd.Series:
312          """Get reasoning, optionally filtered by mask.
313  
314          Args:
315          * `mask`: Optional boolean mask to filter reasoning.
316  
317          Returns:
318          * `pd.Series` with reasoning.
319  
320          Raises:
321          * `KeyError`: If reasoning is not available.
322          """
323          if self.reasoning is None:
324              raise KeyError("Dataset has no reasoning")
325          if mask is not None:
326              return self.reasoning[mask]
327          return self.reasoning
328  
329      def get_scores(self, mask: Optional[pd.Series] = None) -> pd.Series:
330          """Get scores, optionally filtered by mask.
331  
332          Args:
333          * `mask`: Optional boolean mask to filter scores.
334  
335          Returns:
336          * `pd.Series` with scores.
337  
338          Raises:
339          * `KeyError`: If scores are not available.
340          """
341          if self.scores is None:
342              raise KeyError("Dataset has no scores")
343          if mask is not None:
344              return self.scores[mask]
345          return self.scores
346  
347      def items(self) -> Iterable[Tuple[str, pd.Series]]:
348          """Iterate over all Series fields.
349  
350          Yields:
351          * Tuples of (field_name, pd.Series) for each Series field.
352          """
353          for field_name in self.__fields__:
354              value = getattr(self, field_name)
355              if isinstance(value, pd.Series):
356                  yield field_name, value
357  
358  
359  class OptimizerConfig(AutoAliasMixin, EvidentlyBaseModel):
360      """Configuration for the optimizer, including provider and model."""
361  
362      __alias_type__: ClassVar = "optimizer_config"
363  
364      class Config:
365          is_base_type = True
366  
367      provider: str = "openai"
368      """LLM provider name."""
369      model: str = "gpt-4o-mini"
370      """LLM model name."""
371      verbose: bool = False
372      """Whether to print optimization progress."""
373      seed: Optional[int] = None
374      """Optional random seed for reproducibility."""
375  
376  
377  LogID = uuid.UUID
378  T = TypeVar("T")
379  
380  
381  class OptimizerLog(AutoAliasMixin, EvidentlyBaseModel, ABC):
382      """Base class for all optimizer logs.
383  
384      Logs track events and steps during optimization runs.
385      """
386  
387      __alias_type__: ClassVar = "optimizer_log"
388      __is_step__: ClassVar[bool] = False
389  
390      class Config:
391          is_base_type = True
392  
393      id: LogID = Field(default_factory=new_id)
394      """Unique log identifier."""
395      timestamp: datetime.datetime = Field(default_factory=datetime.datetime.now)
396      """Timestamp when the log was created."""
397  
398      @abstractmethod
399      def message(self) -> str:
400          """Get a human-readable message for this log.
401  
402          Returns:
403          * String message describing the log event.
404          """
405          raise NotImplementedError()
406  
407      def full_message(self):
408          """Get the full message for this log.
409  
410          Returns:
411          * String message (may be overridden by subclasses for additional context).
412          """
413          return self.message()
414  
415  
416  class LLMCallOptimizerLog(OptimizerLog, ABC):
417      """Log entry for an LLM API call during optimization.
418  
419      Tracks token usage for cost monitoring and rate limiting.
420      """
421  
422      input_tokens: int
423      """Number of input tokens used."""
424      output_tokens: int
425      """Number of output tokens used."""
426  
427  
428  TLog = TypeVar("TLog", bound=OptimizerLog)
429  LogsDict = Dict[LogID, OptimizerLog]
430  RunID = int
431  
432  
433  class OptimizerRun(BaseModel):
434      """A single optimization run with logs and statistics.
435  
436      Tracks all events, steps, and LLM calls during an optimization run.
437      """
438  
439      run_id: RunID
440      """Unique identifier for this run."""
441      logs: LogsDict = {}
442      """Dictionary of log entries by log ID."""
443      seed: Optional[int]
444      """Random seed used for this run."""
445      start_time: datetime.datetime = Field(default_factory=datetime.datetime.now)
446      """Timestamp when the run started."""
447      _context: "OptimizerContext" = PrivateAttr()
448  
449      def bind(self, context: "OptimizerContext") -> "OptimizerRun":
450          """Bind this run to an optimizer context.
451  
452          Args:
453          * `context`: `OptimizerContext` to bind to.
454  
455          Returns:
456          * Self for method chaining.
457          """
458          self._context = context
459          return self
460  
461      @property
462      def context(self) -> "OptimizerContext":
463          """Get the optimizer context.
464  
465          Returns:
466          * `OptimizerContext` associated with this run.
467          """
468          return self._context
469  
470      def add_log(self, log: OptimizerLog):
471          """Add a log entry to the context and log its message.
472  
473          Args:
474          * `log`: `OptimizerLog` to add.
475          """
476          if self._context.config.verbose:
477              print(f"[{self.run_id}]", log.message())
478          self.logs[log.id] = log
479  
480      def get_log(self, log_id: LogID) -> OptimizerLog:
481          """Retrieve a log entry by its ID.
482  
483          Args:
484          * `log_id`: ID of the log to retrieve.
485  
486          Returns:
487          * `OptimizerLog` with the specified ID.
488  
489          Raises:
490          * `KeyError`: If log ID not found.
491          """
492          if log_id not in self.logs:
493              raise KeyError(f"Log with id {log_id} not found")
494          return self.logs[log_id]
495  
496      def get_logs(self, log_type: Type[TLog]) -> List[TLog]:
497          """Get all logs of a specific type.
498  
499          Args:
500          * `log_type`: Type of logs to retrieve.
501  
502          Returns:
503          * List of logs of the specified type.
504          """
505          return [log for log in self.logs.values() if isinstance(log, log_type)]
506  
507      def get_last_log(self, log_type: Type[TLog]) -> Optional[TLog]:
508          """Get the most recent log of a specific type, or None if not found.
509  
510          Args:
511          * `log_type`: Type of log to retrieve.
512  
513          Returns:
514          * Most recent log of the specified type, or `None` if not found.
515          """
516          for log in reversed(self.logs.values()):
517              if isinstance(log, log_type):
518                  return log
519          return None
520  
521      def print_stats(self):
522          """Print statistics about this optimization run.
523  
524          Displays run ID, seed, number of steps, elapsed time, token usage,
525          and a timeline of all log events.
526          """
527          print(f"Optimizer Run [{self.run_id}], seed [{self.seed}]")
528          log_list = list(self.logs.values())
529          last_log = log_list[-1]
530          print(f"Steps: {sum(1 for log in log_list if log.__is_step__)}")
531          print(f"Time: {(last_log.timestamp - self.start_time).total_seconds():.1f}s")
532          print("Input tokens:", sum(log.input_tokens for log in self.get_logs(LLMCallOptimizerLog)))
533          print("Output tokens:", sum(log.output_tokens for log in self.get_logs(LLMCallOptimizerLog)))
534          start_time = self.start_time
535          for log in log_list:
536              elapsed = (log.timestamp - start_time).total_seconds()
537              start_time = log.timestamp
538              print(f"\t[{elapsed:.1f}s] {log.full_message()}")
539  
540  
541  def get_seeded_nth_int(seed: int, n: int) -> int:
542      rng = random.Random(seed)
543      value = None
544      for _ in range(n):
545          value = rng.getrandbits(32)
546      return value or seed
547  
548  
549  _run_lock = Lock()
550  
551  
552  class OptimizerContext(BaseModel):
553      """Holds the state, parameters, and logs for an optimization run.
554  
555      Manages configuration, parameters, and multiple optimization runs.
556      Can be locked to prevent parameter changes during optimization.
557      """
558  
559      config: OptimizerConfig
560      """Optimizer configuration."""
561      params: Dict[str, Any]
562      """Dictionary of optimization parameters."""
563      runs: List[OptimizerRun]
564      """List of completed optimization runs."""
565      locked: bool = False
566      """Whether the context is locked (prevents parameter changes)."""
567  
568      # @classmethod
569      # def load(cls: Type[Self], path: str) -> Self:
570      #     raise NotImplementedError()
571      #
572      # def save(self, path: str):
573      #     raise NotImplementedError()
574  
575      async def new_run(self) -> OptimizerRun:
576          """Create a new optimization run.
577  
578          Generates a new run with a unique ID and seed (if configured).
579  
580          Returns:
581          * New `OptimizerRun` bound to this context.
582          """
583          async with _run_lock:
584              seed = None if self.config.seed is None else get_seeded_nth_int(self.config.seed, len(self.runs))
585              run = OptimizerRun(run_id=len(self.runs), seed=seed).bind(self)
586              self.runs.append(run)
587              return run
588  
589      @property
590      def llm_wrapper(self) -> LLMWrapper:
591          """Get the LLM wrapper for this context.
592  
593          Returns:
594          * `LLMWrapper` configured with the context's provider and model.
595          """
596          return get_llm_wrapper(self.config.provider, self.config.model, self.params[Params.Options])
597  
598      @property
599      def options(self) -> Options:
600          """Get the processing options.
601  
602          Returns:
603          * `Options` from the context parameters.
604          """
605          return self.params[Params.Options]
606  
607      def set_param(self, name: str, value: Any):
608          """Set a parameter in the context. Raises if context is locked.
609  
610          Args:
611          * `name`: Parameter name.
612          * `value`: Parameter value.
613  
614          Raises:
615          * `OptimizationConfigurationError`: If context is locked.
616          """
617          if self.locked:
618              raise OptimizationConfigurationError("OptimizerContext is locked")
619          self.params[name] = value
620          if isinstance(value, InitContextMixin):
621              value.on_param_set(self)
622  
623      def get_param(self, name: str, cls: Optional[Type[T]] = None, missing_error_message: Optional[str] = None) -> T:
624          """Get a parameter, optionally checking type and raising with a custom message if missing.
625  
626          Args:
627          * `name`: Parameter name.
628          * `cls`: Optional type to validate against.
629          * `missing_error_message`: Optional custom error message if parameter is missing.
630  
631          Returns:
632          * Parameter value.
633  
634          Raises:
635          * `OptimizationConfigurationError`: If context is not locked, parameter is missing, or type mismatch.
636          """
637          if not self.locked:
638              raise OptimizationConfigurationError("Attempted to get param from unlocked OptimizerContext")
639          value = self.params.get(name, None)
640          if value is None and missing_error_message is not None:
641              raise OptimizationConfigurationError(missing_error_message)
642          if cls is not None and not isinstance(value, cls):
643              raise OptimizationConfigurationError(f"Expected {cls.__name__}, got {type(value).__name__}")
644          return value
645  
646      def lock(self):
647          """Lock the context to prevent further parameter changes.
648  
649          After locking, parameters can only be read, not modified.
650          """
651          self.locked = True
652          for value in self.params.values():
653              if isinstance(value, InitContextMixin):
654                  value.on_context_lock(self)
655  
656      def has_param(self, name: str):
657          """Check if a parameter exists.
658  
659          Args:
660          * `name`: Parameter name to check.
661  
662          Returns:
663          * `True` if parameter exists, `False` otherwise.
664          """
665          return name in self.params
666  
667  
668  class InitContextMixin(ABC):
669      """Mixin for objects that need to alter OptimizerContext on init.
670  
671      Provides hooks for objects to react to context parameter changes
672      and context locking.
673      """
674  
675      def on_param_set(self, context: OptimizerContext):
676          """Called when this object is set as a context parameter.
677  
678          Args:
679          * `context`: `OptimizerContext` that this parameter was set in.
680          """
681          pass
682  
683      def on_context_lock(self, context: OptimizerContext):
684          """Called when the context is locked.
685  
686          Args:
687          * `context`: `OptimizerContext` that was locked.
688          """
689          pass
690  
691  
692  TOptimizerConfig = TypeVar("TOptimizerConfig", bound=OptimizerConfig)
693  
694  
695  class BaseOptimizer(ABC, Generic[TOptimizerConfig]):
696      """Base class for all optimizers, handling context and parameter management.
697  
698      Provides common functionality for managing optimizer configuration,
699      parameters, and runs.
700      """
701  
702      def __init__(self, name: str, config: TOptimizerConfig, checkpoint_path: Optional[str] = None):
703          """Initialize the optimizer.
704  
705          Args:
706          * `name`: Name of the optimizer.
707          * `config`: `OptimizerConfig` with optimizer settings.
708          * `checkpoint_path`: Optional path for saving/loading checkpoints.
709          """
710          self.name = name
711          self.checkpoint_path = checkpoint_path or f".optimizer_checkpoint_{name}"
712          # if os.path.exists(self.checkpoint_path):
713          #     self.context = OptimizerContext.load(self.checkpoint_path)
714          #     if self.context.config != config:
715          #         raise ValueError(f"Optimizer config changed, cannot load from checkpoint at {self.checkpoint_path}")
716          # else:
717          #     self.context = OptimizerContext(config=config, inputs={}, logs={})
718          self.context = OptimizerContext(config=config, params={}, runs=[])
719  
720      def _lock(self):
721          """Lock the context to prevent further parameter changes."""
722          self.context.lock()
723  
724      def set_param(self, name: str, value: Any):
725          """Set a parameter in the optimizer context.
726  
727          Args:
728          * `name`: Parameter name.
729          * `value`: Parameter value.
730          """
731          self.context.set_param(name, value)
732  
733      def get_param(self, name: str, cls: Optional[Type[T]] = None, missing_error_message: Optional[str] = None) -> T:
734          """Get a parameter from the optimizer context.
735  
736          Args:
737          * `name`: Parameter name.
738          * `cls`: Optional type to validate against.
739          * `missing_error_message`: Optional custom error message if parameter is missing.
740  
741          Returns:
742          * Parameter value.
743          """
744          return self.context.get_param(name, cls, missing_error_message)
745  
746      def has_param(self, name: str) -> bool:
747          """Check if a parameter exists.
748  
749          Args:
750          * `name`: Parameter name to check.
751  
752          Returns:
753          * `True` if parameter exists, `False` otherwise.
754          """
755          return self.context.has_param(name)