test_model_metadata_local_ctx.py
1 """Tests for _query_local_context_length and the local server fallback in 2 get_model_context_length. 3 4 All tests use synthetic inputs — no filesystem or live server required. 5 """ 6 7 import sys 8 import os 9 import json 10 from unittest.mock import MagicMock, patch 11 12 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) 13 14 import pytest 15 16 17 # --------------------------------------------------------------------------- 18 # _query_local_context_length — unit tests with mocked httpx 19 # --------------------------------------------------------------------------- 20 21 class TestQueryLocalContextLengthOllama: 22 """_query_local_context_length with server_type == 'ollama'.""" 23 24 def _make_resp(self, status_code, body): 25 resp = MagicMock() 26 resp.status_code = status_code 27 resp.json.return_value = body 28 return resp 29 30 def test_ollama_model_info_context_length(self): 31 """Reads context length from model_info dict in /api/show response.""" 32 from agent.model_metadata import _query_local_context_length 33 34 show_resp = self._make_resp(200, { 35 "model_info": {"llama.context_length": 131072} 36 }) 37 models_resp = self._make_resp(404, {}) 38 39 client_mock = MagicMock() 40 client_mock.__enter__ = lambda s: client_mock 41 client_mock.__exit__ = MagicMock(return_value=False) 42 client_mock.post.return_value = show_resp 43 client_mock.get.return_value = models_resp 44 45 with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \ 46 patch("httpx.Client", return_value=client_mock): 47 result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1") 48 49 assert result == 131072 50 51 def test_ollama_parameters_num_ctx(self): 52 """Falls back to num_ctx in parameters string when model_info lacks context_length.""" 53 from agent.model_metadata import _query_local_context_length 54 55 show_resp = self._make_resp(200, { 56 "model_info": {}, 57 "parameters": "num_ctx 32768\ntemperature 0.7\n" 58 }) 59 models_resp = self._make_resp(404, {}) 60 61 client_mock = MagicMock() 62 client_mock.__enter__ = lambda s: client_mock 63 client_mock.__exit__ = MagicMock(return_value=False) 64 client_mock.post.return_value = show_resp 65 client_mock.get.return_value = models_resp 66 67 with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \ 68 patch("httpx.Client", return_value=client_mock): 69 result = _query_local_context_length("some-model", "http://localhost:11434/v1") 70 71 assert result == 32768 72 73 def test_ollama_num_ctx_wins_over_model_info(self): 74 """When both num_ctx (Modelfile) and model_info (GGUF) are present, 75 num_ctx wins because it's the *runtime* context Ollama actually 76 allocates KV cache for. The GGUF model_info.context_length is the 77 training max — using it would let Hermes grow conversations past 78 the runtime limit and Ollama would silently truncate. 79 80 Concrete example: hermes-brain:qwen3-14b-ctx32k is a Modelfile 81 derived from qwen3:14b with `num_ctx 32768`, but the underlying 82 GGUF reports `qwen3.context_length: 40960` (training max). If 83 Hermes used 40960 it would let the conversation grow past 32768 84 before compressing, and Ollama would truncate the prefix. 85 """ 86 from agent.model_metadata import _query_local_context_length 87 88 show_resp = self._make_resp(200, { 89 "model_info": {"qwen3.context_length": 40960}, 90 "parameters": "num_ctx 32768\ntemperature 0.6\n", 91 }) 92 models_resp = self._make_resp(404, {}) 93 94 client_mock = MagicMock() 95 client_mock.__enter__ = lambda s: client_mock 96 client_mock.__exit__ = MagicMock(return_value=False) 97 client_mock.post.return_value = show_resp 98 client_mock.get.return_value = models_resp 99 100 with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \ 101 patch("httpx.Client", return_value=client_mock): 102 result = _query_local_context_length( 103 "hermes-brain:qwen3-14b-ctx32k", "http://100.77.243.5:11434/v1" 104 ) 105 106 assert result == 32768, ( 107 f"Expected num_ctx (32768) to win over model_info (40960), got {result}. " 108 "If Hermes uses the GGUF training max, conversations will silently truncate." 109 ) 110 111 def test_ollama_show_404_falls_through(self): 112 """When /api/show returns 404, falls through to /v1/models/{model}.""" 113 from agent.model_metadata import _query_local_context_length 114 115 show_resp = self._make_resp(404, {}) 116 model_detail_resp = self._make_resp(200, {"max_model_len": 65536}) 117 118 client_mock = MagicMock() 119 client_mock.__enter__ = lambda s: client_mock 120 client_mock.__exit__ = MagicMock(return_value=False) 121 client_mock.post.return_value = show_resp 122 client_mock.get.return_value = model_detail_resp 123 124 with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \ 125 patch("httpx.Client", return_value=client_mock): 126 result = _query_local_context_length("some-model", "http://localhost:11434/v1") 127 128 assert result == 65536 129 130 131 class TestQueryLocalContextLengthVllm: 132 """_query_local_context_length with vLLM-style /v1/models/{model} response.""" 133 134 def _make_resp(self, status_code, body): 135 resp = MagicMock() 136 resp.status_code = status_code 137 resp.json.return_value = body 138 return resp 139 140 def test_vllm_max_model_len(self): 141 """Reads max_model_len from /v1/models/{model} response.""" 142 from agent.model_metadata import _query_local_context_length 143 144 detail_resp = self._make_resp(200, {"id": "omnicoder-9b", "max_model_len": 100000}) 145 list_resp = self._make_resp(404, {}) 146 147 client_mock = MagicMock() 148 client_mock.__enter__ = lambda s: client_mock 149 client_mock.__exit__ = MagicMock(return_value=False) 150 client_mock.post.return_value = self._make_resp(404, {}) 151 client_mock.get.return_value = detail_resp 152 153 with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \ 154 patch("httpx.Client", return_value=client_mock): 155 result = _query_local_context_length("omnicoder-9b", "http://localhost:8000/v1") 156 157 assert result == 100000 158 159 def test_vllm_context_length_key(self): 160 """Reads context_length from /v1/models/{model} response.""" 161 from agent.model_metadata import _query_local_context_length 162 163 detail_resp = self._make_resp(200, {"id": "some-model", "context_length": 32768}) 164 165 client_mock = MagicMock() 166 client_mock.__enter__ = lambda s: client_mock 167 client_mock.__exit__ = MagicMock(return_value=False) 168 client_mock.post.return_value = self._make_resp(404, {}) 169 client_mock.get.return_value = detail_resp 170 171 with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \ 172 patch("httpx.Client", return_value=client_mock): 173 result = _query_local_context_length("some-model", "http://localhost:8000/v1") 174 175 assert result == 32768 176 177 178 class TestQueryLocalContextLengthModelsList: 179 """_query_local_context_length: falls back to /v1/models list.""" 180 181 def _make_resp(self, status_code, body): 182 resp = MagicMock() 183 resp.status_code = status_code 184 resp.json.return_value = body 185 return resp 186 187 def test_models_list_max_model_len(self): 188 """Finds context length for model in /v1/models list.""" 189 from agent.model_metadata import _query_local_context_length 190 191 detail_resp = self._make_resp(404, {}) 192 list_resp = self._make_resp(200, { 193 "data": [ 194 {"id": "other-model", "max_model_len": 4096}, 195 {"id": "omnicoder-9b", "max_model_len": 131072}, 196 ] 197 }) 198 199 call_count = [0] 200 def side_effect(url, **kwargs): 201 call_count[0] += 1 202 if call_count[0] == 1: 203 return detail_resp # /v1/models/omnicoder-9b 204 return list_resp # /v1/models 205 206 client_mock = MagicMock() 207 client_mock.__enter__ = lambda s: client_mock 208 client_mock.__exit__ = MagicMock(return_value=False) 209 client_mock.post.return_value = self._make_resp(404, {}) 210 client_mock.get.side_effect = side_effect 211 212 with patch("agent.model_metadata.detect_local_server_type", return_value=None), \ 213 patch("httpx.Client", return_value=client_mock): 214 result = _query_local_context_length("omnicoder-9b", "http://localhost:1234") 215 216 assert result == 131072 217 218 def test_models_list_model_not_found_returns_none(self): 219 """Returns None when model is not in the /v1/models list.""" 220 from agent.model_metadata import _query_local_context_length 221 222 detail_resp = self._make_resp(404, {}) 223 list_resp = self._make_resp(200, { 224 "data": [{"id": "other-model", "max_model_len": 4096}] 225 }) 226 227 call_count = [0] 228 def side_effect(url, **kwargs): 229 call_count[0] += 1 230 if call_count[0] == 1: 231 return detail_resp 232 return list_resp 233 234 client_mock = MagicMock() 235 client_mock.__enter__ = lambda s: client_mock 236 client_mock.__exit__ = MagicMock(return_value=False) 237 client_mock.post.return_value = self._make_resp(404, {}) 238 client_mock.get.side_effect = side_effect 239 240 with patch("agent.model_metadata.detect_local_server_type", return_value=None), \ 241 patch("httpx.Client", return_value=client_mock): 242 result = _query_local_context_length("omnicoder-9b", "http://localhost:1234") 243 244 assert result is None 245 246 247 class TestQueryLocalContextLengthLmStudio: 248 """_query_local_context_length with LM Studio native /api/v1/models response.""" 249 250 def _make_resp(self, status_code, body): 251 resp = MagicMock() 252 resp.status_code = status_code 253 resp.json.return_value = body 254 return resp 255 256 def _make_client(self, native_resp, detail_resp, list_resp): 257 """Build a mock httpx.Client with sequenced GET responses.""" 258 client_mock = MagicMock() 259 client_mock.__enter__ = lambda s: client_mock 260 client_mock.__exit__ = MagicMock(return_value=False) 261 client_mock.post.return_value = self._make_resp(404, {}) 262 263 responses = [native_resp, detail_resp, list_resp] 264 call_idx = [0] 265 266 def get_side_effect(url, **kwargs): 267 idx = call_idx[0] 268 call_idx[0] += 1 269 if idx < len(responses): 270 return responses[idx] 271 return self._make_resp(404, {}) 272 273 client_mock.get.side_effect = get_side_effect 274 return client_mock 275 276 def test_lmstudio_exact_key_match(self): 277 """Resolves loaded ctx when key matches exactly.""" 278 from agent.model_metadata import _query_local_context_length 279 280 native_resp = self._make_resp(200, { 281 "models": [ 282 {"key": "nvidia/nvidia-nemotron-super-49b-v1", 283 "id": "nvidia/nvidia-nemotron-super-49b-v1", 284 "max_context_length": 1_048_576, 285 "loaded_instances": [{"config": {"context_length": 131072}}]}, 286 ] 287 }) 288 client_mock = self._make_client( 289 native_resp, 290 self._make_resp(404, {}), 291 self._make_resp(404, {}), 292 ) 293 294 with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ 295 patch("httpx.Client", return_value=client_mock): 296 result = _query_local_context_length( 297 "nvidia/nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" 298 ) 299 300 assert result == 131072 301 302 def test_lmstudio_slug_only_matches_key_with_publisher_prefix(self): 303 """Fuzzy match: bare model slug matches key that includes publisher prefix. 304 305 When the user configures the model as "local:nvidia-nemotron-super-49b-v1" 306 (slug only, no publisher), but LM Studio's native API stores it as 307 "nvidia/nvidia-nemotron-super-49b-v1", the lookup must still succeed. 308 """ 309 from agent.model_metadata import _query_local_context_length 310 311 native_resp = self._make_resp(200, { 312 "models": [ 313 {"key": "nvidia/nvidia-nemotron-super-49b-v1", 314 "id": "nvidia/nvidia-nemotron-super-49b-v1", 315 "max_context_length": 1_048_576, 316 "loaded_instances": [{"config": {"context_length": 131072}}]}, 317 ] 318 }) 319 client_mock = self._make_client( 320 native_resp, 321 self._make_resp(404, {}), 322 self._make_resp(404, {}), 323 ) 324 325 with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ 326 patch("httpx.Client", return_value=client_mock): 327 # Model passed in is just the slug after stripping "local:" prefix 328 result = _query_local_context_length( 329 "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" 330 ) 331 332 assert result == 131072 333 334 def test_lmstudio_v1_models_list_slug_fuzzy_match(self): 335 """Fuzzy match also works for /v1/models list when exact match fails. 336 337 LM Studio's OpenAI-compat /v1/models returns id like 338 "nvidia/nvidia-nemotron-super-49b-v1" — must match bare slug. 339 """ 340 from agent.model_metadata import _query_local_context_length 341 342 # native /api/v1/models: no match 343 native_resp = self._make_resp(404, {}) 344 # /v1/models/{model}: no match 345 detail_resp = self._make_resp(404, {}) 346 # /v1/models list: model found with publisher prefix, includes context_length 347 list_resp = self._make_resp(200, { 348 "data": [ 349 {"id": "nvidia/nvidia-nemotron-super-49b-v1", "context_length": 131072}, 350 ] 351 }) 352 client_mock = self._make_client(native_resp, detail_resp, list_resp) 353 354 with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ 355 patch("httpx.Client", return_value=client_mock): 356 result = _query_local_context_length( 357 "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" 358 ) 359 360 assert result == 131072 361 362 def test_lmstudio_loaded_instances_context_length(self): 363 """Reads active context_length from loaded_instances when max_context_length absent.""" 364 from agent.model_metadata import _query_local_context_length 365 366 native_resp = self._make_resp(200, { 367 "models": [ 368 { 369 "key": "nvidia/nvidia-nemotron-super-49b-v1", 370 "id": "nvidia/nvidia-nemotron-super-49b-v1", 371 "loaded_instances": [ 372 {"config": {"context_length": 65536}}, 373 ], 374 }, 375 ] 376 }) 377 client_mock = self._make_client( 378 native_resp, 379 self._make_resp(404, {}), 380 self._make_resp(404, {}), 381 ) 382 383 with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ 384 patch("httpx.Client", return_value=client_mock): 385 result = _query_local_context_length( 386 "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" 387 ) 388 389 assert result == 65536 390 391 def test_lmstudio_loaded_instance_beats_max_context_length(self): 392 """loaded_instances context_length takes priority over max_context_length. 393 394 LM Studio may show max_context_length=1_048_576 (theoretical model max) 395 while the actual loaded context is 122_651 (runtime setting). The loaded 396 value is the real constraint and must be preferred. 397 """ 398 from agent.model_metadata import _query_local_context_length 399 400 native_resp = self._make_resp(200, { 401 "models": [ 402 { 403 "key": "nvidia/nvidia-nemotron-3-nano-4b", 404 "id": "nvidia/nvidia-nemotron-3-nano-4b", 405 "max_context_length": 1_048_576, 406 "loaded_instances": [ 407 {"config": {"context_length": 122_651}}, 408 ], 409 }, 410 ] 411 }) 412 client_mock = self._make_client( 413 native_resp, 414 self._make_resp(404, {}), 415 self._make_resp(404, {}), 416 ) 417 418 with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ 419 patch("httpx.Client", return_value=client_mock): 420 result = _query_local_context_length( 421 "nvidia-nemotron-3-nano-4b", "http://192.168.1.22:1234/v1" 422 ) 423 424 assert result == 122_651, ( 425 f"Expected loaded instance context (122651) but got {result}. " 426 "max_context_length (1048576) must not win over loaded_instances." 427 ) 428 429 430 class TestDetectLocalServerTypeAuth: 431 def test_passes_bearer_token_to_probe_requests(self): 432 from agent.model_metadata import detect_local_server_type 433 434 resp = MagicMock() 435 resp.status_code = 200 436 437 client_mock = MagicMock() 438 client_mock.__enter__ = lambda s: client_mock 439 client_mock.__exit__ = MagicMock(return_value=False) 440 client_mock.get.return_value = resp 441 442 with patch("httpx.Client", return_value=client_mock) as mock_client: 443 result = detect_local_server_type("http://localhost:1234/v1", api_key="lm-token") 444 445 assert result == "lm-studio" 446 assert mock_client.call_args.kwargs["headers"] == { 447 "Authorization": "Bearer lm-token" 448 } 449 450 451 class TestFetchEndpointModelMetadataLmStudio: 452 """fetch_endpoint_model_metadata should use LM Studio's native models endpoint.""" 453 454 def _make_resp(self, body): 455 resp = MagicMock() 456 resp.raise_for_status.return_value = None 457 resp.json.return_value = body 458 return resp 459 460 def test_uses_native_models_endpoint_only(self): 461 from agent.model_metadata import fetch_endpoint_model_metadata 462 463 native_resp = self._make_resp( 464 { 465 "models": [ 466 { 467 "key": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf", 468 "id": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf", 469 "max_context_length": 1_048_576, 470 "loaded_instances": [ 471 {"config": {"context_length": 131072}} 472 ], 473 } 474 ] 475 } 476 ) 477 478 with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ 479 patch("agent.model_metadata.requests.get", return_value=native_resp) as mock_get: 480 result = fetch_endpoint_model_metadata( 481 "http://localhost:1234/v1", 482 api_key="lm-token", 483 force_refresh=True, 484 ) 485 486 assert mock_get.call_count == 1 487 assert mock_get.call_args[0][0] == "http://localhost:1234/api/v1/models" 488 assert mock_get.call_args.kwargs["headers"] == { 489 "Authorization": "Bearer lm-token" 490 } 491 assert result["lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072 492 assert result["Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072 493 494 495 class TestQueryLocalContextLengthNetworkError: 496 """_query_local_context_length handles network failures gracefully.""" 497 498 def test_connection_error_returns_none(self): 499 """Returns None when the server is unreachable.""" 500 from agent.model_metadata import _query_local_context_length 501 502 client_mock = MagicMock() 503 client_mock.__enter__ = lambda s: client_mock 504 client_mock.__exit__ = MagicMock(return_value=False) 505 client_mock.post.side_effect = Exception("Connection refused") 506 client_mock.get.side_effect = Exception("Connection refused") 507 508 with patch("agent.model_metadata.detect_local_server_type", return_value=None), \ 509 patch("httpx.Client", return_value=client_mock): 510 result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1") 511 512 assert result is None 513 514 515 # --------------------------------------------------------------------------- 516 # get_model_context_length — integration-style tests with mocked helpers 517 # --------------------------------------------------------------------------- 518 519 class TestGetModelContextLengthLocalFallback: 520 """get_model_context_length uses local server query before falling back to 2M.""" 521 522 def test_local_endpoint_unknown_model_queries_server(self): 523 """Unknown model on local endpoint gets ctx from server, not 2M default.""" 524 from agent.model_metadata import get_model_context_length 525 526 with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ 527 patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ 528 patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ 529 patch("agent.model_metadata.is_local_endpoint", return_value=True), \ 530 patch("agent.model_metadata._query_local_context_length", return_value=131072), \ 531 patch("agent.model_metadata.save_context_length") as mock_save: 532 result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") 533 534 assert result == 131072 535 536 def test_local_endpoint_unknown_model_result_is_cached(self): 537 """Context length returned from local server is persisted to cache.""" 538 from agent.model_metadata import get_model_context_length 539 540 with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ 541 patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ 542 patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ 543 patch("agent.model_metadata.is_local_endpoint", return_value=True), \ 544 patch("agent.model_metadata._query_local_context_length", return_value=131072), \ 545 patch("agent.model_metadata.save_context_length") as mock_save: 546 get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") 547 548 mock_save.assert_called_once_with("omnicoder-9b", "http://localhost:11434/v1", 131072) 549 550 def test_local_endpoint_server_returns_none_falls_back_to_2m(self): 551 """When local server returns None, still falls back to 2M probe tier.""" 552 from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS 553 554 with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ 555 patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ 556 patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ 557 patch("agent.model_metadata.is_local_endpoint", return_value=True), \ 558 patch("agent.model_metadata._query_local_context_length", return_value=None): 559 result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") 560 561 assert result == CONTEXT_PROBE_TIERS[0] 562 563 def test_non_local_endpoint_does_not_query_local_server(self): 564 """For non-local endpoints, _query_local_context_length is not called.""" 565 from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS 566 567 with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ 568 patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ 569 patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ 570 patch("agent.model_metadata.is_local_endpoint", return_value=False), \ 571 patch("agent.model_metadata._query_local_context_length") as mock_query: 572 result = get_model_context_length( 573 "unknown-model", "https://some-cloud-api.example.com/v1" 574 ) 575 576 mock_query.assert_not_called() 577 578 def test_cached_result_skips_local_query(self): 579 """Cached context length is returned without querying the local server.""" 580 from agent.model_metadata import get_model_context_length 581 582 with patch("agent.model_metadata.get_cached_context_length", return_value=65536), \ 583 patch("agent.model_metadata._query_local_context_length") as mock_query: 584 result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") 585 586 assert result == 65536 587 mock_query.assert_not_called() 588 589 def test_no_base_url_does_not_query_local_server(self): 590 """When base_url is empty, local server is not queried.""" 591 from agent.model_metadata import get_model_context_length 592 593 with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ 594 patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ 595 patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ 596 patch("agent.model_metadata._query_local_context_length") as mock_query: 597 result = get_model_context_length("unknown-xyz-model", "") 598 599 mock_query.assert_not_called()