/ tests / genai / utils / test_prompt_cache.py
test_prompt_cache.py
  1  import threading
  2  import time
  3  
  4  import pytest
  5  
  6  from mlflow.prompt.registry_utils import PromptCache, PromptCacheKey
  7  
  8  
  9  @pytest.fixture(autouse=True)
 10  def reset_cache():
 11      """Reset the prompt cache before and after each test."""
 12      PromptCache._reset_instance()
 13      yield
 14      PromptCache._reset_instance()
 15  
 16  
 17  def test_singleton_pattern():
 18      cache1 = PromptCache.get_instance()
 19      cache2 = PromptCache.get_instance()
 20      assert cache1 is cache2
 21  
 22  
 23  def test_set_and_get():
 24      cache = PromptCache.get_instance()
 25      key = PromptCacheKey.from_parts("test-prompt", version=1)
 26      cache.set(key, {"template": "Hello {{name}}"})
 27      assert cache.get(key) == {"template": "Hello {{name}}"}
 28  
 29  
 30  def test_get_nonexistent():
 31      cache = PromptCache.get_instance()
 32      key = PromptCacheKey.from_parts("nonexistent", version=1)
 33      assert cache.get(key) is None
 34  
 35  
 36  def test_ttl_expiration():
 37      cache = PromptCache.get_instance()
 38      key = PromptCacheKey.from_parts("test-prompt", version=1)
 39      cache.set(key, "value", ttl_seconds=0.01)
 40      time.sleep(0.02)
 41      assert cache.get(key) is None
 42  
 43  
 44  def test_delete_prompt():
 45      cache = PromptCache.get_instance()
 46      key1 = PromptCacheKey.from_parts("my-prompt", version=1)
 47      key2 = PromptCacheKey.from_parts("my-prompt", version=2)
 48      key3 = PromptCacheKey.from_parts("other-prompt", version=1)
 49  
 50      cache.set(key1, "value1")
 51      cache.set(key2, "value2")
 52      cache.set(key3, "value3")
 53  
 54      # Delete only version 1 of my-prompt
 55      cache.delete("my-prompt", version=1)
 56  
 57      assert cache.get(key1) is None
 58      assert cache.get(key2) == "value2"  # version 2 still cached
 59      assert cache.get(key3) == "value3"
 60  
 61  
 62  def test_delete_prompt_by_alias():
 63      cache = PromptCache.get_instance()
 64      key1 = PromptCacheKey.from_parts("my-prompt", alias="production")
 65      key2 = PromptCacheKey.from_parts("my-prompt", alias="staging")
 66  
 67      cache.set(key1, "value1")
 68      cache.set(key2, "value2")
 69  
 70      # Delete only the production alias
 71      cache.delete("my-prompt", alias="production")
 72  
 73      assert cache.get(key1) is None
 74      assert cache.get(key2) == "value2"  # staging still cached
 75  
 76  
 77  def test_delete_all_prompt_entries():
 78      cache = PromptCache.get_instance()
 79      key1 = PromptCacheKey.from_parts("my-prompt", version=1)
 80      key2 = PromptCacheKey.from_parts("my-prompt", version=2)
 81      key3 = PromptCacheKey.from_parts("my-prompt", alias="latest")
 82      key4 = PromptCacheKey.from_parts("other-prompt", version=1)
 83  
 84      cache.set(key1, "value1")
 85      cache.set(key2, "value2")
 86      cache.set(key3, "value3")
 87      cache.set(key4, "value4")
 88  
 89      cache.delete_all("my-prompt")
 90  
 91      assert cache.get(key1) is None
 92      assert cache.get(key2) is None
 93      assert cache.get(key3) is None
 94      assert cache.get(key4) == "value4"
 95  
 96  
 97  def test_clear():
 98      cache = PromptCache.get_instance()
 99      key1 = PromptCacheKey.from_parts("prompt1", version=1)
100      key2 = PromptCacheKey.from_parts("prompt2", version=1)
101  
102      cache.set(key1, "value1")
103      cache.set(key2, "value2")
104      cache.clear()
105  
106      assert cache.get(key1) is None
107      assert cache.get(key2) is None
108  
109  
110  def test_generate_cache_key_with_version():
111      key = PromptCacheKey.from_parts("my-prompt", version=1)
112      assert key.name == "my-prompt"
113      assert key.version == 1
114      assert key.alias is None
115  
116  
117  def test_generate_cache_key_with_alias():
118      key = PromptCacheKey.from_parts("my-prompt", alias="production")
119      assert key.name == "my-prompt"
120      assert key.version is None
121      assert key.alias == "production"
122  
123  
124  def test_generate_cache_key_with_neither():
125      key = PromptCacheKey.from_parts("my-prompt")
126      assert key.name == "my-prompt"
127      assert key.version is None
128      assert key.alias is None
129  
130  
131  def test_generate_cache_key_with_both_raises_error():
132      with pytest.raises(ValueError, match="Cannot specify both version and alias"):
133          PromptCacheKey.from_parts("my-prompt", version=1, alias="production")
134  
135  
136  def test_generate_cache_key_version_zero():
137      key = PromptCacheKey.from_parts("my-prompt", version=0)
138      assert key.name == "my-prompt"
139      assert key.version == 0
140      assert key.alias is None
141  
142  
143  def test_concurrent_get_instance():
144      instances = []
145      errors = []
146  
147      def get_instance():
148          try:
149              instance = PromptCache.get_instance()
150              instances.append(instance)
151          except Exception as e:
152              errors.append(e)
153  
154      threads = [
155          threading.Thread(name=f"prompt-cache-singleton-{i}", target=get_instance) for i in range(10)
156      ]
157      for t in threads:
158          t.start()
159      for t in threads:
160          t.join()
161  
162      assert len(errors) == 0
163      assert all(inst is instances[0] for inst in instances)
164  
165  
166  def test_concurrent_operations():
167      cache = PromptCache.get_instance()
168      errors = []
169  
170      def writer(thread_id):
171          try:
172              for i in range(50):
173                  key = PromptCacheKey.from_parts(f"prompt-{thread_id}-{i}", version=1)
174                  cache.set(key, f"value-{thread_id}-{i}")
175          except Exception as e:
176              errors.append(e)
177  
178      def reader(thread_id):
179          try:
180              for i in range(50):
181                  key = PromptCacheKey.from_parts(f"prompt-{thread_id}-{i}", version=1)
182                  cache.get(key)
183          except Exception as e:
184              errors.append(e)
185  
186      threads = []
187      for i in range(5):
188          threads.append(threading.Thread(name=f"prompt-cache-writer-{i}", target=writer, args=(i,)))
189          threads.append(threading.Thread(name=f"prompt-cache-reader-{i}", target=reader, args=(i,)))
190  
191      for t in threads:
192          t.start()
193      for t in threads:
194          t.join()
195  
196      assert len(errors) == 0
197  
198  
199  def test_set_uses_default_ttl():
200      cache = PromptCache.get_instance()
201      key = PromptCacheKey.from_parts("test", version=1)
202      cache.set(key, "value")
203      assert cache.get(key) == "value"