base.py
1 from abc import ABC, abstractmethod 2 from functools import lru_cache 3 from pathlib import Path 4 from typing import Any, AsyncGenerator, Callable 5 6 from mlflow.assistant.config import AssistantConfig, ProviderConfig 7 8 9 @lru_cache(maxsize=10) 10 def load_config(name: str) -> ProviderConfig: 11 cfg = AssistantConfig.load() 12 if not cfg or name not in cfg.providers: 13 raise RuntimeError(f"Provider configuration not found for {name}") 14 return cfg.providers[name] 15 16 17 def clear_config_cache() -> None: 18 """Clear the config cache to pick up config changes.""" 19 load_config.cache_clear() 20 21 22 class ProviderNotConfiguredError(Exception): 23 """Raised when a provider is not properly configured.""" 24 25 26 class CLINotInstalledError(ProviderNotConfiguredError): 27 """Raised when the provider CLI is not installed.""" 28 29 30 class NotAuthenticatedError(ProviderNotConfiguredError): 31 """Raised when the user is not authenticated with the provider.""" 32 33 34 class AssistantProvider(ABC): 35 """Abstract base class for assistant providers.""" 36 37 @property 38 @abstractmethod 39 def name(self) -> str: 40 """Return the provider identifier (e.g., 'claude_code').""" 41 42 @property 43 @abstractmethod 44 def display_name(self) -> str: 45 """Return the human-readable provider name (e.g., 'Claude Code').""" 46 47 @property 48 @abstractmethod 49 def description(self) -> str: 50 """Return a short description of the provider.""" 51 52 @abstractmethod 53 def is_available(self) -> bool: 54 """Check if the provider is available and ready to use.""" 55 56 @abstractmethod 57 def check_connection(self, echo: Callable[[str], None] | None = None) -> None: 58 """ 59 Check if the provider is properly configured and can connect. 60 61 Args: 62 echo: Optional function to print status messages. 63 64 Raises: 65 ProviderNotConfiguredError: If the provider is not properly configured. 66 """ 67 68 @abstractmethod 69 def resolve_skills_path(self, base_directory: Path) -> Path: 70 """Resolve the skills installation path. 71 72 Args: 73 base_directory: Base directory to resolve skills path from. 74 75 Returns: 76 Resolved absolute path for skills installation. 77 """ 78 79 @abstractmethod 80 def astream( 81 self, 82 prompt: str, 83 tracking_uri: str, 84 session_id: str | None = None, 85 cwd: Path | None = None, 86 context: dict[str, Any] | None = None, 87 ) -> AsyncGenerator[dict[str, Any], None]: 88 """ 89 Stream responses from the assistant asynchronously. 90 91 Args: 92 prompt: The prompt to send to the assistant 93 tracking_uri: MLflow tracking server URI for the assistant to use 94 session_id: Session ID for conversation continuity 95 cwd: Working directory for the assistant 96 context: Additional context for the assistant, such as information from 97 the current UI page the user is viewing (e.g., experimentId, traceId) 98 99 Yields: 100 Event dictionaries with 'type' and 'data' keys. 101 Event types: 'message', 'status', 'done', 'error' 102 """