/ models.py
models.py
  1  from pydantic import BaseModel, Field, validator
  2  from typing import Dict, List, Optional, Any, Union
  3  from datetime import datetime
  4  import time
  5  
  6  class InferenceRequest(BaseModel):
  7      """Model for inference requests."""
  8      text: str = Field(..., description="Text to analyze")
  9      use_segmentation: bool = Field(True, description="Use text segmentation")
 10      max_new_tokens: int = Field(8000, description="Maximum number of generated tokens")
 11      batch_parallel: bool = Field(True, description="Run tasks in parallel")
 12      timeout_seconds: int = Field(300, description="Timeout in seconds")
 13      engine_id: Optional[str] = Field(None, description="ID of the inference engine to use")
 14      
 15      @validator('text')
 16      def text_must_not_be_empty(cls, v):
 17          if not v.strip():
 18              raise ValueError("Text cannot be empty")
 19          return v
 20      
 21      @validator('max_new_tokens')
 22      def max_tokens_in_range(cls, v):
 23          if v < 100 or v > 24000:
 24              raise ValueError("Number of tokens must be between 100 and 24000")
 25          return v
 26      
 27      @validator('timeout_seconds')
 28      def timeout_in_range(cls, v):
 29          if v < 10 or v > 3600:
 30              raise ValueError("Timeout must be between 10 and 3600 seconds")
 31          return v
 32  
 33  class InferenceResponse(BaseModel):
 34      """Model for inference request responses."""
 35      task_id: str = Field(..., description="Task identifier")
 36      status: str = Field(..., description="Task status (pending, running, completed, failed)")
 37      message: str = Field(..., description="Descriptive message")
 38      created_at: Optional[float] = Field(None, description="Creation timestamp")
 39  
 40  class SessionRequest(BaseModel):
 41      """Model for specific session requests."""
 42      system_prompt: str = Field(..., description="System prompt to use")
 43      user_input: Optional[str] = Field("", description="User input to provide")
 44      max_new_tokens: int = Field(8000, description="Maximum number of generated tokens")
 45      timeout_seconds: int = Field(300, description="Timeout in seconds")
 46      engine_id: Optional[str] = Field(None, description="ID of the inference engine to use")
 47      
 48      @validator('system_prompt')
 49      def prompt_not_empty(cls, v):
 50          if not v.strip():
 51              raise ValueError("System prompt cannot be empty")
 52          return v
 53      
 54      @validator('max_new_tokens')
 55      def max_tokens_in_range(cls, v):
 56          if v < 100 or v > 24000:
 57              raise ValueError("Number of tokens must be between 100 and 24000")
 58          return v
 59      
 60      @validator('timeout_seconds')
 61      def timeout_in_range(cls, v):
 62          if v < 10 or v > 3600:
 63              raise ValueError("Timeout must be between 10 and 3600 seconds")
 64          return v
 65  
 66  class SessionResponse(BaseModel):
 67      """Model for specific session request responses."""
 68      task_id: str = Field(..., description="Session task identifier")
 69      parent_task_id: str = Field(..., description="Parent task identifier")
 70      session_name: str = Field(..., description="Session name")
 71      status: str = Field(..., description="Task status (pending, running, completed, failed)")
 72      message: str = Field(..., description="Descriptive message")
 73  
 74  class InferenceStatus(BaseModel):
 75      """Model for inference task statuses."""
 76      task_id: str = Field(..., description="Task identifier")
 77      status: str = Field(..., description="Task status (pending, running, completed, failed)")
 78      message: str = Field(..., description="Descriptive message")
 79      progress: float = Field(0, description="Progress percentage (0-100)")
 80      created_at: float = Field(..., description="Creation timestamp")
 81      started_at: Optional[float] = Field(None, description="Execution start timestamp")
 82      completed_at: Optional[float] = Field(None, description="Execution end timestamp")
 83      results: Optional[Dict[str, Any]] = Field(None, description="Inference results")
 84      metrics: Optional[Dict[str, Any]] = Field(None, description="Inference metrics")
 85      result_file: Optional[str] = Field(None, description="Path to results file")
 86      error: Optional[str] = Field(None, description="Error message in case of failure")
 87      error_type: Optional[str] = Field(None, description="Error type in case of failure")
 88      
 89      def formatted_timestamps(self) -> Dict[str, str]:
 90          """Returns formatted timestamps."""
 91          timestamps = {}
 92          if self.created_at:
 93              timestamps["created_at"] = datetime.fromtimestamp(self.created_at).isoformat()
 94          if self.started_at:
 95              timestamps["started_at"] = datetime.fromtimestamp(self.started_at).isoformat()
 96          if self.completed_at:
 97              timestamps["completed_at"] = datetime.fromtimestamp(self.completed_at).isoformat()
 98          return timestamps
 99  
100  class BatchInferenceRequest(BaseModel):
101      """Model for batch inference requests."""
102      texts: List[str] = Field(..., description="List of texts to analyze")
103      use_segmentation: bool = Field(True, description="Use text segmentation")
104      max_new_tokens: int = Field(8000, description="Maximum number of generated tokens")
105      batch_parallel: bool = Field(True, description="Run tasks in parallel")
106      timeout_seconds: int = Field(300, description="Timeout in seconds")
107      engine_id: Optional[str] = Field(None, description="ID of the inference engine to use")
108      max_concurrent: int = Field(3, description="Maximum number of concurrent tasks")
109      
110      @validator('texts')
111      def texts_not_empty(cls, v):
112          if not v:
113              raise ValueError("The list of texts cannot be empty")
114          for i, text in enumerate(v):
115              if not text.strip():
116                  raise ValueError(f"Text at index {i} cannot be empty")
117          return v
118      
119      @validator('max_new_tokens')
120      def max_tokens_in_range(cls, v):
121          if v < 100 or v > 24000:
122              raise ValueError("Number of tokens must be between 100 and 24000")
123          return v
124      
125      @validator('timeout_seconds')
126      def timeout_in_range(cls, v):
127          if v < 10 or v > 3600:
128              raise ValueError("Timeout must be between 10 and 3600 seconds")
129          return v
130      
131      @validator('max_concurrent')
132      def concurrent_in_range(cls, v):
133          if v < 1 or v > 10:
134              raise ValueError("Number of concurrent tasks must be between 1 and 10")
135          return v
136  
137  class BatchInferenceResponse(BaseModel):
138      """Model for batch inference request responses."""
139      batch_id: str = Field(..., description="Batch identifier")
140      status: str = Field(..., description="Task status (pending, running, completed, failed)")
141      message: str = Field(..., description="Descriptive message")
142      batch_size: int = Field(..., description="Number of texts in the batch")