/ tests / examples / all_llm_provider / test_all_llm_provider.py
test_all_llm_provider.py
 1  import os
 2  import pytest
 3  import sys
 4  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
 5  
 6  from examples.test_utils.get_trace_data import (
 7      run_command,
 8      extract_information,
 9      load_trace_data
10  )
11  
12  from examples.test_utils.get_components import (
13      get_component_structure_and_sequence
14  )
15  
16  @pytest.mark.parametrize("provider, model, async_mode", [
17      # OpenAI
18      ("openai", "gpt-4o-mini", True),
19      ("openai", "gpt-4o-mini", False),
20      
21      # # Anthropic
22      # ("anthropic", "claude-3-opus-20240229", True),
23      # ("anthropic", "claude-3-opus-20240229", False),
24      
25      # # Groq
26      # ("groq", "llama3-8b-8192", True),
27      # ("groq", "llama3-8b-8192", False),
28      
29      # LiteLLM
30      ("litellm", "gpt-4o-mini", True),
31      ("litellm", "gpt-4o-mini", False),
32      
33      # Azure
34      ("azure", "azure-gpt-4o-mini", True),
35      ("azure", "azure-gpt-4o-mini", False),
36      
37      # Google
38      ("google", "gemini-1.5-flash", True),
39      ("google", "gemini-1.5-flash", False),
40      
41      # Chat Google
42      ("chat_google", "gemini-1.5-flash", True),
43      ("chat_google", "gemini-1.5-flash", False),
44  ])
45  
46  def test_all_llm_provider(provider: str, model: str, async_mode: bool):
47      # Build the command to run all_llm_provider.py with the provided arguments
48      command = f'python all_llm_provider.py --model {model} --provider {provider} --async_llm {async_mode}'
49      cwd = os.path.dirname(os.path.abspath(__file__))  # Use the current directory
50      output = run_command(command, cwd=cwd)
51      
52      # Extract trace file location from logs
53      locations = extract_information(output)
54  
55      # Load and validate the trace data
56      data = load_trace_data(locations)
57  
58      # Get component structure and sequence
59      component_sequence = get_component_structure_and_sequence(data)
60  
61      # Print component sequence
62      print("Component sequence:", component_sequence)
63  
64      # Validate component sequence
65      assert len(component_sequence) == 1, f"Expected 1 component, got {len(component_sequence)}"
66  
67  
68