/ src / api / models.py
models.py
  1  """
  2  Pydantic request/response models for Ag3ntum API.
  3  
  4  Defines the API contract for all endpoints.
  5  All parameters from CLI are available via HTTP request.
  6  """
  7  from datetime import datetime, timezone
  8  from typing import Literal, Optional
  9  import re
 10  
 11  from pydantic import BaseModel, Field, field_validator
 12  
 13  from .waf_filter import truncate_text_content
 14  
 15  
 16  class TokenResponse(BaseModel):
 17      """Response from POST /auth/login."""
 18      access_token: str = Field(description="JWT access token")
 19      token_type: str = Field(default="bearer", description="Token type")
 20      user_id: str = Field(description="User ID associated with the token")
 21      expires_in: int = Field(description="Token expiry in seconds")
 22  
 23  
 24  class UserResponse(BaseModel):
 25      """Response from GET /auth/me."""
 26      id: str = Field(description="User ID")
 27      username: str = Field(description="Username")
 28      email: str = Field(description="Email address")
 29      role: str = Field(description="User role (admin/user)")
 30      reseller_id: Optional[str] = Field(default=None, description="Reseller ID if user belongs to a reseller")
 31      created_at: datetime = Field(description="Account creation timestamp")
 32  
 33  
 34  class HealthResponse(BaseModel):
 35      """Response from GET /health."""
 36      status: str = Field(default="ok", description="Health status")
 37      version: str = Field(description="API version")
 38      timestamp: datetime = Field(
 39          default_factory=lambda: datetime.now(timezone.utc),
 40          description="Current server time"
 41      )
 42  
 43  
 44  class ComponentHealth(BaseModel):
 45      """Health status for a single component."""
 46      status: str = Field(description="Component status: ok, degraded, or unhealthy")
 47      latency_ms: float | None = Field(default=None, description="Response latency in milliseconds")
 48      error: str | None = Field(default=None, description="Error message if unhealthy")
 49  
 50  
 51  class DeepHealthResponse(BaseModel):
 52      """Response from GET /health/deep with detailed component health."""
 53      status: str = Field(description="Overall status: ok, degraded, or unhealthy")
 54      version: str = Field(description="API version")
 55      timestamp: datetime = Field(
 56          default_factory=lambda: datetime.now(timezone.utc),
 57          description="Current server time"
 58      )
 59      database: ComponentHealth = Field(description="Database health status")
 60      redis: ComponentHealth = Field(description="Redis health status")
 61  
 62  
 63  class ConfigResponse(BaseModel):
 64      """Response from GET /config."""
 65      models_available: list[str] = Field(
 66          description="List of available Claude models"
 67      )
 68      default_model: str = Field(
 69          description="Default model to use"
 70      )
 71      thinking_tokens: Optional[int] = Field(
 72          default=None,
 73          description="Token budget for extended thinking mode"
 74      )
 75  
 76  
 77  # =============================================================================
 78  # Agent Configuration Overrides (matches agent.yaml + CLI args)
 79  # =============================================================================
 80  
 81  class AgentConfigOverrides(BaseModel):
 82      """
 83      Configuration overrides for agent execution.
 84  
 85      All fields are optional - if not provided, values from agent.yaml are used.
 86      These match the CLI arguments: --model, --max-turns, --timeout, etc.
 87      """
 88      # Model override (CLI: --model)
 89      model: Optional[str] = Field(
 90          default=None,
 91          description="Claude model to use (overrides agent.yaml)"
 92      )
 93  
 94      # Execution limits (CLI: --max-turns, --timeout)
 95      max_turns: Optional[int] = Field(
 96          default=None,
 97          description="Maximum conversation turns (overrides agent.yaml)"
 98      )
 99      timeout_seconds: Optional[int] = Field(
100          default=None,
101          description="Execution timeout in seconds (overrides agent.yaml)"
102      )
103  
104      # Feature toggles (CLI: --no-skills maps to enable_skills=false)
105      enable_skills: Optional[bool] = Field(
106          default=None,
107          description="Enable custom skills (overrides agent.yaml)"
108      )
109      enable_file_checkpointing: Optional[bool] = Field(
110          default=None,
111          description="Enable file change tracking (overrides agent.yaml)"
112      )
113  
114      # Permission settings (CLI: --permission-mode, --profile)
115      permission_mode: Optional[str] = Field(
116          default=None,
117          description="Permission mode: default, acceptEdits, plan, bypassPermissions"
118      )
119      profile: Optional[str] = Field(
120          default=None,
121          description="Permission profile file path"
122      )
123  
124      # Role template (from agent.yaml)
125      role: Optional[str] = Field(
126          default=None,
127          description="Role template name (loads prompts/roles/<role>.md)"
128      )
129  
130      # SDK options (from agent.yaml)
131      max_buffer_size: Optional[int] = Field(
132          default=None,
133          description="Maximum buffer size for streaming"
134      )
135      output_format: Optional[str] = Field(
136          default=None,
137          description="Output format: text, json, stream-json"
138      )
139      include_partial_messages: Optional[bool] = Field(
140          default=None,
141          description="Include partial/incomplete messages in output"
142      )
143  
144      # Extended thinking configuration
145      thinking_tokens: Optional[int] = Field(
146          default=None,
147          description="Token budget for extended thinking mode (overrides agent.yaml)"
148      )
149  
150  
151  # =============================================================================
152  # Session Requests
153  # =============================================================================
154  
155  class RunTaskRequest(BaseModel):
156      """
157      Request body for POST /sessions/run - unified endpoint to run agent tasks.
158  
159      This is the primary endpoint that combines session creation and task execution.
160      Supports both new sessions and resuming existing sessions.
161  
162      Matches CLI capabilities:
163        python agent.py --task "..." --model "..." --max-turns 50
164        python agent.py --resume SESSION_ID --task "Continue..."
165      """
166      # Task (required)
167      task: str = Field(
168          description="Task description to execute"
169      )
170  
171      # Additional directories (CLI: --add-dir)
172      additional_dirs: list[str] = Field(
173          default_factory=list,
174          description="Additional directories the agent can access"
175      )
176  
177      # Session resumption (CLI: --resume, --fork-session)
178      resume_session_id: Optional[str] = Field(
179          default=None,
180          description="Session ID to resume (e.g., 20260103_210631_ecc41d66)"
181      )
182      fork_session: bool = Field(
183          default=False,
184          description="Fork to new session when resuming instead of continuing"
185      )
186  
187      # All agent config overrides
188      config: AgentConfigOverrides = Field(
189          default_factory=AgentConfigOverrides,
190          description="Agent configuration overrides (model, max_turns, etc.)"
191      )
192  
193      # Dynamic mounts for this session (NEW)
194      dynamic_mounts: Optional[list["DynamicMountRequest"]] = Field(
195          default=None,
196          description="Dynamic folder mounts for this session only",
197          max_length=10
198      )
199  
200      @field_validator("task")
201      @classmethod
202      def truncate_task(cls, v: str) -> str:
203          """Apply WAF filter to task field."""
204          return truncate_text_content(v, "task") or ""
205  
206  
207  class CreateSessionRequest(BaseModel):
208      """
209      Request body for POST /sessions - creates a session without starting.
210  
211      Use this if you want to create a session first and start later.
212      For most cases, use POST /sessions/run instead.
213      """
214      task: str = Field(description="Task description for the agent")
215      model: Optional[str] = Field(
216          default=None,
217          description="Claude model to use (overrides config)"
218      )
219  
220      @field_validator("task")
221      @classmethod
222      def truncate_task(cls, v: str) -> str:
223          """Apply WAF filter to task field."""
224          return truncate_text_content(v, "task") or ""
225  
226  
227  class StartTaskRequest(BaseModel):
228      """
229      Request body for POST /sessions/{id}/task - starts task on existing session.
230  
231      All fields are optional - uses session's stored values if not provided.
232      Matches CLI capabilities for task continuation.
233      """
234      # Task override (optional - uses session's task if not provided)
235      task: Optional[str] = Field(
236          default=None,
237          description="Task to execute (optional, uses session task if not provided)"
238      )
239  
240      # Additional directories (CLI: --add-dir)
241      additional_dirs: list[str] = Field(
242          default_factory=list,
243          description="Additional directories the agent can access"
244      )
245  
246      # Session resumption
247      resume_session_id: Optional[str] = Field(
248          default=None,
249          description="Resume from a different session (optional)"
250      )
251      fork_session: bool = Field(
252          default=False,
253          description="Fork to new session when resuming"
254      )
255  
256      # All agent config overrides
257      config: AgentConfigOverrides = Field(
258          default_factory=AgentConfigOverrides,
259          description="Agent configuration overrides"
260      )
261  
262      @field_validator("task")
263      @classmethod
264      def truncate_task(cls, v: Optional[str]) -> Optional[str]:
265          """Apply WAF filter to task field."""
266          if v is None:
267              return None
268          return truncate_text_content(v, "task")
269  
270  
271  # =============================================================================
272  # Session Responses
273  # =============================================================================
274  
275  class SessionResponse(BaseModel):
276      """Response representing a session."""
277      id: str = Field(description="Session ID")
278      status: str = Field(description="Session status")
279      task: Optional[str] = Field(default=None, description="Task description")
280      model: Optional[str] = Field(default=None, description="Model used")
281      created_at: datetime = Field(description="Creation timestamp")
282      updated_at: datetime = Field(description="Last update timestamp")
283      completed_at: Optional[datetime] = Field(
284          default=None,
285          description="Completion timestamp"
286      )
287      num_turns: int = Field(default=0, description="Number of conversation turns")
288      duration_ms: Optional[int] = Field(
289          default=None,
290          description="Duration in milliseconds"
291      )
292      total_cost_usd: Optional[float] = Field(
293          default=None,
294          description="Total cost in USD"
295      )
296      cancel_requested: bool = Field(
297          default=False,
298          description="Whether cancellation was requested"
299      )
300      resumable: Optional[bool] = Field(
301          default=None,
302          description="Whether session can be resumed (has established Claude session)"
303      )
304  
305      # Claude SDK session ID for resumption
306      claude_session_id: Optional[str] = Field(
307          default=None,
308          description="Claude SDK session ID for resumption"
309      )
310  
311      # Cumulative statistics across all resumptions
312      cumulative_turns: int = Field(
313          default=0,
314          description="Total turns across all resumptions"
315      )
316      cumulative_duration_ms: int = Field(
317          default=0,
318          description="Total duration across all resumptions in milliseconds"
319      )
320      cumulative_cost_usd: float = Field(
321          default=0.0,
322          description="Total cost across all resumptions in USD"
323      )
324      cumulative_input_tokens: int = Field(
325          default=0,
326          description="Total input tokens across all resumptions"
327      )
328      cumulative_output_tokens: int = Field(
329          default=0,
330          description="Total output tokens across all resumptions"
331      )
332  
333      # Session forking
334      parent_session_id: Optional[str] = Field(
335          default=None,
336          description="Parent session ID if this session was forked"
337      )
338  
339      # Queue management fields
340      queue_position: Optional[int] = Field(
341          default=None,
342          description="Position in queue (if status is 'queued')"
343      )
344      queued_at: Optional[datetime] = Field(
345          default=None,
346          description="Time when task was queued"
347      )
348      is_auto_resume: bool = Field(
349          default=False,
350          description="Whether this is an auto-resumed session"
351      )
352  
353  
354  class SessionListResponse(BaseModel):
355      """Response for GET /sessions."""
356      sessions: list[SessionResponse] = Field(
357          default_factory=list,
358          description="List of sessions"
359      )
360      total: int = Field(description="Total number of sessions")
361  
362  
363  class TaskStartedResponse(BaseModel):
364      """Response from POST /sessions/run or POST /sessions/{id}/task."""
365      session_id: str = Field(description="Session ID")
366      status: str = Field(description="Session status (running or queued)")
367      message: str = Field(description="Status message")
368      resumed_from: Optional[str] = Field(
369          default=None,
370          description="Session ID that was resumed (if applicable)"
371      )
372      queue_position: Optional[int] = Field(
373          default=None,
374          description="Queue position (if status is 'queued')"
375      )
376  
377  
378  class CancelResponse(BaseModel):
379      """Response from POST /sessions/{id}/cancel."""
380      session_id: str = Field(description="Session ID")
381      status: str = Field(description="Session status after cancellation")
382      message: str = Field(description="Cancellation message")
383  
384  
385  class QueuedSessionInfo(BaseModel):
386      """Information about a queued session."""
387      session_id: str = Field(description="Session ID")
388      queue_position: Optional[int] = Field(description="Position in queue")
389      queued_at: Optional[datetime] = Field(description="Time queued")
390      is_auto_resume: bool = Field(default=False, description="Auto-resume session")
391  
392  
393  class QueueStatusResponse(BaseModel):
394      """Response from GET /queue/status."""
395      global_queue_length: int = Field(description="Total tasks in queue")
396      global_active_tasks: int = Field(description="Currently running tasks")
397      user_active_tasks: int = Field(description="User's running tasks")
398      user_queued_tasks: list[QueuedSessionInfo] = Field(
399          default_factory=list,
400          description="User's queued tasks"
401      )
402      max_concurrent_global: int = Field(description="Max global concurrent tasks")
403      max_concurrent_user: int = Field(description="Max per-user concurrent tasks")
404  
405  
406  class TokenUsageResponse(BaseModel):
407      """Token usage breakdown for a completed task."""
408      input_tokens: int = Field(default=0, description="Input tokens (non-cached)")
409      output_tokens: int = Field(default=0, description="Output tokens generated")
410      cache_creation_input_tokens: int = Field(
411          default=0,
412          description="Tokens used to create cache"
413      )
414      cache_read_input_tokens: int = Field(
415          default=0,
416          description="Tokens read from cache"
417      )
418  
419      @property
420      def total_input(self) -> int:
421          """Total input tokens including cache."""
422          return self.input_tokens + self.cache_creation_input_tokens + self.cache_read_input_tokens
423  
424      @property
425      def total(self) -> int:
426          """Total tokens (input + output)."""
427          return self.total_input + self.output_tokens
428  
429  
430  class ResultMetrics(BaseModel):
431      """Execution metrics for a completed task."""
432      duration_ms: Optional[int] = Field(
433          default=None,
434          description="Duration in milliseconds"
435      )
436      num_turns: int = Field(default=0, description="Number of conversation turns")
437      total_cost_usd: Optional[float] = Field(
438          default=None,
439          description="Total cost in USD"
440      )
441      model: Optional[str] = Field(
442          default=None,
443          description="Model used for execution"
444      )
445      usage: Optional[TokenUsageResponse] = Field(
446          default=None,
447          description="Token usage breakdown"
448      )
449  
450  
451  class ResultResponse(BaseModel):
452      """Response from GET /sessions/{id}/result (event summary + metrics)."""
453      session_id: str = Field(description="Session ID")
454      status: str = Field(description="Task status: COMPLETE, PARTIAL, FAILED")
455      error: str = Field(default="", description="Error message if any")
456      comments: str = Field(default="", description="Additional comments")
457      output: str = Field(default="", description="Task output")
458      result_files: list[str] = Field(
459          default_factory=list,
460          description="Generated file paths"
461      )
462      metrics: Optional[ResultMetrics] = Field(
463          default=None,
464          description="Execution metrics (duration, turns, cost)"
465      )
466  
467  
468  class ErrorResponse(BaseModel):
469      """Standard error response."""
470      detail: str = Field(description="Error message")
471      code: Optional[str] = Field(default=None, description="Error code")
472  
473  
474  # =============================================================================
475  # Dynamic Mount Models
476  # =============================================================================
477  
478  class DynamicMountRequest(BaseModel):
479      """Request to mount a dynamic path for this session."""
480  
481      base: str = Field(
482          ...,
483          description="Name of the dynamic base (from config)",
484          min_length=1,
485          max_length=64,
486          examples=["logs", "projects", "user-home"]
487      )
488  
489      subpath: Optional[str] = Field(
490          default=None,
491          description="Subdirectory within the base (optional)",
492          max_length=512,
493          examples=["nginx", "myapp/logs"]
494      )
495  
496      alias: Optional[str] = Field(
497          default=None,
498          description="Name for the mount in workspace root. Auto-generated from host path if omitted.",
499          max_length=64,
500          examples=["app-logs", "my-project"]
501      )
502  
503      mode: Optional[Literal["ro", "rw"]] = Field(
504          default=None,
505          description="Access mode (defaults to 'ro')"
506      )
507  
508      @field_validator("base")
509      @classmethod
510      def validate_base_name(cls, v: str) -> str:
511          """Validate base name - alphanumeric, hyphen, underscore only."""
512          if not re.match(r'^[a-zA-Z0-9_-]+$', v):
513              raise ValueError("Must contain only alphanumeric, hyphen, underscore")
514          return v
515  
516      @field_validator("alias")
517      @classmethod
518      def validate_alias(cls, v: Optional[str]) -> Optional[str]:
519          """Validate alias if provided - alphanumeric, hyphen, underscore only."""
520          if v is not None and not re.match(r'^[a-zA-Z0-9_-]+$', v):
521              raise ValueError("Must contain only alphanumeric, hyphen, underscore")
522          return v
523  
524      @field_validator("subpath")
525      @classmethod
526      def validate_subpath(cls, v: Optional[str]) -> Optional[str]:
527          """Validate subpath - no path traversal."""
528          if v is None:
529              return None
530          # Reject dangerous patterns BEFORE any normalization
531          if ".." in v or v.startswith("/") or "\x00" in v or "\\" in v:
532              raise ValueError("Invalid subpath: contains forbidden characters")
533          # Only allow safe characters
534          if not re.match(r'^[a-zA-Z0-9/_.-]+$', v):
535              raise ValueError("Invalid subpath: contains forbidden characters")
536          return v
537  
538  
539  class DynamicMountInfo(BaseModel):
540      """Information about a mounted dynamic path (response)."""
541  
542      alias: str = Field(description="Mount alias name")
543      workspace_path: str = Field(description="Path in workspace (e.g., ./nginx-logs)")
544      mode: str = Field(description="Access mode: ro or rw")
545      source_base: str = Field(description="Source base name")
546      source_subpath: Optional[str] = Field(default=None, description="Source subpath")
547      host_path: Optional[str] = Field(default=None, description="Original host path (e.g., /var/log)")
548  
549  
550  class DynamicBaseInfo(BaseModel):
551      """Information about an available dynamic base (for UI)."""
552  
553      name: str = Field(description="Base name")
554      description: str = Field(description="Human-readable description")
555      max_mode: str = Field(description="Maximum allowed mode (ro or rw)")
556      requires_subpath: bool = Field(default=False, description="Whether subpath is required")
557      host_path: str = Field(description="Original host path (e.g., /var/log)")
558  
559  
560  class AvailableDynamicMountsResponse(BaseModel):
561      """Response for GET /sessions/dynamic-mounts/available."""
562      enabled: bool = Field(description="Whether dynamic mounts feature is enabled")
563      bases: list[DynamicBaseInfo] = Field(
564          default_factory=list,
565          description="List of available dynamic bases for this user"
566      )
567      max_mounts_per_session: int = Field(
568          default=10,
569          description="Maximum mounts allowed per session"
570      )
571  
572  
573  class SubmitAnswerRequest(BaseModel):
574      """Request body for POST /sessions/{id}/answer."""
575      question_id: str = Field(description="ID of the question being answered")
576      answer: str = Field(description="User's answer to the question")
577  
578  
579  class SubmitAnswerResponse(BaseModel):
580      """Response from POST /sessions/{id}/answer."""
581      success: bool = Field(description="Whether the answer was submitted successfully")
582      message: str = Field(description="Status message")
583      can_resume: bool = Field(
584          default=False,
585          description="Whether the session can be resumed now that the answer is submitted"
586      )