test_prompt_cache.py
1 import threading 2 import time 3 4 import pytest 5 6 from mlflow.prompt.registry_utils import PromptCache, PromptCacheKey 7 8 9 @pytest.fixture(autouse=True) 10 def reset_cache(): 11 """Reset the prompt cache before and after each test.""" 12 PromptCache._reset_instance() 13 yield 14 PromptCache._reset_instance() 15 16 17 def test_singleton_pattern(): 18 cache1 = PromptCache.get_instance() 19 cache2 = PromptCache.get_instance() 20 assert cache1 is cache2 21 22 23 def test_set_and_get(): 24 cache = PromptCache.get_instance() 25 key = PromptCacheKey.from_parts("test-prompt", version=1) 26 cache.set(key, {"template": "Hello {{name}}"}) 27 assert cache.get(key) == {"template": "Hello {{name}}"} 28 29 30 def test_get_nonexistent(): 31 cache = PromptCache.get_instance() 32 key = PromptCacheKey.from_parts("nonexistent", version=1) 33 assert cache.get(key) is None 34 35 36 def test_ttl_expiration(): 37 cache = PromptCache.get_instance() 38 key = PromptCacheKey.from_parts("test-prompt", version=1) 39 cache.set(key, "value", ttl_seconds=0.01) 40 time.sleep(0.02) 41 assert cache.get(key) is None 42 43 44 def test_delete_prompt(): 45 cache = PromptCache.get_instance() 46 key1 = PromptCacheKey.from_parts("my-prompt", version=1) 47 key2 = PromptCacheKey.from_parts("my-prompt", version=2) 48 key3 = PromptCacheKey.from_parts("other-prompt", version=1) 49 50 cache.set(key1, "value1") 51 cache.set(key2, "value2") 52 cache.set(key3, "value3") 53 54 # Delete only version 1 of my-prompt 55 cache.delete("my-prompt", version=1) 56 57 assert cache.get(key1) is None 58 assert cache.get(key2) == "value2" # version 2 still cached 59 assert cache.get(key3) == "value3" 60 61 62 def test_delete_prompt_by_alias(): 63 cache = PromptCache.get_instance() 64 key1 = PromptCacheKey.from_parts("my-prompt", alias="production") 65 key2 = PromptCacheKey.from_parts("my-prompt", alias="staging") 66 67 cache.set(key1, "value1") 68 cache.set(key2, "value2") 69 70 # Delete only the production alias 71 cache.delete("my-prompt", alias="production") 72 73 assert cache.get(key1) is None 74 assert cache.get(key2) == "value2" # staging still cached 75 76 77 def test_delete_all_prompt_entries(): 78 cache = PromptCache.get_instance() 79 key1 = PromptCacheKey.from_parts("my-prompt", version=1) 80 key2 = PromptCacheKey.from_parts("my-prompt", version=2) 81 key3 = PromptCacheKey.from_parts("my-prompt", alias="latest") 82 key4 = PromptCacheKey.from_parts("other-prompt", version=1) 83 84 cache.set(key1, "value1") 85 cache.set(key2, "value2") 86 cache.set(key3, "value3") 87 cache.set(key4, "value4") 88 89 cache.delete_all("my-prompt") 90 91 assert cache.get(key1) is None 92 assert cache.get(key2) is None 93 assert cache.get(key3) is None 94 assert cache.get(key4) == "value4" 95 96 97 def test_clear(): 98 cache = PromptCache.get_instance() 99 key1 = PromptCacheKey.from_parts("prompt1", version=1) 100 key2 = PromptCacheKey.from_parts("prompt2", version=1) 101 102 cache.set(key1, "value1") 103 cache.set(key2, "value2") 104 cache.clear() 105 106 assert cache.get(key1) is None 107 assert cache.get(key2) is None 108 109 110 def test_generate_cache_key_with_version(): 111 key = PromptCacheKey.from_parts("my-prompt", version=1) 112 assert key.name == "my-prompt" 113 assert key.version == 1 114 assert key.alias is None 115 116 117 def test_generate_cache_key_with_alias(): 118 key = PromptCacheKey.from_parts("my-prompt", alias="production") 119 assert key.name == "my-prompt" 120 assert key.version is None 121 assert key.alias == "production" 122 123 124 def test_generate_cache_key_with_neither(): 125 key = PromptCacheKey.from_parts("my-prompt") 126 assert key.name == "my-prompt" 127 assert key.version is None 128 assert key.alias is None 129 130 131 def test_generate_cache_key_with_both_raises_error(): 132 with pytest.raises(ValueError, match="Cannot specify both version and alias"): 133 PromptCacheKey.from_parts("my-prompt", version=1, alias="production") 134 135 136 def test_generate_cache_key_version_zero(): 137 key = PromptCacheKey.from_parts("my-prompt", version=0) 138 assert key.name == "my-prompt" 139 assert key.version == 0 140 assert key.alias is None 141 142 143 def test_concurrent_get_instance(): 144 instances = [] 145 errors = [] 146 147 def get_instance(): 148 try: 149 instance = PromptCache.get_instance() 150 instances.append(instance) 151 except Exception as e: 152 errors.append(e) 153 154 threads = [ 155 threading.Thread(name=f"prompt-cache-singleton-{i}", target=get_instance) for i in range(10) 156 ] 157 for t in threads: 158 t.start() 159 for t in threads: 160 t.join() 161 162 assert len(errors) == 0 163 assert all(inst is instances[0] for inst in instances) 164 165 166 def test_concurrent_operations(): 167 cache = PromptCache.get_instance() 168 errors = [] 169 170 def writer(thread_id): 171 try: 172 for i in range(50): 173 key = PromptCacheKey.from_parts(f"prompt-{thread_id}-{i}", version=1) 174 cache.set(key, f"value-{thread_id}-{i}") 175 except Exception as e: 176 errors.append(e) 177 178 def reader(thread_id): 179 try: 180 for i in range(50): 181 key = PromptCacheKey.from_parts(f"prompt-{thread_id}-{i}", version=1) 182 cache.get(key) 183 except Exception as e: 184 errors.append(e) 185 186 threads = [] 187 for i in range(5): 188 threads.append(threading.Thread(name=f"prompt-cache-writer-{i}", target=writer, args=(i,))) 189 threads.append(threading.Thread(name=f"prompt-cache-reader-{i}", target=reader, args=(i,))) 190 191 for t in threads: 192 t.start() 193 for t in threads: 194 t.join() 195 196 assert len(errors) == 0 197 198 199 def test_set_uses_default_ttl(): 200 cache = PromptCache.get_instance() 201 key = PromptCacheKey.from_parts("test", version=1) 202 cache.set(key, "value") 203 assert cache.get(key) == "value"