/ tests / agent / test_model_metadata_local_ctx.py
test_model_metadata_local_ctx.py
  1  """Tests for _query_local_context_length and the local server fallback in
  2  get_model_context_length.
  3  
  4  All tests use synthetic inputs — no filesystem or live server required.
  5  """
  6  
  7  import sys
  8  import os
  9  import json
 10  from unittest.mock import MagicMock, patch
 11  
 12  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
 13  
 14  import pytest
 15  
 16  
 17  # ---------------------------------------------------------------------------
 18  # _query_local_context_length — unit tests with mocked httpx
 19  # ---------------------------------------------------------------------------
 20  
 21  class TestQueryLocalContextLengthOllama:
 22      """_query_local_context_length with server_type == 'ollama'."""
 23  
 24      def _make_resp(self, status_code, body):
 25          resp = MagicMock()
 26          resp.status_code = status_code
 27          resp.json.return_value = body
 28          return resp
 29  
 30      def test_ollama_model_info_context_length(self):
 31          """Reads context length from model_info dict in /api/show response."""
 32          from agent.model_metadata import _query_local_context_length
 33  
 34          show_resp = self._make_resp(200, {
 35              "model_info": {"llama.context_length": 131072}
 36          })
 37          models_resp = self._make_resp(404, {})
 38  
 39          client_mock = MagicMock()
 40          client_mock.__enter__ = lambda s: client_mock
 41          client_mock.__exit__ = MagicMock(return_value=False)
 42          client_mock.post.return_value = show_resp
 43          client_mock.get.return_value = models_resp
 44  
 45          with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
 46               patch("httpx.Client", return_value=client_mock):
 47              result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
 48  
 49          assert result == 131072
 50  
 51      def test_ollama_parameters_num_ctx(self):
 52          """Falls back to num_ctx in parameters string when model_info lacks context_length."""
 53          from agent.model_metadata import _query_local_context_length
 54  
 55          show_resp = self._make_resp(200, {
 56              "model_info": {},
 57              "parameters": "num_ctx 32768\ntemperature 0.7\n"
 58          })
 59          models_resp = self._make_resp(404, {})
 60  
 61          client_mock = MagicMock()
 62          client_mock.__enter__ = lambda s: client_mock
 63          client_mock.__exit__ = MagicMock(return_value=False)
 64          client_mock.post.return_value = show_resp
 65          client_mock.get.return_value = models_resp
 66  
 67          with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
 68               patch("httpx.Client", return_value=client_mock):
 69              result = _query_local_context_length("some-model", "http://localhost:11434/v1")
 70  
 71          assert result == 32768
 72  
 73      def test_ollama_num_ctx_wins_over_model_info(self):
 74          """When both num_ctx (Modelfile) and model_info (GGUF) are present,
 75          num_ctx wins because it's the *runtime* context Ollama actually
 76          allocates KV cache for. The GGUF model_info.context_length is the
 77          training max — using it would let Hermes grow conversations past
 78          the runtime limit and Ollama would silently truncate.
 79  
 80          Concrete example: hermes-brain:qwen3-14b-ctx32k is a Modelfile
 81          derived from qwen3:14b with `num_ctx 32768`, but the underlying
 82          GGUF reports `qwen3.context_length: 40960` (training max). If
 83          Hermes used 40960 it would let the conversation grow past 32768
 84          before compressing, and Ollama would truncate the prefix.
 85          """
 86          from agent.model_metadata import _query_local_context_length
 87  
 88          show_resp = self._make_resp(200, {
 89              "model_info": {"qwen3.context_length": 40960},
 90              "parameters": "num_ctx                        32768\ntemperature                    0.6\n",
 91          })
 92          models_resp = self._make_resp(404, {})
 93  
 94          client_mock = MagicMock()
 95          client_mock.__enter__ = lambda s: client_mock
 96          client_mock.__exit__ = MagicMock(return_value=False)
 97          client_mock.post.return_value = show_resp
 98          client_mock.get.return_value = models_resp
 99  
