/ tests / langchain / conftest.py
conftest.py
 1  import importlib
 2  from unittest import mock
 3  
 4  import openai
 5  import pytest
 6  from langchain.embeddings.base import Embeddings
 7  from pydantic import BaseModel
 8  
 9  from tests.helper_functions import start_mock_openai_server
10  from tests.tracing.helper import reset_autolog_state  # noqa: F401
11  
12  
13  @pytest.fixture(autouse=True)
14  def set_envs(monkeypatch, mock_openai):
15      monkeypatch.setenv("OPENAI_API_KEY", "test")
16      monkeypatch.setenv("OPENAI_API_BASE", mock_openai)
17      monkeypatch.setenv("SERPAPI_API_KEY", "test")
18      importlib.reload(openai)
19  
20  
21  @pytest.fixture(scope="module", autouse=True)
22  def mock_openai():
23      with start_mock_openai_server() as base_url:
24          yield base_url
25  
26  
27  @pytest.fixture(autouse=True)
28  def reset_autolog(reset_autolog_state):
29      # Apply the reset_autolog_state fixture to all tests for LangChain
30      return
31  
32  
33  @pytest.fixture(autouse=True)
34  def mock_init_auth():
35      def mocked_init_auth(config_instance):
36          config_instance.host = "https://databricks.com/"
37          config_instance._header_factory = lambda: {}
38  
39      with mock.patch("databricks.sdk.config.Config.init_auth", new=mocked_init_auth):
40          yield
41  
42  
43  # Define a special embedding for testing
44  class DeterministicDummyEmbeddings(Embeddings, BaseModel):
45      size: int
46  
47      def _get_embedding(self, text: str) -> list[float]:
48          import numpy as np
49  
50          seed = abs(hash(text)) % (10**8)
51          np.random.seed(seed)
52          return list(np.random.normal(size=self.size))
53  
54      def embed_documents(self, texts: list[str]) -> list[list[float]]:
55          return [self._get_embedding(t) for t in texts]
56  
57      def embed_query(self, text: str) -> list[float]:
58          return self._get_embedding(text)