/ tests / test_helpers.py
test_helpers.py
  1  import os
  2  import secrets
  3  
  4  import numpy as np
  5  import psutil
  6  import pytest
  7  from scipy.stats import linregress
  8  
  9  from khoj.processor.embeddings import EmbeddingsModel
 10  from khoj.processor.tools.online_search import (
 11      read_webpage_at_url,
 12      read_webpage_with_olostep,
 13  )
 14  from khoj.utils import helpers
 15  
 16  
 17  def test_get_from_null_dict():
 18      # null handling
 19      assert helpers.get_from_dict(dict()) == dict()
 20      assert helpers.get_from_dict(dict(), None) == None
 21  
 22      # key present in nested dictionary
 23      # 1-level dictionary
 24      assert helpers.get_from_dict({"a": 1, "b": 2}, "a") == 1
 25      assert helpers.get_from_dict({"a": 1, "b": 2}, "c") == None
 26  
 27      # 2-level dictionary
 28      assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a") == {"a_a": 1}
 29      assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "a", "a_a") == 1
 30  
 31      # key not present in nested dictionary
 32      # 2-level_dictionary
 33      assert helpers.get_from_dict({"a": {"a_a": 1}, "b": 2}, "b", "b_a") == None
 34  
 35  
 36  def test_merge_dicts():
 37      # basic merge of dicts with non-overlapping keys
 38      assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"b": 2}) == {"a": 1, "b": 2}
 39  
 40      # use default dict items when not present in priority dict
 41      assert helpers.merge_dicts(priority_dict={}, default_dict={"b": 2}) == {"b": 2}
 42  
 43      # do not override existing key in priority_dict with default dict
 44      assert helpers.merge_dicts(priority_dict={"a": 1}, default_dict={"a": 2}) == {"a": 1}
 45  
 46  
 47  def test_lru_cache():
 48      # Test initializing cache
 49      cache = helpers.LRU({"a": 1, "b": 2}, capacity=2)
 50      assert cache == {"a": 1, "b": 2}
 51  
 52      # Test capacity overflow
 53      cache["c"] = 3
 54      assert cache == {"b": 2, "c": 3}
 55  
 56      # Test delete least recently used item from LRU cache on capacity overflow
 57      cache["b"]  # accessing 'b' makes it the most recently used item
 58      cache["d"] = 4  # so 'c' is deleted from the cache instead of 'b'
 59      assert cache == {"b": 2, "d": 4}
 60  
 61  
 62  @pytest.mark.skip(reason="Memory leak exists on GPU, MPS devices")
 63  def test_encode_docs_memory_leak():
 64      # Arrange
 65      iterations = 50
 66      batch_size = 20
 67      embeddings_model = EmbeddingsModel()
 68      memory_usage_trend = []
 69      device = f"{helpers.get_device()}".upper()
 70  
 71      # Act
 72      # Encode random strings repeatedly and record memory usage trend
 73      for iteration in range(iterations):
 74          random_docs = [" ".join(secrets.token_hex(5) for _ in range(10)) for _ in range(batch_size)]
 75          a = [embeddings_model.embed_documents(random_docs)]
 76          memory_usage_trend += [psutil.Process().memory_info().rss / (1024 * 1024)]
 77          print(f"{iteration:02d}, {memory_usage_trend[-1]:.2f}", flush=True)
 78  
 79      # Calculate slope of line fitting memory usage history
 80      memory_usage_trend = np.array(memory_usage_trend)
 81      slope, _, _, _, _ = linregress(np.arange(len(memory_usage_trend)), memory_usage_trend)
 82      print(f"Memory usage increased at ~{slope:.2f} MB per iteration on {device}")
 83  
 84      # Assert
 85      # If slope is positive memory utilization is increasing
 86      # Positive threshold of 2, from observing memory usage trend on MPS vs CPU device
 87      assert slope < 2, f"Memory leak suspected on {device}. Memory usage increased at ~{slope:.2f} MB per iteration"
 88  
 89  
 90  @pytest.mark.asyncio
 91  async def test_reading_webpage():
 92      # Arrange
 93      website = "https://en.wikipedia.org/wiki/Great_Chicago_Fire"
 94  
 95      # Act
 96      response = await read_webpage_at_url(website)
 97  
 98      # Assert
 99      assert (
100          "An alarm sent from the area near the fire also failed to register at the courthouse where the fire watchmen were"
101          in response
102      )
103  
104  
105  @pytest.mark.skipif(os.getenv("OLOSTEP_API_KEY") is None, reason="OLOSTEP_API_KEY is not set")
106  @pytest.mark.asyncio
107  async def test_reading_webpage_with_olostep():
108      # Arrange
109      website = "https://en.wikipedia.org/wiki/Great_Chicago_Fire"
110  
111      # Act
112      response = await read_webpage_with_olostep(website)
113  
114      # Assert
115      assert (
116          "An alarm sent from the area near the fire also failed to register at the courthouse where the fire watchmen were"
117          in response
118      )