100          with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
101               patch("httpx.Client", return_value=client_mock):
102              result = _query_local_context_length(
103                  "hermes-brain:qwen3-14b-ctx32k", "http://100.77.243.5:11434/v1"
104              )
105  
106          assert result == 32768, (
107              f"Expected num_ctx (32768) to win over model_info (40960), got {result}. "
108              "If Hermes uses the GGUF training max, conversations will silently truncate."
109          )
110  
111      def test_ollama_show_404_falls_through(self):
112          """When /api/show returns 404, falls through to /v1/models/{model}."""
113          from agent.model_metadata import _query_local_context_length
114  
115          show_resp = self._make_resp(404, {})
116          model_detail_resp = self._make_resp(200, {"max_model_len": 65536})
117  
118          client_mock = MagicMock()
119          client_mock.__enter__ = lambda s: client_mock
120          client_mock.__exit__ = MagicMock(return_value=False)
121          client_mock.post.return_value = show_resp
122          client_mock.get.return_value = model_detail_resp
123  
124          with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
125               patch("httpx.Client", return_value=client_mock):
126              result = _query_local_context_length("some-model", "http://localhost:11434/v1")
127  
128          assert result == 65536
129  
130  
131  class TestQueryLocalContextLengthVllm:
132      """_query_local_context_length with vLLM-style /v1/models/{model} response."""
133  
134      def _make_resp(self, status_code, body):
135          resp = MagicMock()
136          resp.status_code = status_code
137          resp.json.return_value = body
138          return resp
139  
140      def test_vllm_max_model_len(self):
141          """Reads max_model_len from /v1/models/{model} response."""
142          from agent.model_metadata import _query_local_context_length
143  
144          detail_resp = self._make_resp(200, {"id": "omnicoder-9b", "max_model_len": 100000})
145          list_resp = self._make_resp(404, {})
146  
147          client_mock = MagicMock()
148          client_mock.__enter__ = lambda s: client_mock
149          client_mock.__exit__ = MagicMock(return_value=False)
150          client_mock.post.return_value = self._make_resp(404, {})
151          client_mock.get.return_value = detail_resp
152  
153          with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
154               patch("httpx.Client", return_value=client_mock):
155              result = _query_local_context_length("omnicoder-9b", "http://localhost:8000/v1")
156  
157          assert result == 100000
158  
159      def test_vllm_context_length_key(self):
160          """Reads context_length from /v1/models/{model} response."""
161          from agent.model_metadata import _query_local_context_length
162  
163          detail_resp = self._make_resp(200, {"id": "some-model", "context_length": 32768})
164  
165          client_mock = MagicMock()
166          client_mock.__enter__ = lambda s: client_mock
167          client_mock.__exit__ = MagicMock(return_value=False)
168          client_mock.post.return_value = self._make_resp(404, {})
169          client_mock.get.return_value = detail_resp
170  
171          with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
172               patch("httpx.Client", return_value=client_mock):
173              result = _query_local_context_length("some-model", "http://localhost:8000/v1")
174  
175          assert result == 32768
176  
177  
178  class TestQueryLocalContextLengthModelsList:
179      """_query_local_context_length: falls back to /v1/models list."""
180  
181      def _make_resp(self, status_code, body):
182          resp = MagicMock()
183          resp.status_code = status_code
184          resp.json.return_value = body
185          return resp
186  
187      def test_models_list_max_model_len(self):
188          """Finds context length for model in /v1/models list."""
189          from agent.model_metadata import _query_local_context_length
190  
191          detail_resp = self._make_resp(404, {})
192          list_resp = self._make_resp(200, {
193              "data": [
194                  {"id": "other-model", "max_model_len": 4096},
195                  {"id": "omnicoder-9b", "max_model_len": 131072},
196              ]
197          })
198  
199          call_count = [0]
200          def side_effect(url, **kwargs):
201              call_count[0] += 1
202              if call_count[0] == 1:
203                  return detail_resp  # /v1/models/omnicoder-9b
204              return list_resp  # /v1/models
205  
206          client_mock = MagicMock()
207          client_mock.__enter__ = lambda s: client_mock
208          client_mock.__exit__ = MagicMock(return_value=False)
209          client_mock.post.return_value = self._make_resp(404, {})
210          client_mock.get.side_effect = side_effect
211  
212          with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
213               patch("httpx.Client", return_value=client_mock):
214              result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
215  
216          assert result == 131072
217  
218      def test_models_list_model_not_found_returns_none(self):
219          """Returns None when model is not in the /v1/models list."""
220          from agent.model_metadata import _query_local_context_length
221  
222          detail_resp = self._make_resp(404, {})
223          list_resp = self._make_resp(200, {
224              "data": [{"id": "other-model", "max_model_len": 4096}]
225          })
226  
227          call_count = [0]
228          def side_effect(url, **kwargs):
229              call_count[0] += 1
230              if call_count[0] == 1:
231                  return detail_resp
232              return list_resp
233  
234          client_mock = MagicMock()
235          client_mock.__enter__ = lambda s: client_mock
236          client_mock.__exit__ = MagicMock(return_value=False)
237          client_mock.post.return_value = self._make_resp(404, {})
238          client_mock.get.side_effect = side_effect
239  
240          with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
241               patch("httpx.Client", return_value=client_mock):
242              result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
243  
244          assert result is None
245  
246  
247  class TestQueryLocalContextLengthLmStudio:
248      """_query_local_context_length with LM Studio native /api/v1/models response."""
249  
250      def _make_resp(self, status_code, body):
251          resp = MagicMock()
252          resp.status_code = status_code
253          resp.json.return_value = body
254          return resp
255  
256      def _make_client(self, native_resp, detail_resp, list_resp):
257          """Build a mock httpx.Client with sequenced GET responses."""
258          client_mock = MagicMock()
259          client_mock.__enter__ = lambda s: client_mock
260          client_mock.__exit__ = MagicMock(return_value=False)
261          client_mock.post.return_value = self._make_resp(404, {})
262  
263          responses = [native_resp, detail_resp, list_resp]
264          call_idx = [0]
265  
266          def get_side_effect(url, **kwargs):
267              idx = call_idx[0]
268              call_idx[0] += 1
269              if idx < len(responses):
270                  return responses[idx]
271              return self._make_resp(404, {})
272  
273          client_mock.get.side_effect = get_side_effect
274          return client_mock
275  
276      def test_lmstudio_exact_key_match(self):
277          """Resolves loaded ctx when key matches exactly."""
278          from agent.model_metadata import _query_local_context_length
279  
280          native_resp = self._make_resp(200, {
281              "models": [
282                  {"key": "nvidia/nvidia-nemotron-super-49b-v1",
283                   "id": "nvidia/nvidia-nemotron-super-49b-v1",
284                   "max_context_length": 1_048_576,
285                   "loaded_instances": [{"config": {"context_length": 131072}}]},
286              ]
287          })
288          client_mock = self._make_client(
289              native_resp,
290              self._make_resp(404, {}),
291              self._make_resp(404, {}),
292          )
293  
294          with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
295               patch("httpx.Client", return_value=client_mock):
296              result = _query_local_context_length(
297                  "nvidia/nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
298              )
299  
300          assert result == 131072
301  
302      def test_lmstudio_slug_only_matches_key_with_publisher_prefix(self):
303          """Fuzzy match: bare model slug matches key that includes publisher prefix.
304  
305          When the user configures the model as "local:nvidia-nemotron-super-49b-v1"
306          (slug only, no publisher), but LM Studio's native API stores it as
307          "nvidia/nvidia-nemotron-super-49b-v1", the lookup must still succeed.
308          """
309          from agent.model_metadata import _query_local_context_length
310  
311          native_resp = self._make_resp(200, {
312              "models": [
313                  {"key": "nvidia/nvidia-nemotron-super-49b-v1",
314                   "id": "nvidia/nvidia-nemotron-super-49b-v1",
315                   "max_context_length": 1_048_576,
316                   "loaded_instances": [{"config": {"context_length": 131072}}]},
317              ]
318          })
319          client_mock = self._make_client(
320              native_resp,
321              self._make_resp(404, {}),
322              self._make_resp(404, {}),
323          )
324  
325          with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
326               patch("httpx.Client", return_value=client_mock):
327              # Model passed in is just the slug after stripping "local:" prefix
328              result = _query_local_context_length(
329                  "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
330              )
331  
332          assert result == 131072
333  
334      def test_lmstudio_v1_models_list_slug_fuzzy_match(self):
335          """Fuzzy match also works for /v1/models list when exact match fails.
336  
337          LM Studio's OpenAI-compat /v1/models returns id like
338          "nvidia/nvidia-nemotron-super-49b-v1" — must match bare slug.
339          """
340          from agent.model_metadata import _query_local_context_length
341  
342          # native /api/v1/models: no match
343          native_resp = self._make_resp(404, {})
344          # /v1/models/{model}: no match
345          detail_resp = self._make_resp(404, {})
346          # /v1/models list: model found with publisher prefix, includes context_length
347          list_resp = self._make_resp(200, {
348              "data": [
349                  {"id": "nvidia/nvidia-nemotron-super-49b-v1", "context_length": 131072},
350              ]
351          })
352          client_mock = self._make_client(native_resp, detail_resp, list_resp)
353  
354          with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
355               patch("httpx.Client", return_value=client_mock):
356              result = _query_local_context_length(
357                  "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
358              )
359  
360          assert result == 131072
361  
362      def test_lmstudio_loaded_instances_context_length(self):
363          """Reads active context_length from loaded_instances when max_context_length absent."""
364          from agent.model_metadata import _query_local_context_length
365  
366          native_resp = self._make_resp(200, {
367              "models": [
368                  {
369                      "key": "nvidia/nvidia-nemotron-super-49b-v1",
370                      "id": "nvidia/nvidia-nemotron-super-49b-v1",
371                      "loaded_instances": [
372                          {"config": {"context_length": 65536}},
373                      ],
374                  },
375              ]
376          })
377          client_mock = self._make_client(
378              native_resp,
379              self._make_resp(404, {}),
380              self._make_resp(404, {}),
381          )
382  
383          with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
384               patch("httpx.Client", return_value=client_mock):
385              result = _query_local_context_length(
386                  "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
387              )
388  
389          assert result == 65536
390  
391      def test_lmstudio_loaded_instance_beats_max_context_length(self):
392          """loaded_instances context_length takes priority over max_context_length.
393  
394          LM Studio may show max_context_length=1_048_576 (theoretical model max)
395          while the actual loaded context is 122_651 (runtime setting). The loaded
396          value is the real constraint and must be preferred.
397          """
398          from agent.model_metadata import _query_local_context_length
399  
400          native_resp = self._make_resp(200, {
401              "models": [
402                  {
403                      "key": "nvidia/nvidia-nemotron-3-nano-4b",
404                      "id": "nvidia/nvidia-nemotron-3-nano-4b",
405                      "max_context_length": 1_048_576,
406                      "loaded_instances": [
407                          {"config": {"context_length": 122_651}},
408                      ],
409                  },
410              ]
411          })
412          client_mock = self._make_client(
413              native_resp,
414              self._make_resp(404, {}),
415              self._make_resp(404, {}),
416          )
417  
418          with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
419               patch("httpx.Client", return_value=client_mock):
420              result = _query_local_context_length(
421                  "nvidia-nemotron-3-nano-4b", "http://192.168.1.22:1234/v1"
422              )
423  
424          assert result == 122_651, (
425              f"Expected loaded instance context (122651) but got {result}. "
426              "max_context_length (1048576) must not win over loaded_instances."
427          )
428  
429  
430  class TestDetectLocalServerTypeAuth:
431      def test_passes_bearer_token_to_probe_requests(self):
432          from agent.model_metadata import detect_local_server_type
433  
434          resp = MagicMock()
435          resp.status_code = 200
436  
437          client_mock = MagicMock()
438          client_mock.__enter__ = lambda s: client_mock
439          client_mock.__exit__ = MagicMock(return_value=False)
440          client_mock.get.return_value = resp
441  
442          with patch("httpx.Client", return_value=client_mock) as mock_client:
443              result = detect_local_server_type("http://localhost:1234/v1", api_key="lm-token")
444  
445          assert result == "lm-studio"
446          assert mock_client.call_args.kwargs["headers"] == {
447              "Authorization": "Bearer lm-token"
448          }
449  
450  
451  class TestFetchEndpointModelMetadataLmStudio:
452      """fetch_endpoint_model_metadata should use LM Studio's native models endpoint."""
453  
454      def _make_resp(self, body):
455          resp = MagicMock()
456          resp.raise_for_status.return_value = None
457          resp.json.return_value = body
458          return resp
459  
460      def test_uses_native_models_endpoint_only(self):
461          from agent.model_metadata import fetch_endpoint_model_metadata
462  
463          native_resp = self._make_resp(
464              {
465                  "models": [
466                      {
467                          "key": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
468                          "id": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf",
469                          "max_context_length": 1_048_576,
470                          "loaded_instances": [
471                              {"config": {"context_length": 131072}}
472                          ],
473                      }
474                  ]
475              }
476          )
477  
478          with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
479               patch("agent.model_metadata.requests.get", return_value=native_resp) as mock_get:
480              result = fetch_endpoint_model_metadata(
481                  "http://localhost:1234/v1",
482                  api_key="lm-token",
483                  force_refresh=True,
484              )
485  
486          assert mock_get.call_count == 1
487          assert mock_get.call_args[0][0] == "http://localhost:1234/api/v1/models"
488          assert mock_get.call_args.kwargs["headers"] == {
489              "Authorization": "Bearer lm-token"
490          }
491          assert result["lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
492          assert result["Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072
493  
494  
495  class TestQueryLocalContextLengthNetworkError:
496      """_query_local_context_length handles network failures gracefully."""
497  
498      def test_connection_error_returns_none(self):
499          """Returns None when the server is unreachable."""
500          from agent.model_metadata import _query_local_context_length
501  
502          client_mock = MagicMock()
503          client_mock.__enter__ = lambda s: client_mock
504          client_mock.__exit__ = MagicMock(return_value=False)
505          client_mock.post.side_effect = Exception("Connection refused")
506          client_mock.get.side_effect = Exception("Connection refused")
507  
508          with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
509               patch("httpx.Client", return_value=client_mock):
510              result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
511  
512          assert result is None
513  
514  
515  # ---------------------------------------------------------------------------
516  # get_model_context_length — integration-style tests with mocked helpers
517  # ---------------------------------------------------------------------------
518  
519  class TestGetModelContextLengthLocalFallback:
520      """get_model_context_length uses local server query before falling back to 2M."""
521  
522      def test_local_endpoint_unknown_model_queries_server(self):
523          """Unknown model on local endpoint gets ctx from server, not 2M default."""
524          from agent.model_metadata import get_model_context_length
525  
526          with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
527               patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
528               patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
529               patch("agent.model_metadata.is_local_endpoint", return_value=True), \
530               patch("agent.model_metadata._query_local_context_length", return_value=131072), \
531               patch("agent.model_metadata.save_context_length") as mock_save:
532              result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
533  
534          assert result == 131072
535  
536      def test_local_endpoint_unknown_model_result_is_cached(self):
537          """Context length returned from local server is persisted to cache."""
538          from agent.model_metadata import get_model_context_length
539  
540          with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
541               patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
542               patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
543               patch("agent.model_metadata.is_local_endpoint", return_value=True), \
544               patch("agent.model_metadata._query_local_context_length", return_value=131072), \
545               patch("agent.model_metadata.save_context_length") as mock_save:
546              get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
547  
548          mock_save.assert_called_once_with("omnicoder-9b", "http://localhost:11434/v1", 131072)
549  
550      def test_local_endpoint_server_returns_none_falls_back_to_2m(self):
551          """When local server returns None, still falls back to 2M probe tier."""
552          from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
553  
554          with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
555               patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
556               patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
557               patch("agent.model_metadata.is_local_endpoint", return_value=True), \
558               patch("agent.model_metadata._query_local_context_length", return_value=None):
559              result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
560  
561          assert result == CONTEXT_PROBE_TIERS[0]
562  
563      def test_non_local_endpoint_does_not_query_local_server(self):
564          """For non-local endpoints, _query_local_context_length is not called."""
565          from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
566  
567          with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
568               patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
569               patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
570               patch("agent.model_metadata.is_local_endpoint", return_value=False), \
571               patch("agent.model_metadata._query_local_context_length") as mock_query:
572              result = get_model_context_length(
573                  "unknown-model", "https://some-cloud-api.example.com/v1"
574              )
575  
576          mock_query.assert_not_called()
577  
578      def test_cached_result_skips_local_query(self):
579          """Cached context length is returned without querying the local server."""
580          from agent.model_metadata import get_model_context_length
581  
582          with patch("agent.model_metadata.get_cached_context_length", return_value=65536), \
583               patch("agent.model_metadata._query_local_context_length") as mock_query:
584              result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
585  
586          assert result == 65536
587          mock_query.assert_not_called()
588  
589      def test_no_base_url_does_not_query_local_server(self):
590          """When base_url is empty, local server is not queried."""
591          from agent.model_metadata import get_model_context_length
592  
593          with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
594               patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
595               patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
596               patch("agent.model_metadata._query_local_context_length") as mock_query:
597              result = get_model_context_length("unknown-xyz-model", "")
598  
599          mock_query.assert_not_called()