conftest.py
1 import os 2 import re 3 import shutil 4 import sys 5 6 import pytest 7 from llama_index.core import ( 8 Document, 9 KnowledgeGraphIndex, 10 PromptTemplate, 11 Settings, 12 VectorStoreIndex, 13 ) 14 from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler 15 from llama_index.core.node_parser import SentenceSplitter 16 from llama_index.embeddings.openai import OpenAIEmbedding 17 from llama_index.llms.openai import OpenAI 18 19 from mlflow.tracing.provider import trace_disabled 20 21 from tests.helper_functions import start_mock_openai_server 22 23 24 #### General #### 25 @pytest.fixture 26 def model_path(tmp_path): 27 model_path = tmp_path.joinpath("model") 28 yield model_path 29 30 if os.environ.get("GITHUB_ACTIONS") == "true": 31 shutil.rmtree(model_path, ignore_errors=True) 32 33 34 @pytest.fixture(scope="module") 35 def spark(): 36 from pyspark.sql import SparkSession 37 38 # NB: ensure that the driver and workers have the same python version 39 os.environ["PYSPARK_PYTHON"] = sys.executable 40 os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable 41 42 with SparkSession.builder.master("local[*]").getOrCreate() as s: 43 yield s 44 45 46 @pytest.fixture(scope="module", autouse=True) 47 def mock_openai(): 48 with start_mock_openai_server() as base_url: 49 yield base_url 50 51 52 #### Settings #### 53 def _mock_tokenizer(text: str) -> list[str]: 54 """Mock tokenizer.""" 55 tokens = re.split(r"[ \n]", text) 56 result = [] 57 for token in tokens: 58 if token.strip() == "": 59 continue 60 result.append(token.strip()) 61 return result 62 63 64 @pytest.fixture(autouse=True) 65 def settings(monkeypatch, mock_openai): 66 """Set the LLM and Embedding model to the mock OpenAI server.""" 67 monkeypatch.setenv("OPENAI_API_KEY", "test") 68 monkeypatch.setenv("OPENAI_API_BASE", mock_openai) 69 monkeypatch.setattr(Settings, "llm", OpenAI()) 70 monkeypatch.setattr(Settings, "embed_model", OpenAIEmbedding()) 71 monkeypatch.setattr(Settings, "callback_manager", CallbackManager([LlamaDebugHandler()])) 72 monkeypatch.setattr(Settings, "_tokenizer", _mock_tokenizer) # must bypass setter 73 monkeypatch.setattr(Settings, "context_window", 4096) # this enters the _prompt_helper field 74 monkeypatch.setattr(Settings, "node_parser", SentenceSplitter(chunk_size=1024)) 75 monkeypatch.setattr(Settings, "transformations", [SentenceSplitter(chunk_size=1024)]) 76 77 assert all(Settings.__dict__.values()) # ensure the full object is populated 78 79 return Settings 80 81 82 #### Indexes #### 83 @pytest.fixture 84 def document(): 85 return Document.example() 86 87 88 @pytest.fixture 89 @trace_disabled 90 def single_index(document): 91 return VectorStoreIndex(nodes=[document]) 92 93 94 @pytest.fixture 95 @trace_disabled 96 def multi_index(document): 97 return VectorStoreIndex(nodes=[document] * 5) 98 99 100 @pytest.fixture 101 def single_graph(document): 102 return KnowledgeGraphIndex.from_documents([document]) 103 104 105 #### Prompt Templates #### 106 @pytest.fixture 107 def qa_prompt_template(): 108 return PromptTemplate( 109 template=""" 110 Context information is below. 111 --------------------- 112 {context_str} 113 --------------------- 114 Given the context information and not prior knowledge, answer the query. 115 Please write the answer in the style of {tone_name} 116 Query: {query_str} 117 Answer: 118 """ 119 )