/ models.py
models.py
1 from pydantic import BaseModel, Field, validator 2 from typing import Dict, List, Optional, Any, Union 3 from datetime import datetime 4 import time 5 6 class InferenceRequest(BaseModel): 7 """Model for inference requests.""" 8 text: str = Field(..., description="Text to analyze") 9 use_segmentation: bool = Field(True, description="Use text segmentation") 10 max_new_tokens: int = Field(8000, description="Maximum number of generated tokens") 11 batch_parallel: bool = Field(True, description="Run tasks in parallel") 12 timeout_seconds: int = Field(300, description="Timeout in seconds") 13 engine_id: Optional[str] = Field(None, description="ID of the inference engine to use") 14 15 @validator('text') 16 def text_must_not_be_empty(cls, v): 17 if not v.strip(): 18 raise ValueError("Text cannot be empty") 19 return v 20 21 @validator('max_new_tokens') 22 def max_tokens_in_range(cls, v): 23 if v < 100 or v > 24000: 24 raise ValueError("Number of tokens must be between 100 and 24000") 25 return v 26 27 @validator('timeout_seconds') 28 def timeout_in_range(cls, v): 29 if v < 10 or v > 3600: 30 raise ValueError("Timeout must be between 10 and 3600 seconds") 31 return v 32 33 class InferenceResponse(BaseModel): 34 """Model for inference request responses.""" 35 task_id: str = Field(..., description="Task identifier") 36 status: str = Field(..., description="Task status (pending, running, completed, failed)") 37 message: str = Field(..., description="Descriptive message") 38 created_at: Optional[float] = Field(None, description="Creation timestamp") 39 40 class SessionRequest(BaseModel): 41 """Model for specific session requests.""" 42 system_prompt: str = Field(..., description="System prompt to use") 43 user_input: Optional[str] = Field("", description="User input to provide") 44 max_new_tokens: int = Field(8000, description="Maximum number of generated tokens") 45 timeout_seconds: int = Field(300, description="Timeout in seconds") 46 engine_id: Optional[str] = Field(None, description="ID of the inference engine to use") 47 48 @validator('system_prompt') 49 def prompt_not_empty(cls, v): 50 if not v.strip(): 51 raise ValueError("System prompt cannot be empty") 52 return v 53 54 @validator('max_new_tokens') 55 def max_tokens_in_range(cls, v): 56 if v < 100 or v > 24000: 57 raise ValueError("Number of tokens must be between 100 and 24000") 58 return v 59 60 @validator('timeout_seconds') 61 def timeout_in_range(cls, v): 62 if v < 10 or v > 3600: 63 raise ValueError("Timeout must be between 10 and 3600 seconds") 64 return v 65 66 class SessionResponse(BaseModel): 67 """Model for specific session request responses.""" 68 task_id: str = Field(..., description="Session task identifier") 69 parent_task_id: str = Field(..., description="Parent task identifier") 70 session_name: str = Field(..., description="Session name") 71 status: str = Field(..., description="Task status (pending, running, completed, failed)") 72 message: str = Field(..., description="Descriptive message") 73 74 class InferenceStatus(BaseModel): 75 """Model for inference task statuses.""" 76 task_id: str = Field(..., description="Task identifier") 77 status: str = Field(..., description="Task status (pending, running, completed, failed)") 78 message: str = Field(..., description="Descriptive message") 79 progress: float = Field(0, description="Progress percentage (0-100)") 80 created_at: float = Field(..., description="Creation timestamp") 81 started_at: Optional[float] = Field(None, description="Execution start timestamp") 82 completed_at: Optional[float] = Field(None, description="Execution end timestamp") 83 results: Optional[Dict[str, Any]] = Field(None, description="Inference results") 84 metrics: Optional[Dict[str, Any]] = Field(None, description="Inference metrics") 85 result_file: Optional[str] = Field(None, description="Path to results file") 86 error: Optional[str] = Field(None, description="Error message in case of failure") 87 error_type: Optional[str] = Field(None, description="Error type in case of failure") 88 89 def formatted_timestamps(self) -> Dict[str, str]: 90 """Returns formatted timestamps.""" 91 timestamps = {} 92 if self.created_at: 93 timestamps["created_at"] = datetime.fromtimestamp(self.created_at).isoformat() 94 if self.started_at: 95 timestamps["started_at"] = datetime.fromtimestamp(self.started_at).isoformat() 96 if self.completed_at: 97 timestamps["completed_at"] = datetime.fromtimestamp(self.completed_at).isoformat() 98 return timestamps 99 100 class BatchInferenceRequest(BaseModel): 101 """Model for batch inference requests.""" 102 texts: List[str] = Field(..., description="List of texts to analyze") 103 use_segmentation: bool = Field(True, description="Use text segmentation") 104 max_new_tokens: int = Field(8000, description="Maximum number of generated tokens") 105 batch_parallel: bool = Field(True, description="Run tasks in parallel") 106 timeout_seconds: int = Field(300, description="Timeout in seconds") 107 engine_id: Optional[str] = Field(None, description="ID of the inference engine to use") 108 max_concurrent: int = Field(3, description="Maximum number of concurrent tasks") 109 110 @validator('texts') 111 def texts_not_empty(cls, v): 112 if not v: 113 raise ValueError("The list of texts cannot be empty") 114 for i, text in enumerate(v): 115 if not text.strip(): 116 raise ValueError(f"Text at index {i} cannot be empty") 117 return v 118 119 @validator('max_new_tokens') 120 def max_tokens_in_range(cls, v): 121 if v < 100 or v > 24000: 122 raise ValueError("Number of tokens must be between 100 and 24000") 123 return v 124 125 @validator('timeout_seconds') 126 def timeout_in_range(cls, v): 127 if v < 10 or v > 3600: 128 raise ValueError("Timeout must be between 10 and 3600 seconds") 129 return v 130 131 @validator('max_concurrent') 132 def concurrent_in_range(cls, v): 133 if v < 1 or v > 10: 134 raise ValueError("Number of concurrent tasks must be between 1 and 10") 135 return v 136 137 class BatchInferenceResponse(BaseModel): 138 """Model for batch inference request responses.""" 139 batch_id: str = Field(..., description="Batch identifier") 140 status: str = Field(..., description="Task status (pending, running, completed, failed)") 141 message: str = Field(..., description="Descriptive message") 142 batch_size: int = Field(..., description="Number of texts in the batch")