inference_router.py
1 """ 2 Router for AI model inferences 3 ------------------------------------------ 4 This module implements routes for running inferences with different AI models. 5 """ 6 # For access to post-processor configurations 7 from config import postprocessing_config 8 9 # For accessing the request and using the post-processor 10 from fastapi import Request 11 import os 12 import time 13 import logging 14 import uuid 15 import json 16 from typing import Dict, List, Any, Optional, Union 17 from pathlib import Path 18 19 from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, File, UploadFile, Form, Body 20 from fastapi.responses import JSONResponse, FileResponse 21 from pydantic import BaseModel, Field, validator 22 23 # Import response models 24 from .response_models import ( 25 SuccessResponse, 26 ErrorResponse, 27 TaskResponse, 28 TaskStatusResponse, 29 TaskListResponse 30 ) 31 32 # Import authentication utilities 33 from auth import get_current_active_user, User 34 35 # Import model manager 36 from model_manager import ModelManager 37 38 # Import prompt manager 39 from utils.prompt_manager import get_prompt_manager 40 41 # Import inference engine 42 from inference_engine import ( 43 run_inference, 44 get_task_status, 45 list_tasks, 46 cancel_task, 47 get_available_models, 48 ModelNotFoundException 49 ) 50 51 # Logging configuration 52 logger = logging.getLogger("api.inference") 53 54 # Create directories for results 55 RESULTS_DIR = Path("inference_results") 56 RESULTS_DIR.mkdir(parents=True, exist_ok=True) 57 58 # Create router 59 inference_router = APIRouter( 60 prefix="/inference", 61 tags=["Inference"], 62 responses={ 63 400: {"model": ErrorResponse, "description": "Invalid request"}, 64 401: {"model": ErrorResponse, "description": "Unauthorized"}, 65 404: {"model": ErrorResponse, "description": "Resource not found"}, 66 500: {"model": ErrorResponse, "description": "Server error"} 67 } 68 ) 69 70 # Pydantic models for requests 71 class TextInferenceRequest(BaseModel): 72 """Model for text inference requests""" 73 model: str 74 text: Optional[str] = None 75 prompt: Optional[str] = None 76 prompt_name: Optional[str] = None 77 language: Optional[str] = "fr" 78 context: Optional[str] = None 79 max_tokens: Optional[int] = 1024 80 temperature: Optional[float] = 0.7 81 top_p: Optional[float] = 1.0 82 n: Optional[int] = 1 83 stop: Optional[Union[str, List[str]]] = None 84 presence_penalty: Optional[float] = 0.0 85 frequency_penalty: Optional[float] = 0.0 86 87 @validator('prompt_name') 88 def validate_prompt_name(cls, v, values): 89 """Verifies that either a custom prompt, a prompt name, or text is specified""" 90 if not v and not values.get('prompt') and not values.get('text'): 91 raise ValueError("You must specify either 'prompt_name', 'prompt', or 'text'") 92 return v 93 94 class TextCompletionResponse(BaseModel): 95 """Model for text completion responses""" 96 id: str = Field(..., description="Unique task identifier") 97 text: str = Field(..., description="Generated text") 98 model: str = Field(..., description="Model used") 99 usage: Dict[str, int] = Field(..., description="Usage statistics") 100 101 class ImageGenerationRequest(BaseModel): 102 """Model for image generation requests""" 103 model: str 104 prompt: str 105 prompt_name: Optional[str] = None 106 text: Optional[str] = None 107 n: Optional[int] = 1 108 size: Optional[str] = "1024x1024" 109 response_format: Optional[str] = "url" 110 111 @validator('prompt') 112 def validate_prompt(cls, v, values): 113 """Verifies that either a custom prompt, or a prompt name with text is specified""" 114 if not v and not values.get('prompt_name') and not values.get('text'): 115 raise ValueError("You must specify either 'prompt', or 'prompt_name' with 'text'") 116 return v 117 118 class ImageGenerationResponse(BaseModel): 119 """Model for image generation responses""" 120 id: str = Field(..., description="Unique task identifier") 121 images: List[str] = Field(..., description="URLs of generated images") 122 model: str = Field(..., description="Model used") 123 124 class ModelsResponse(BaseModel): 125 """Model for the list of available models""" 126 text_models: Dict[str, Any] = Field(..., description="Available text models") 127 image_models: Dict[str, Any] = Field(..., description="Available image models") 128 embedding_models: Dict[str, Any] = Field(..., description="Available embedding models") 129 130 class PostProcessingOptions(BaseModel): 131 """Post-processing options for inferences""" 132 json_simplify: Optional[bool] = False 133 134 # Inference routes 135 @inference_router.post("/text", response_model=TaskResponse) 136 async def create_text_inference( 137 request: TextInferenceRequest, 138 background_tasks: BackgroundTasks, 139 current_user: User = Depends(get_current_active_user) 140 ): 141 """Creates an inference task for text generation""" 142 try: 143 task_id = str(uuid.uuid4()) 144 prompt_manager = get_prompt_manager() 145 final_prompt = None 146 147 # Prompt management with the PromptManager 148 if request.prompt_name: 149 # Use a predefined prompt with the provided text 150 placeholder_values = {} 151 if request.text: 152 placeholder_values["text"] = request.text 153 if request.language: 154 placeholder_values["language"] = request.language 155 if request.context: 156 placeholder_values["context"] = request.context 157 158 final_prompt = prompt_manager.format_prompt( 159 request.prompt_name, 160 **placeholder_values 161 ) 162 163 if not final_prompt: 164 raise HTTPException( 165 status_code=status.HTTP_400_BAD_REQUEST, 166 detail=f"Unable to format prompt '{request.prompt_name}'. Check required placeholders." 167 ) 168 elif request.prompt: 169 # Use the provided prompt directly 170 final_prompt = request.prompt 171 elif request.text: 172 # Use the default prompt with the provided text 173 final_prompt = prompt_manager.format_prompt("default", text=request.text) 174 if not final_prompt: 175 # Fallback if default prompt doesn't exist 176 final_prompt = f"Analyze the following text:\n\n{request.text}" 177 else: 178 raise HTTPException( 179 status_code=status.HTTP_400_BAD_REQUEST, 180 detail="You must provide either 'prompt_name', 'prompt', or 'text'." 181 ) 182 183 # Inference parameters 184 params = { 185 "model": request.model, 186 "prompt": final_prompt, 187 "max_tokens": request.max_tokens, 188 "temperature": request.temperature, 189 "top_p": request.top_p, 190 "n": request.n, 191 "stop": request.stop, 192 "presence_penalty": request.presence_penalty, 193 "frequency_penalty": request.frequency_penalty, 194 "user_id": current_user.username 195 } 196 197 # Launch task in background 198 background_tasks.add_task( 199 run_inference, 200 task_id=task_id, 201 task_type="text", 202 params=params 203 ) 204 205 return TaskResponse( 206 task_id=task_id, 207 status="pending", 208 message="Inference task created" 209 ) 210 211 except ModelNotFoundException as e: 212 raise HTTPException( 213 status_code=status.HTTP_400_BAD_REQUEST, 214 detail=f"Model not found: {str(e)}" 215 ) 216 except HTTPException: 217 raise 218 except Exception as e: 219 logger.error(f"Error creating inference task: {str(e)}") 220 raise HTTPException( 221 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 222 detail=f"Error creating task: {str(e)}" 223 ) 224 225 @inference_router.post("/image", response_model=TaskResponse) 226 async def create_image_generation( 227 request: ImageGenerationRequest, 228 background_tasks: BackgroundTasks, 229 current_user: User = Depends(get_current_active_user) 230 ): 231 """Creates an inference task for image generation""" 232 try: 233 task_id = str(uuid.uuid4()) 234 prompt_manager = get_prompt_manager() 235 final_prompt = None 236 237 # Prompt management with the PromptManager 238 if request.prompt_name and request.text: 239 # Use a predefined prompt with the provided text 240 final_prompt = prompt_manager.format_prompt(request.prompt_name, text=request.text) 241 if not final_prompt: 242 raise HTTPException( 243 status_code=status.HTTP_400_BAD_REQUEST, 244 detail=f"Unable to format prompt '{request.prompt_name}'. Check required placeholders." 245 ) 246 elif request.prompt: 247 # Use the provided prompt directly 248 final_prompt = request.prompt 249 else: 250 raise HTTPException( 251 status_code=status.HTTP_400_BAD_REQUEST, 252 detail="You must provide either 'prompt', or 'prompt_name' with 'text'." 253 ) 254 255 # Inference parameters 256 params = { 257 "model": request.model, 258 "prompt": final_prompt, 259 "n": request.n, 260 "size": request.size, 261 "response_format": request.response_format, 262 "user_id": current_user.username 263 } 264 265 # Launch task in background 266 background_tasks.add_task( 267 run_inference, 268 task_id=task_id, 269 task_type="image", 270 params=params 271 ) 272 273 return TaskResponse( 274 task_id=task_id, 275 status="pending", 276 message="Image generation task created" 277 ) 278 279 except ModelNotFoundException as e: 280 raise HTTPException( 281 status_code=status.HTTP_400_BAD_REQUEST, 282 detail=f"Model not found: {str(e)}" 283 ) 284 except HTTPException: 285 raise 286 except Exception as e: 287 logger.error(f"Error creating image generation task: {str(e)}") 288 raise HTTPException( 289 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 290 detail=f"Error creating task: {str(e)}" 291 ) 292 293 @inference_router.post("/embedding", response_model=TaskResponse) 294 async def create_embedding( 295 text: str = Body(..., embed=True), 296 model: str = Body(..., embed=True), 297 background_tasks: BackgroundTasks = None, 298 current_user: User = Depends(get_current_active_user) 299 ): 300 """Creates an inference task for embedding generation""" 301 try: 302 task_id = str(uuid.uuid4()) 303 304 # Inference parameters 305 params = { 306 "model": model, 307 "text": text, 308 "user_id": current_user.username 309 } 310 311 # Launch task in background 312 background_tasks.add_task( 313 run_inference, 314 task_id=task_id, 315 task_type="embedding", 316 params=params 317 ) 318 319 return TaskResponse( 320 task_id=task_id, 321 status="pending", 322 message="Embedding generation task created" 323 ) 324 325 except ModelNotFoundException as e: 326 raise HTTPException( 327 status_code=status.HTTP_400_BAD_REQUEST, 328 detail=f"Model not found: {str(e)}" 329 ) 330 except Exception as e: 331 logger.error(f"Error creating embedding task: {str(e)}") 332 raise HTTPException( 333 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 334 detail=f"Error creating task: {str(e)}" 335 ) 336 337 @inference_router.post("/chain", response_model=TaskResponse) 338 async def create_inference_chain( 339 text: str = Body(...), 340 prompt_sequence: List[str] = Body(...), 341 model: str = Body(...), 342 max_tokens: Optional[int] = Body(1024), 343 temperature: Optional[float] = Body(0.7), 344 background_tasks: BackgroundTasks = None, 345 current_user: User = Depends(get_current_active_user) 346 ): 347 """Creates a chain inference task (sequence of prompts)""" 348 try: 349 task_id = str(uuid.uuid4()) 350 prompt_manager = get_prompt_manager() 351 352 # Verify that all prompts in the sequence exist 353 for prompt_name in prompt_sequence: 354 if not prompt_manager.get_prompt(prompt_name): 355 raise HTTPException( 356 status_code=status.HTTP_400_BAD_REQUEST, 357 detail=f"Prompt '{prompt_name}' does not exist in the sequence" 358 ) 359 360 # Inference parameters 361 params = { 362 "model": model, 363 "text": text, 364 "prompt_sequence": prompt_sequence, 365 "max_tokens": max_tokens, 366 "temperature": temperature, 367 "user_id": current_user.username 368 } 369 370 # Launch task in background 371 background_tasks.add_task( 372 run_inference, 373 task_id=task_id, 374 task_type="chain", 375 params=params 376 ) 377 378 return TaskResponse( 379 task_id=task_id, 380 status="pending", 381 message="Chain inference task created" 382 ) 383 384 except ModelNotFoundException as e: 385 raise HTTPException( 386 status_code=status.HTTP_400_BAD_REQUEST, 387 detail=f"Model not found: {str(e)}" 388 ) 389 except HTTPException: 390 raise 391 except Exception as e: 392 logger.error(f"Error creating chain inference task: {str(e)}") 393 raise HTTPException( 394 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 395 detail=f"Error creating task: {str(e)}" 396 ) 397 398 @inference_router.post("/system-final", response_model=TaskResponse) 399 async def create_system_final_inference( 400 task_id_1: str = Body(...), 401 task_id_2: str = Body(...), 402 task_id_1_2: str = Body(...), 403 task_id_1_2_1: str = Body(...), 404 model: str = Body(...), 405 max_tokens: Optional[int] = Body(1024), 406 temperature: Optional[float] = Body(0.7), 407 background_tasks: BackgroundTasks = None, 408 current_user: User = Depends(get_current_active_user) 409 ): 410 """Crée une tâche d'inférence finale utilisant les résultats de tâches précédentes""" 411 try: 412 task_id = str(uuid.uuid4()) 413 414 # Vérifier l'existence des tâches précédentes 415 for task_id_check in [task_id_1, task_id_2, task_id_1_2, task_id_1_2_1]: 416 task = get_task_status(task_id_check) 417 if task is None: 418 raise HTTPException( 419 status_code=status.HTTP_404_NOT_FOUND, 420 detail=f"Task {task_id_check} not found" 421 ) 422 if task.get("status") != "completed": 423 raise HTTPException( 424 status_code=status.HTTP_400_BAD_REQUEST, 425 detail=f"Task {task_id_check} is not yet completed" 426 ) 427 428 # Paramètres d'inférence avec les IDs des tâches précédentes 429 params = { 430 "model": model, 431 "max_tokens": max_tokens, 432 "temperature": temperature, 433 "user_id": current_user.username, 434 "task_dependencies": { 435 "task_id_1": task_id_1, 436 "task_id_2": task_id_2, 437 "task_id_1_2": task_id_1_2, 438 "task_id_1_2_1": task_id_1_2_1 439 }, 440 "prompt_name": "system_final" 441 } 442 443 # Lancer la tâche en arrière-plan 444 background_tasks.add_task( 445 run_inference, 446 task_id=task_id, 447 task_type="system_final", 448 params=params 449 ) 450 451 return TaskResponse( 452 task_id=task_id, 453 status="pending", 454 message="System final inference task created" 455 ) 456 457 except ModelNotFoundException as e: 458 raise HTTPException( 459 status_code=status.HTTP_400_BAD_REQUEST, 460 detail=f"Model not found: {str(e)}" 461 ) 462 except HTTPException: 463 raise 464 except Exception as e: 465 logger.error(f"Error creating system final inference task: {str(e)}") 466 raise HTTPException( 467 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 468 detail=f"Error creating task: {str(e)}" 469 ) 470 471 @inference_router.post("/custom", response_model=TaskResponse) 472 async def create_custom_inference( 473 text: str = Body(...), 474 prompt_name: str = Body(...), 475 model: str = Body(...), 476 max_tokens: Optional[int] = Body(1024), 477 temperature: Optional[float] = Body(0.7), 478 language: Optional[str] = Body("fr"), 479 context: Optional[str] = Body(None), 480 content: Optional[str] = Body(None), 481 additional_context: Optional[Dict[str, Any]] = Body(None), 482 background_tasks: BackgroundTasks = None, 483 current_user: User = Depends(get_current_active_user) 484 ): 485 """Creates an inference task with custom placeholders""" 486 try: 487 task_id = str(uuid.uuid4()) 488 prompt_manager = get_prompt_manager() 489 490 # Verify that the prompt exists 491 if not prompt_manager.get_prompt(prompt_name): 492 raise HTTPException( 493 status_code=status.HTTP_400_BAD_REQUEST, 494 detail=f"Prompt '{prompt_name}' does not exist" 495 ) 496 497 # Prepare placeholders 498 placeholder_values = { 499 "text": text, 500 "language": language 501 } 502 503 # Add optional placeholders if present 504 if context: 505 placeholder_values["context"] = context 506 if content: 507 placeholder_values["content"] = content 508 509 # Add any additional context as placeholders 510 if additional_context: 511 placeholder_values.update(additional_context) 512 513 # Format the prompt 514 final_prompt = prompt_manager.format_prompt(prompt_name, **placeholder_values) 515 516 if not final_prompt: 517 raise HTTPException( 518 status_code=status.HTTP_400_BAD_REQUEST, 519 detail=f"Unable to format prompt '{prompt_name}'. Check required placeholders." 520 ) 521 522 # Inference parameters 523 params = { 524 "model": model, 525 "prompt": final_prompt, 526 "max_tokens": max_tokens, 527 "temperature": temperature, 528 "user_id": current_user.username 529 } 530 531 # Launch task in background 532 background_tasks.add_task( 533 run_inference, 534 task_id=task_id, 535 task_type="text", 536 params=params 537 ) 538 539 return TaskResponse( 540 task_id=task_id, 541 status="pending", 542 message="Custom inference task created" 543 ) 544 545 except ModelNotFoundException as e: 546 raise HTTPException( 547 status_code=status.HTTP_400_BAD_REQUEST, 548 detail=f"Model not found: {str(e)}" 549 ) 550 except HTTPException: 551 raise 552 except Exception as e: 553 logger.error(f"Error creating custom inference task: {str(e)}") 554 raise HTTPException( 555 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 556 detail=f"Error creating task: {str(e)}" 557 ) 558 559 @inference_router.get("/task/{task_id}", response_model=TaskStatusResponse) 560 async def get_task( 561 task_id: str, 562 current_user: User = Depends(get_current_active_user) 563 ): 564 """Retrieves the status of an inference task""" 565 try: 566 task = get_task_status(task_id) 567 568 if task is None: 569 raise HTTPException( 570 status_code=status.HTTP_404_NOT_FOUND, 571 detail=f"Task {task_id} not found" 572 ) 573 574 # Check if the user is authorized to access this task 575 if not current_user.is_admin and task.get("user_id") != current_user.username: 576 raise HTTPException( 577 status_code=status.HTTP_403_FORBIDDEN, 578 detail="You are not authorized to access this task" 579 ) 580 581 return TaskStatusResponse(**task) 582 583 except HTTPException: 584 raise 585 except Exception as e: 586 logger.error(f"Error retrieving task: {str(e)}") 587 raise HTTPException( 588 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 589 detail=f"Error retrieving task: {str(e)}" 590 ) 591 592 @inference_router.get("/tasks", response_model=TaskListResponse) 593 async def get_tasks( 594 current_user: User = Depends(get_current_active_user), 595 limit: int = 10, 596 offset: int = 0 597 ): 598 """Retrieves the list of inference tasks""" 599 try: 600 # Filter tasks by user (except for admins) 601 user_filter = None if current_user.is_admin else current_user.username 602 603 tasks = list_tasks(limit=limit, offset=offset, user_id=user_filter) 604 605 return TaskListResponse( 606 total=tasks.get("total", 0), 607 tasks=tasks.get("tasks", {}) 608 ) 609 610 except Exception as e: 611 logger.error(f"Error retrieving tasks: {str(e)}") 612 raise HTTPException( 613 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 614 detail=f"Error retrieving tasks: {str(e)}" 615 ) 616 617 618 @inference_router.delete("/task/{task_id}", response_model=SuccessResponse) 619 async def cancel_task_endpoint( 620 task_id: str, 621 current_user: User = Depends(get_current_active_user) 622 ): 623 """Cancels an inference task""" 624 try: 625 # Check if the task exists and belongs to the user 626 task = get_task_status(task_id) 627 628 if task is None: 629 raise HTTPException( 630 status_code=status.HTTP_404_NOT_FOUND, 631 detail=f"Task {task_id} not found" 632 ) 633 634 # Check if the user is authorized to cancel this task 635 if not current_user.is_admin and task.get("user_id") != current_user.username: 636 raise HTTPException( 637 status_code=status.HTTP_403_FORBIDDEN, 638 detail="You are not authorized to cancel this task" 639 ) 640 641 # Cancel the task 642 result = cancel_task(task_id) 643 644 if not result: 645 return SuccessResponse( 646 success=False, 647 message="The task cannot be canceled because it has already completed" 648 ) 649 650 return SuccessResponse( 651 success=True, 652 message=f"Task {task_id} successfully canceled" 653 ) 654 655 except HTTPException: 656 raise 657 except Exception as e: 658 logger.error(f"Error canceling task: {str(e)}") 659 raise HTTPException( 660 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 661 detail=f"Error canceling task: {str(e)}" 662 ) 663 664 @inference_router.get("/models", response_model=ModelsResponse) 665 async def get_models(): 666 """Retrieves the list of available inference models""" 667 try: 668 models = get_available_models() 669 670 return ModelsResponse( 671 text_models=models.get("text", {}), 672 image_models=models.get("image", {}), 673 embedding_models=models.get("embedding", {}) 674 ) 675 676 except Exception as e: 677 logger.error(f"Error retrieving models: {str(e)}") 678 raise HTTPException( 679 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 680 detail=f"Error retrieving models: {str(e)}" 681 ) 682 683 @inference_router.get("/prompts") 684 async def get_available_prompts( 685 current_user: User = Depends(get_current_active_user) 686 ): 687 """Retrieves the list of available prompts""" 688 try: 689 prompt_manager = get_prompt_manager() 690 available_prompts = prompt_manager.list_prompts() 691 692 prompts_with_placeholders = {} 693 for prompt_name in available_prompts: 694 prompts_with_placeholders[prompt_name] = { 695 "placeholders": prompt_manager.get_placeholder_names(prompt_name) 696 } 697 698 return { 699 "prompts": prompts_with_placeholders 700 } 701 702 except Exception as e: 703 logger.error(f"Error retrieving prompts: {str(e)}") 704 raise HTTPException( 705 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 706 detail=f"Error retrieving prompts: {str(e)}" 707 ) 708 709 @inference_router.get("/results/{task_id}") 710 async def get_results( 711 task_id: str, 712 request: Request, 713 current_user: User = Depends(get_current_active_user) 714 ): 715 """Retrieves the results of an inference task""" 716 try: 717 # Check if the task exists and belongs to the user 718 task = get_task_status(task_id) 719 720 if task is None: 721 raise HTTPException( 722 status_code=status.HTTP_404_NOT_FOUND, 723 detail=f"Task {task_id} not found" 724 ) 725 726 # Check if the user is authorized to access this task 727 if not current_user.is_admin and task.get("user_id") != current_user.username: 728 raise HTTPException( 729 status_code=status.HTTP_403_FORBIDDEN, 730 detail="You are not authorized to access these results" 731 ) 732 733 # Check if the task is completed 734 if task.get("status") != "completed": 735 raise HTTPException( 736 status_code=status.HTTP_400_BAD_REQUEST, 737 detail=f"Task {task_id} is not yet completed" 738 ) 739 740 # Check the result type 741 task_type = task.get("task_type") 742 results = task.get("results") 743 744 if not results: 745 raise HTTPException( 746 status_code=status.HTTP_404_NOT_FOUND, 747 detail=f"No results available for task {task_id}" 748 ) 749 750 # Prepare the response based on the task type 751 if task_type == "image": 752 # For images, return URLs or files 753 if results.get("format") == "url": 754 response_data = {"images": results.get("images", [])} 755 else: 756 # TODO: Handle image files 757 response_data = results 758 759 elif task_type == "text": 760 # For text, simply return the result 761 response_data = {"text": results.get("text", ""), "usage": results.get("usage", {})} 762 763 elif task_type == "embedding": 764 # For embeddings, return the vectors 765 response_data = {"embedding": results.get("embedding", [])} 766 767 elif task_type == "chain": 768 # For inference chains, return intermediate results 769 response_data = { 770 "final_result": results.get("final_result", ""), 771 "intermediate_results": results.get("intermediate_results", {}) 772 } 773 774 else: 775 # By default, return all results 776 response_data = results 777 778 # Apply JSONSimplifier post-processor if available and enabled 779 json_simplifier = getattr(request.app.state, "json_simplifier", None) 780 if json_simplifier and json_simplifier.should_process("inference"): 781 logger.debug(f"Applying JSONSimplifier post-processor to results of task {task_id}") 782 response_data = json_simplifier.process(response_data, "inference") 783 784 return response_data 785 786 except HTTPException: 787 raise 788 except Exception as e: 789 logger.error(f"Error retrieving results: {str(e)}") 790 raise HTTPException( 791 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 792 detail=f"Error retrieving results: {str(e)}" 793 ) 794 795 @inference_router.post("/text/with-options", response_model=TaskResponse) 796 async def create_text_inference_with_options( 797 request: TextInferenceRequest, 798 options: PostProcessingOptions = Body(...), 799 background_tasks: BackgroundTasks = None, 800 current_user: User = Depends(get_current_active_user) 801 ): 802 """Creates a text generation inference task with post-processing options""" 803 try: 804 task_id = str(uuid.uuid4()) 805 prompt_manager = get_prompt_manager() 806 final_prompt = None 807 808 # Prompt management with the PromptManager 809 if request.prompt_name: 810 # Use a predefined prompt with the provided text 811 placeholder_values = {} 812 if request.text: 813 placeholder_values["text"] = request.text 814 if request.language: 815 placeholder_values["language"] = request.language 816 if request.context: 817 placeholder_values["context"] = request.context 818 819 final_prompt = prompt_manager.format_prompt( 820 request.prompt_name, 821 **placeholder_values 822 ) 823 824 if not final_prompt: 825 raise HTTPException( 826 status_code=status.HTTP_400_BAD_REQUEST, 827 detail=f"Unable to format prompt '{request.prompt_name}'. Check required placeholders." 828 ) 829 elif request.prompt: 830 # Use the provided prompt directly 831 final_prompt = request.prompt 832 elif request.text: 833 # Use the default prompt with the provided text 834 final_prompt = prompt_manager.format_prompt("default", text=request.text) 835 if not final_prompt: 836 # Fallback if default prompt doesn't exist 837 final_prompt = f"Analyze the following text:\n\n{request.text}" 838 else: 839 raise HTTPException( 840 status_code=status.HTTP_400_BAD_REQUEST, 841 detail="You must provide either 'prompt_name', 'prompt', or 'text'." 842 ) 843 844 # Inference parameters 845 params = { 846 "model": request.model, 847 "prompt": final_prompt, 848 "max_tokens": request.max_tokens, 849 "temperature": request.temperature, 850 "top_p": request.top_p, 851 "n": request.n, 852 "stop": request.stop, 853 "presence_penalty": request.presence_penalty, 854 "frequency_penalty": request.frequency_penalty, 855 "user_id": current_user.username, 856 # Add post-processing options 857 "post_processing": { 858 "json_simplify": options.json_simplify 859 } 860 } 861 862 # Launch task in background 863 background_tasks.add_task( 864 run_inference, 865 task_id=task_id, 866 task_type="text", 867 params=params 868 ) 869 870 return TaskResponse( 871 task_id=task_id, 872 status="pending", 873 message="Inference task created with post-processing options" 874 ) 875 876 except ModelNotFoundException as e: 877 raise HTTPException( 878 status_code=status.HTTP_400_BAD_REQUEST, 879 detail=f"Model not found: {str(e)}" 880 ) 881 except HTTPException: 882 raise 883 except Exception as e: 884 logger.error(f"Error creating inference task: {str(e)}") 885 raise HTTPException( 886 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 887 detail=f"Error creating task: {str(e)}" 888 ) 889 890 @inference_router.get("/postprocessors") 891 async def get_postprocessors_config( 892 current_user: User = Depends(get_current_active_user) 893 ): 894 """Retrieves the configuration of active post-processors""" 895 try: 896 # Get active post-processor configuration 897 json_simplifier_config = postprocessing_config.get("json_simplifier", {}) 898 899 return { 900 "json_simplifier": { 901 "enabled": json_simplifier_config.get("enabled", False), 902 "model": json_simplifier_config.get("model"), 903 "apply_to": json_simplifier_config.get("apply_to", []) 904 } 905 } 906 907 except Exception as e: 908 logger.error(f"Error retrieving post-processors configuration: {str(e)}") 909 raise HTTPException( 910 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 911 detail=f"Error retrieving configuration: {str(e)}" 912 )