/ tests / test_inference.py
test_inference.py
  1  import pytest
  2  import requests
  3  import json
  4  import os
  5  import time
  6  from typing import Dict, Any, Optional
  7  
  8  # URL de base de l'API (ajuster selon l'environnement)
  9  BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000")
 10  
 11  # Texte d'exemple pour les tests
 12  SAMPLE_TEXT = """
 13  The debate on artificial intelligence poses significant ethical challenges. On one hand, AI offers unprecedented opportunities for progress in healthcare, education, and environmental protection. On the other hand, it raises concerns about privacy, job displacement, and potential risks from autonomous systems. While some argue that AI development should proceed with minimal restrictions to maximize innovation, others advocate for a cautious approach with robust regulatory frameworks. The question isn't whether we should develop AI, but how we can do so responsibly.
 14  """
 15  
 16  # Variable pour stocker les informations de test
 17  test_data = {
 18      "api_key": None,
 19      "task_id": None,
 20      "batch_id": None
 21  }
 22  
 23  def setup_module():
 24      """Configuration initiale pour les tests d'inférence."""
 25      # Récupérer une clé API valide depuis les tests d'authentification ou utiliser une variable d'environnement
 26      try:
 27          # Essayer d'utiliser une variable d'environnement pour la clé API
 28          api_key = os.environ.get("TEST_API_KEY")
 29          
 30          if not api_key:
 31              # Importer et exécuter les tests d'authentification si nécessaire
 32              from test_auth import test_register_user, test_login, test_create_api_key
 33              
 34              # Créer un utilisateur et une clé API si nécessaire
 35              test_register_user()
 36              test_login()
 37              test_create_api_key()
 38              
 39              from test_auth import test_data as auth_test_data
 40              api_key = auth_test_data["api_key"]
 41      except Exception as e:
 42          # En cas d'erreur, une clé de test doit être fournie en variable d'environnement
 43          api_key = os.environ.get("TEST_API_KEY")
 44          if not api_key:
 45              raise Exception("Aucune clé API disponible pour les tests. Définissez TEST_API_KEY ou exécutez test_auth.py")
 46      
 47      test_data["api_key"] = api_key
 48  
 49  def get_auth_headers():
 50      """Retourne les en-têtes d'authentification avec la clé API."""
 51      return {"X-API-Key": test_data["api_key"]}
 52  
 53  def wait_for_task_completion(task_id: str, max_retries: int = 30, delay: int = 10) -> Optional[Dict[str, Any]]:
 54      """
 55      Attend qu'une tâche soit terminée et retourne son résultat.
 56      
 57      Args:
 58          task_id: ID de la tâche à attendre
 59          max_retries: Nombre maximal de tentatives
 60          delay: Délai entre les tentatives en secondes
 61          
 62      Returns:
 63          Dict contenant les données de la tâche ou None en cas d'échec
 64      """
 65      headers = get_auth_headers()
 66      
 67      for i in range(max_retries):
 68          response = requests.get(f"{BASE_URL}/api/tasks/{task_id}", headers=headers)
 69          
 70          if response.status_code != 200:
 71              print(f"Erreur lors de la récupération de la tâche: {response.status_code}")
 72              continue
 73              
 74          data = response.json()
 75          
 76          # Si la tâche a échoué, retourner None
 77          if data["status"] == "failed":
 78              print(f"La tâche a échoué: {data.get('error', 'Erreur inconnue')}")
 79              return None
 80          
 81          # Si la tâche est terminée, retourner les résultats
 82          if data["status"] == "completed":
 83              return data
 84              
 85          # Afficher la progression
 86          print(f"Attente de la fin de la tâche... Progression: {data.get('progress', 0):.0f}% (tentative {i+1}/{max_retries})")
 87              
 88          # Attendre avant de réessayer
 89          time.sleep(delay)
 90      
 91      return None
 92  
 93  def test_inference_without_api_key():
 94      """Teste une requête d'inférence sans clé API."""
 95      # Données pour l'inférence
 96      inference_data = {
 97          "text": SAMPLE_TEXT,
 98          "use_segmentation": True,
 99          "max_new_tokens": 500
100      }
101      
102      # Envoyer la requête sans clé API
103      response = requests.post(f"{BASE_URL}/api/inference/start", json=inference_data)
104      assert response.status_code in [401, 403], f"Code de statut inattendu: {response.status_code}"
105  
106  def test_start_inference():
107      """Teste le démarrage d'une tâche d'inférence."""
108      # S'assurer qu'une clé API est disponible
109      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
110      
111      # Configurer les headers avec la clé API
112      headers = get_auth_headers()
113      
114      # Données pour l'inférence
115      inference_data = {
116          "text": SAMPLE_TEXT,
117          "use_segmentation": True,
118          "max_new_tokens": 500,
119          "timeout_seconds": 120
120      }
121      
122      # Envoyer la requête d'inférence
123      response = requests.post(f"{BASE_URL}/api/inference/start", json=inference_data, headers=headers)
124      
125      # Vérifier si l'API n'est pas disponible ou en mode de compatibilité
126      if response.status_code == 400 and "mode de compatibilité" in response.text:
127          pytest.skip("API en mode de compatibilité")
128      
129      assert response.status_code == 202, f"Code de statut inattendu: {response.status_code}, {response.text}"
130      
131      # Vérifier la réponse
132      data = response.json()
133      assert "task_id" in data
134      assert data["status"] == "pending"
135      
136      # Sauvegarder l'ID de tâche pour les tests suivants
137      test_data["task_id"] = data["task_id"]
138  
139  def test_get_inference_status():
140      """Teste la récupération de l'état d'une tâche d'inférence."""
141      # S'assurer qu'une clé API et un ID de tâche sont disponibles
142      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
143      assert test_data["task_id"] is not None, "Aucun ID de tâche disponible pour le test"
144      
145      # Configurer les headers avec la clé API
146      headers = get_auth_headers()
147      
148      # Attendre que la tâche soit au moins en cours d'exécution
149      max_retries = 10
150      for _ in range(max_retries):
151          # Envoyer la requête pour récupérer l'état de la tâche
152          response = requests.get(f"{BASE_URL}/api/tasks/{test_data['task_id']}", headers=headers)
153          assert response.status_code == 200, f"Code de statut inattendu: {response.status_code}, {response.text}"
154          
155          # Vérifier la réponse
156          data = response.json()
157          assert "status" in data
158          assert "progress" in data
159          assert "type" in data
160          assert data["type"] == "text_inference"
161          
162          # Si la tâche est terminée ou en cours, le test est réussi
163          if data["status"] in ["running", "completed"]:
164              break
165              
166          # Attendre avant de réessayer
167          time.sleep(2)
168      
169      # Vérifier que la tâche a progressé
170      assert data["status"] in ["running", "completed"], f"La tâche est toujours en attente après {max_retries} tentatives"
171  
172  def test_wait_for_inference_completion():
173      """Teste l'attente de la fin d'une tâche d'inférence."""
174      # S'assurer qu'une clé API et un ID de tâche sont disponibles
175      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
176      assert test_data["task_id"] is not None, "Aucun ID de tâche disponible pour le test"
177      
178      # Attendre que la tâche soit terminée
179      result = wait_for_task_completion(test_data["task_id"])
180      
181      # Vérifier que la tâche est terminée avec succès
182      assert result is not None, "La tâche n'a pas été terminée avec succès"
183      assert result["status"] == "completed"
184      assert "results" in result
185      assert isinstance(result["results"], dict)
186      # Vérifier que les résultats contiennent au moins une clé
187      assert len(result["results"]) > 0
188  
189  def test_custom_session():
190      """Teste l'exécution d'une session personnalisée."""
191      # S'assurer qu'une clé API et un ID de tâche sont disponibles
192      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
193      assert test_data["task_id"] is not None, "Aucun ID de tâche disponible pour le test"
194      
195      # Configurer les headers avec la clé API
196      headers = get_auth_headers()
197      
198      # Données pour la session personnalisée
199      session_data = {
200          "system_prompt": "Summarize the following text in a few sentences:\n\n{text}",
201          "user_input": "",
202          "max_new_tokens": 200
203      }
204      
205      # Envoyer la requête pour exécuter une session personnalisée
206      response = requests.post(
207          f"{BASE_URL}/api/inference/session/{test_data['task_id']}/custom_summary",
208          json=session_data,
209          headers=headers
210      )
211      
212      # Vérifier si l'API n'est pas disponible ou en mode de compatibilité
213      if response.status_code == 400 and "mode de compatibilité" in response.text:
214          pytest.skip("API en mode de compatibilité")
215      
216      assert response.status_code == 202, f"Code de statut inattendu: {response.status_code}, {response.text}"
217      
218      # Vérifier la réponse
219      data = response.json()
220      assert "task_id" in data
221      assert "parent_task_id" in data
222      assert data["parent_task_id"] == test_data["task_id"]
223      assert data["session_name"] == "custom_summary"
224  
225  def test_start_batch_inference():
226      """Teste le démarrage d'une tâche d'inférence par lots."""
227      # S'assurer qu'une clé API est disponible
228      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
229      
230      # Configurer les headers avec la clé API
231      headers = get_auth_headers()
232      
233      # Données pour l'inférence par lots
234      batch_data = {
235          "texts": [
236              SAMPLE_TEXT,
237              "AI safety is a critical concern for researchers and policymakers alike.",
238              "The adoption of renewable energy is accelerating globally."
239          ],
240          "use_segmentation": True,
241          "max_new_tokens": 500,
242          "max_concurrent": 2
243      }
244      
245      # Envoyer la requête d'inférence par lots
246      response = requests.post(f"{BASE_URL}/api/inference/batch", json=batch_data, headers=headers)
247      
248      # Si l'utilisateur de test n'a pas accès au traitement par lots, ignorer ce test
249      if response.status_code == 403 and "plan actuel" in response.text:
250          pytest.skip("L'utilisateur de test n'a pas accès au traitement par lots")
251      
252      # Vérifier si l'API n'est pas disponible ou en mode de compatibilité
253      if response.status_code == 400 and "mode de compatibilité" in response.text:
254          pytest.skip("API en mode de compatibilité")
255      
256      assert response.status_code == 202, f"Code de statut inattendu: {response.status_code}, {response.text}"
257      
258      # Vérifier la réponse
259      data = response.json()
260      assert "task_id" in data  # Maintenant on utilise task_id au lieu de batch_id
261      assert "batch_size" in data
262      assert data["batch_size"] == len(batch_data["texts"])
263      
264      # Sauvegarder l'ID de lot pour les tests suivants
265      test_data["batch_id"] = data["task_id"]
266  
267  def test_get_batch_status():
268      """Teste la récupération de l'état d'une tâche d'inférence par lots."""
269      # S'assurer qu'une clé API et un ID de lot sont disponibles
270      if test_data.get("batch_id") is None:
271          pytest.skip("Aucun ID de lot disponible pour le test")
272      
273      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
274      
275      # Configurer les headers avec la clé API
276      headers = get_auth_headers()
277      
278      # Envoyer la requête pour récupérer l'état du lot
279      response = requests.get(f"{BASE_URL}/api/tasks/{test_data['batch_id']}", headers=headers)
280      assert response.status_code == 200, f"Code de statut inattendu: {response.status_code}, {response.text}"
281      
282      # Vérifier la réponse
283      data = response.json()
284      assert "status" in data
285      assert "progress" in data
286      assert "type" in data
287      assert data["type"] == "batch"
288      assert "params" in data
289      assert "batch_size" in data["params"]
290      
291      # Attendre que le lot soit au moins en cours d'exécution
292      assert data["status"] in ["pending", "running", "completed"], f"Statut du lot inattendu: {data['status']}"
293  
294  def test_list_tasks():
295      """Teste la récupération de la liste des tâches."""
296      # S'assurer qu'une clé API est disponible
297      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
298      
299      # Configurer les headers avec la clé API
300      headers = get_auth_headers()
301      
302      # Envoyer la requête pour récupérer la liste des tâches
303      response = requests.get(f"{BASE_URL}/api/tasks", headers=headers)
304      assert response.status_code == 200, f"Code de statut inattendu: {response.status_code}, {response.text}"
305      
306      # Vérifier la réponse
307      data = response.json()
308      assert "total" in data
309      assert "tasks" in data
310      
311      # La structure a changé, tasks est maintenant une liste d'objets
312      assert isinstance(data["tasks"], list)
313      
314      # Vérifier que la tâche créée précédemment est présente
315      task_ids = [task["task_id"] for task in data["tasks"]]
316      assert test_data["task_id"] in task_ids, "La tâche créée n'a pas été trouvée dans la liste"
317  
318  def test_cancel_task():
319      """Teste l'annulation d'une tâche en cours."""
320      # S'assurer qu'une clé API et un ID de tâche sont disponibles
321      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
322      assert test_data["batch_id"] is not None, "Aucun ID de tâche par lots disponible pour le test"
323      
324      # Configurer les headers avec la clé API
325      headers = get_auth_headers()
326      
327      # Envoyer la requête pour annuler la tâche
328      response = requests.post(f"{BASE_URL}/api/tasks/{test_data['batch_id']}/cancel", headers=headers)
329      
330      # Si la tâche est déjà terminée, ce test peut échouer normalement
331      if response.status_code == 400 and "not running" in response.text.lower():
332          pytest.skip("La tâche est déjà terminée et ne peut pas être annulée")
333      
334      assert response.status_code == 200, f"Code de statut inattendu: {response.status_code}, {response.text}"
335      
336      # Vérifier la réponse
337      data = response.json()
338      assert "success" in data
339      assert data["success"] is True
340      
341      # Vérifier que la tâche a bien été annulée
342      response = requests.get(f"{BASE_URL}/api/tasks/{test_data['batch_id']}", headers=headers)
343      assert response.status_code == 200
344      data = response.json()
345      assert data["status"] == "cancelled"
346  
347  def test_delete_task():
348      """Teste la suppression d'une tâche."""
349      # S'assurer qu'une clé API et un ID de tâche sont disponibles
350      assert test_data["api_key"] is not None, "Aucune clé API disponible pour le test"
351      assert test_data["task_id"] is not None, "Aucun ID de tâche disponible pour le test"
352      
353      # Configurer les headers avec la clé API
354      headers = get_auth_headers()
355      
356      # Envoyer la requête pour supprimer la tâche
357      response = requests.delete(f"{BASE_URL}/api/tasks/{test_data['task_id']}", headers=headers)
358      assert response.status_code == 200, f"Code de statut inattendu: {response.status_code}, {response.text}"
359      
360      # Vérifier la réponse
361      data = response.json()
362      assert "success" in data
363      assert data["success"] is True
364      
365      # Vérifier que la tâche a bien été supprimée
366      response = requests.get(f"{BASE_URL}/api/tasks/{test_data['task_id']}", headers=headers)
367      assert response.status_code == 404
368  
369  if __name__ == "__main__":
370      # Initialiser les tests
371      setup_module()
372      
373      # Exécuter les tests manuellement
374      test_inference_without_api_key()
375      test_start_inference()
376      test_get_inference_status()
377      test_wait_for_inference_completion()
378      test_custom_session()
379      
380      try:
381          test_start_batch_inference()
382          test_get_batch_status()
383      except Exception as e:
384          print(f"Les tests de traitement par lots ont échoué: {e}")
385      
386      test_list_tasks()
387      test_cancel_task()
388      test_delete_task()
389      
390      print("Tous les tests d'inférence ont réussi!")