/ tests / llama_index / conftest.py
conftest.py
  1  import os
  2  import re
  3  import shutil
  4  import sys
  5  
  6  import pytest
  7  from llama_index.core import (
  8      Document,
  9      KnowledgeGraphIndex,
 10      PromptTemplate,
 11      Settings,
 12      VectorStoreIndex,
 13  )
 14  from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler
 15  from llama_index.core.node_parser import SentenceSplitter
 16  from llama_index.embeddings.openai import OpenAIEmbedding
 17  from llama_index.llms.openai import OpenAI
 18  
 19  from mlflow.tracing.provider import trace_disabled
 20  
 21  from tests.helper_functions import start_mock_openai_server
 22  
 23  
 24  #### General ####
 25  @pytest.fixture
 26  def model_path(tmp_path):
 27      model_path = tmp_path.joinpath("model")
 28      yield model_path
 29  
 30      if os.environ.get("GITHUB_ACTIONS") == "true":
 31          shutil.rmtree(model_path, ignore_errors=True)
 32  
 33  
 34  @pytest.fixture(scope="module")
 35  def spark():
 36      from pyspark.sql import SparkSession
 37  
 38      # NB: ensure that the driver and workers have the same python version
 39      os.environ["PYSPARK_PYTHON"] = sys.executable
 40      os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
 41  
 42      with SparkSession.builder.master("local[*]").getOrCreate() as s:
 43          yield s
 44  
 45  
 46  @pytest.fixture(scope="module", autouse=True)
 47  def mock_openai():
 48      with start_mock_openai_server() as base_url:
 49          yield base_url
 50  
 51  
 52  #### Settings ####
 53  def _mock_tokenizer(text: str) -> list[str]:
 54      """Mock tokenizer."""
 55      tokens = re.split(r"[ \n]", text)
 56      result = []
 57      for token in tokens:
 58          if token.strip() == "":
 59              continue
 60          result.append(token.strip())
 61      return result
 62  
 63  
 64  @pytest.fixture(autouse=True)
 65  def settings(monkeypatch, mock_openai):
 66      """Set the LLM and Embedding model to the mock OpenAI server."""
 67      monkeypatch.setenv("OPENAI_API_KEY", "test")
 68      monkeypatch.setenv("OPENAI_API_BASE", mock_openai)
 69      monkeypatch.setattr(Settings, "llm", OpenAI())
 70      monkeypatch.setattr(Settings, "embed_model", OpenAIEmbedding())
 71      monkeypatch.setattr(Settings, "callback_manager", CallbackManager([LlamaDebugHandler()]))
 72      monkeypatch.setattr(Settings, "_tokenizer", _mock_tokenizer)  # must bypass setter
 73      monkeypatch.setattr(Settings, "context_window", 4096)  # this enters the _prompt_helper field
 74      monkeypatch.setattr(Settings, "node_parser", SentenceSplitter(chunk_size=1024))
 75      monkeypatch.setattr(Settings, "transformations", [SentenceSplitter(chunk_size=1024)])
 76  
 77      assert all(Settings.__dict__.values())  # ensure the full object is populated
 78  
 79      return Settings
 80  
 81  
 82  #### Indexes ####
 83  @pytest.fixture
 84  def document():
 85      return Document.example()
 86  
 87  
 88  @pytest.fixture
 89  @trace_disabled
 90  def single_index(document):
 91      return VectorStoreIndex(nodes=[document])
 92  
 93  
 94  @pytest.fixture
 95  @trace_disabled
 96  def multi_index(document):
 97      return VectorStoreIndex(nodes=[document] * 5)
 98  
 99  
100  @pytest.fixture
101  def single_graph(document):
102      return KnowledgeGraphIndex.from_documents([document])
103  
104  
105  #### Prompt Templates ####
106  @pytest.fixture
107  def qa_prompt_template():
108      return PromptTemplate(
109          template="""
110      Context information is below.
111      ---------------------
112      {context_str}
113      ---------------------
114      Given the context information and not prior knowledge, answer the query.
115      Please write the answer in the style of {tone_name}
116      Query: {query_str}
117      Answer:
118      """
119      )