/ mlflow / assistant / providers / base.py
base.py
  1  from abc import ABC, abstractmethod
  2  from functools import lru_cache
  3  from pathlib import Path
  4  from typing import Any, AsyncGenerator, Callable
  5  
  6  from mlflow.assistant.config import AssistantConfig, ProviderConfig
  7  
  8  
  9  @lru_cache(maxsize=10)
 10  def load_config(name: str) -> ProviderConfig:
 11      cfg = AssistantConfig.load()
 12      if not cfg or name not in cfg.providers:
 13          raise RuntimeError(f"Provider configuration not found for {name}")
 14      return cfg.providers[name]
 15  
 16  
 17  def clear_config_cache() -> None:
 18      """Clear the config cache to pick up config changes."""
 19      load_config.cache_clear()
 20  
 21  
 22  class ProviderNotConfiguredError(Exception):
 23      """Raised when a provider is not properly configured."""
 24  
 25  
 26  class CLINotInstalledError(ProviderNotConfiguredError):
 27      """Raised when the provider CLI is not installed."""
 28  
 29  
 30  class NotAuthenticatedError(ProviderNotConfiguredError):
 31      """Raised when the user is not authenticated with the provider."""
 32  
 33  
 34  class AssistantProvider(ABC):
 35      """Abstract base class for assistant providers."""
 36  
 37      @property
 38      @abstractmethod
 39      def name(self) -> str:
 40          """Return the provider identifier (e.g., 'claude_code')."""
 41  
 42      @property
 43      @abstractmethod
 44      def display_name(self) -> str:
 45          """Return the human-readable provider name (e.g., 'Claude Code')."""
 46  
 47      @property
 48      @abstractmethod
 49      def description(self) -> str:
 50          """Return a short description of the provider."""
 51  
 52      @abstractmethod
 53      def is_available(self) -> bool:
 54          """Check if the provider is available and ready to use."""
 55  
 56      @abstractmethod
 57      def check_connection(self, echo: Callable[[str], None] | None = None) -> None:
 58          """
 59          Check if the provider is properly configured and can connect.
 60  
 61          Args:
 62              echo: Optional function to print status messages.
 63  
 64          Raises:
 65              ProviderNotConfiguredError: If the provider is not properly configured.
 66          """
 67  
 68      @abstractmethod
 69      def resolve_skills_path(self, base_directory: Path) -> Path:
 70          """Resolve the skills installation path.
 71  
 72          Args:
 73              base_directory: Base directory to resolve skills path from.
 74  
 75          Returns:
 76              Resolved absolute path for skills installation.
 77          """
 78  
 79      @abstractmethod
 80      def astream(
 81          self,
 82          prompt: str,
 83          tracking_uri: str,
 84          session_id: str | None = None,
 85          cwd: Path | None = None,
 86          context: dict[str, Any] | None = None,
 87      ) -> AsyncGenerator[dict[str, Any], None]:
 88          """
 89          Stream responses from the assistant asynchronously.
 90  
 91          Args:
 92              prompt: The prompt to send to the assistant
 93              tracking_uri: MLflow tracking server URI for the assistant to use
 94              session_id: Session ID for conversation continuity
 95              cwd: Working directory for the assistant
 96              context: Additional context for the assistant, such as information from
 97                  the current UI page the user is viewing (e.g., experimentId, traceId)
 98  
 99          Yields:
100              Event dictionaries with 'type' and 'data' keys.
101              Event types: 'message', 'status', 'done', 'error'
102          """