/ tests / tools / test_vision_tools.py
test_vision_tools.py
  1  """Tests for tools/vision_tools.py — URL validation, type hints, error logging."""
  2  
  3  import asyncio
  4  import json
  5  import logging
  6  import os
  7  from pathlib import Path
  8  from typing import Awaitable
  9  from unittest.mock import AsyncMock, MagicMock, patch
 10  
 11  import pytest
 12  
 13  from tools.vision_tools import (
 14      _validate_image_url,
 15      _handle_vision_analyze,
 16      _determine_mime_type,
 17      _image_to_base64_data_url,
 18      _resize_image_for_vision,
 19      _is_image_size_error,
 20      _MAX_BASE64_BYTES,
 21      _RESIZE_TARGET_BYTES,
 22      vision_analyze_tool,
 23      check_vision_requirements,
 24  )
 25  
 26  
 27  # ---------------------------------------------------------------------------
 28  # _validate_image_url — urlparse-based validation
 29  # ---------------------------------------------------------------------------
 30  
 31  
 32  class TestValidateImageUrl:
 33      """Tests for URL validation, including urlparse-based netloc check."""
 34  
 35      def test_valid_https_url(self):
 36          with patch("tools.url_safety.socket.getaddrinfo", return_value=[
 37              (2, 1, 6, "", ("93.184.216.34", 0)),
 38          ]):
 39              assert _validate_image_url("https://example.com/image.jpg") is True
 40  
 41      def test_valid_http_url(self):
 42          with patch("tools.url_safety.socket.getaddrinfo", return_value=[
 43              (2, 1, 6, "", ("93.184.216.34", 0)),
 44          ]):
 45              assert _validate_image_url("http://cdn.example.org/photo.png") is True
 46  
 47      def test_valid_url_without_extension(self):
 48          """CDN endpoints that redirect to images should still pass."""
 49          with patch("tools.url_safety.socket.getaddrinfo", return_value=[
 50              (2, 1, 6, "", ("93.184.216.34", 0)),
 51          ]):
 52              assert _validate_image_url("https://cdn.example.com/abcdef123") is True
 53  
 54      def test_valid_url_with_query_params(self):
 55          with patch("tools.url_safety.socket.getaddrinfo", return_value=[
 56              (2, 1, 6, "", ("93.184.216.34", 0)),
 57          ]):
 58              assert _validate_image_url("https://img.example.com/pic?w=200&h=200") is True
 59  
 60      def test_localhost_url_blocked_by_ssrf(self):
 61          """localhost URLs are now blocked by SSRF protection."""
 62          assert _validate_image_url("http://localhost:8080/image.png") is False
 63  
 64      def test_valid_url_with_port(self):
 65          with patch("tools.url_safety.socket.getaddrinfo", return_value=[
 66              (2, 1, 6, "", ("93.184.216.34", 0)),
 67          ]):
 68              assert _validate_image_url("http://example.com:8080/image.png") is True
 69  
 70      def test_valid_url_with_path_only(self):
 71          with patch("tools.url_safety.socket.getaddrinfo", return_value=[
 72              (2, 1, 6, "", ("93.184.216.34", 0)),
 73          ]):
 74              assert _validate_image_url("https://example.com/") is True
 75  
 76      def test_rejects_empty_string(self):
 77          assert _validate_image_url("") is False
 78  
 79      def test_rejects_none(self):
 80          assert _validate_image_url(None) is False
 81  
 82      def test_rejects_non_string(self):
 83          assert _validate_image_url(12345) is False
 84  
 85      def test_rejects_ftp_scheme(self):
 86          assert _validate_image_url("ftp://files.example.com/image.jpg") is False
 87  
 88      def test_rejects_file_scheme(self):
 89          assert _validate_image_url("file:///etc/passwd") is False
 90  
 91      def test_rejects_no_scheme(self):
 92          assert _validate_image_url("example.com/image.jpg") is False
 93  
 94      def test_rejects_javascript_scheme(self):
 95          assert _validate_image_url("javascript:alert(1)") is False
 96  
 97      def test_rejects_http_without_netloc(self):
 98          """http:// alone has no network location — urlparse catches this."""
 99          assert _validate_image_url("http://") is False
100  
101      def test_rejects_https_without_netloc(self):
102          assert _validate_image_url("https://") is False
103  
104      def test_rejects_http_colon_only(self):
105          assert _validate_image_url("http:") is False
106  
107      def test_rejects_data_url(self):
108          assert _validate_image_url("data:image/png;base64,iVBOR") is False
109  
110      def test_rejects_whitespace_only(self):
111          assert _validate_image_url("   ") is False
112  
113      def test_rejects_boolean(self):
114          assert _validate_image_url(True) is False
115  
116      def test_rejects_list(self):
117          assert _validate_image_url(["https://example.com"]) is False
118  
119  
120  # ---------------------------------------------------------------------------
121  # _determine_mime_type
122  # ---------------------------------------------------------------------------
123  
124  
125  class TestDetermineMimeType:
126      def test_jpg(self):
127          assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
128  
129      def test_jpeg(self):
130          assert _determine_mime_type(Path("photo.jpeg")) == "image/jpeg"
131  
132      def test_png(self):
133          assert _determine_mime_type(Path("screenshot.png")) == "image/png"
134  
135      def test_gif(self):
136          assert _determine_mime_type(Path("anim.gif")) == "image/gif"
137  
138      def test_webp(self):
139          assert _determine_mime_type(Path("modern.webp")) == "image/webp"
140  
141      def test_unknown_extension_defaults_to_jpeg(self):
142          assert _determine_mime_type(Path("file.xyz")) == "image/jpeg"
143  
144  
145  # ---------------------------------------------------------------------------
146  # _image_to_base64_data_url
147  # ---------------------------------------------------------------------------
148  
149  
150  class TestImageToBase64DataUrl:
151      def test_returns_data_url(self, tmp_path):
152          img = tmp_path / "test.png"
153          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
154          result = _image_to_base64_data_url(img)
155          assert result.startswith("data:image/png;base64,")
156  
157      def test_custom_mime_type(self, tmp_path):
158          img = tmp_path / "test.bin"
159          img.write_bytes(b"\x00" * 16)
160          result = _image_to_base64_data_url(img, mime_type="image/webp")
161          assert result.startswith("data:image/webp;base64,")
162  
163      def test_file_not_found_raises(self, tmp_path):
164          with pytest.raises(FileNotFoundError):
165              _image_to_base64_data_url(tmp_path / "nonexistent.png")
166  
167  
168  # ---------------------------------------------------------------------------
169  # _handle_vision_analyze — type signature & behavior
170  # ---------------------------------------------------------------------------
171  
172  
173  class TestHandleVisionAnalyze:
174      """Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
175  
176      def test_returns_awaitable(self):
177          """The handler must return an Awaitable (coroutine) since it's registered as async."""
178          with patch(
179              "tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
180          ) as mock_tool:
181              mock_tool.return_value = json.dumps({"result": "ok"})
182              result = _handle_vision_analyze(
183                  {
184                      "image_url": "https://example.com/img.png",
185                      "question": "What is this?",
186                  }
187              )
188              # It should be an Awaitable (coroutine)
189              assert isinstance(result, Awaitable)
190              # Clean up the coroutine to avoid RuntimeWarning
191              result.close()
192  
193      def test_prompt_contains_question(self):
194          """The full prompt should incorporate the user's question."""
195          with patch(
196              "tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
197          ) as mock_tool:
198              mock_tool.return_value = json.dumps({"result": "ok"})
199              coro = _handle_vision_analyze(
200                  {
201                      "image_url": "https://example.com/img.png",
202                      "question": "Describe the cat",
203                  }
204              )
205              # Clean up coroutine
206              coro.close()
207              call_args = mock_tool.call_args
208              full_prompt = call_args[0][1]  # second positional arg
209              assert "Describe the cat" in full_prompt
210              assert "Fully describe and explain" in full_prompt
211  
212      def test_uses_auxiliary_vision_model_env(self):
213          """AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
214          with (
215              patch(
216                  "tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
217              ) as mock_tool,
218              patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}),
219          ):
220              mock_tool.return_value = json.dumps({"result": "ok"})
221              coro = _handle_vision_analyze(
222                  {"image_url": "https://example.com/img.png", "question": "test"}
223              )
224              coro.close()
225              call_args = mock_tool.call_args
226              model = call_args[0][2]  # third positional arg
227              assert model == "custom/model-v1"
228  
229      def test_falls_back_to_default_model(self):
230          """Without AUXILIARY_VISION_MODEL, model should be None (let call_llm resolve default)."""
231          with (
232              patch(
233                  "tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
234              ) as mock_tool,
235              patch.dict(os.environ, {}, clear=False),
236          ):
237              # Ensure AUXILIARY_VISION_MODEL is not set
238              os.environ.pop("AUXILIARY_VISION_MODEL", None)
239              mock_tool.return_value = json.dumps({"result": "ok"})
240              coro = _handle_vision_analyze(
241                  {"image_url": "https://example.com/img.png", "question": "test"}
242              )
243              coro.close()
244              call_args = mock_tool.call_args
245              model = call_args[0][2]
246              # With no AUXILIARY_VISION_MODEL set, model should be None
247              # (the centralized call_llm router picks the default)
248              assert model is None
249  
250      def test_empty_args_graceful(self):
251          """Missing keys should default to empty strings, not raise."""
252          with patch(
253              "tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
254          ) as mock_tool:
255              mock_tool.return_value = json.dumps({"result": "ok"})
256              result = _handle_vision_analyze({})
257              assert isinstance(result, Awaitable)
258              result.close()
259  
260  
261  # ---------------------------------------------------------------------------
262  # Error logging with exc_info — verify tracebacks are logged
263  # ---------------------------------------------------------------------------
264  
265  
266  class TestErrorLoggingExcInfo:
267      """Verify that exc_info=True is used in error/warning log calls."""
268  
269      @pytest.mark.asyncio
270      async def test_download_failure_logs_exc_info(self, tmp_path, caplog):
271          """After max retries, the download error should include exc_info."""
272          from tools.vision_tools import _download_image
273  
274          with patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls:
275              mock_client = AsyncMock()
276              mock_client.__aenter__ = AsyncMock(return_value=mock_client)
277              mock_client.__aexit__ = AsyncMock(return_value=False)
278              mock_client.get = AsyncMock(side_effect=ConnectionError("network down"))
279              mock_client_cls.return_value = mock_client
280  
281              dest = tmp_path / "image.jpg"
282              with (
283                  caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
284                  pytest.raises(ConnectionError),
285              ):
286                  await _download_image(
287                      "https://example.com/img.jpg", dest, max_retries=1
288                  )
289  
290              # Should have logged with exc_info (traceback present)
291              error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
292              assert len(error_records) >= 1
293              assert error_records[0].exc_info is not None
294  
295      @pytest.mark.asyncio
296      async def test_analysis_error_logs_exc_info(self, caplog):
297          """When vision_analyze_tool encounters an error, it should log with exc_info."""
298          with (
299              patch("tools.vision_tools._validate_image_url", return_value=True),
300              patch(
301                  "tools.vision_tools._download_image",
302                  new_callable=AsyncMock,
303                  side_effect=Exception("download boom"),
304              ),
305              caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
306          ):
307              result = await vision_analyze_tool(
308                  "https://example.com/img.jpg", "describe this", "test/model"
309              )
310              result_data = json.loads(result)
311              # Error response uses "success": False, not an "error" key
312              assert result_data["success"] is False
313  
314              error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
315              assert any(r.exc_info and r.exc_info[0] is not None for r in error_records)
316  
317      @pytest.mark.asyncio
318      async def test_cleanup_error_logs_exc_info(self, tmp_path, caplog):
319          """Temp file cleanup failure should log warning with exc_info."""
320          # Create a real temp file that will be "downloaded"
321          temp_dir = tmp_path / "temp_vision_images"
322          temp_dir.mkdir()
323  
324          async def fake_download(url, dest, max_retries=3):
325              """Simulate download by writing file to the expected destination."""
326              dest.parent.mkdir(parents=True, exist_ok=True)
327              dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
328              return dest
329  
330          with (
331              patch("tools.vision_tools._validate_image_url", return_value=True),
332              patch("tools.vision_tools._download_image", side_effect=fake_download),
333              patch(
334                  "tools.vision_tools._image_to_base64_data_url",
335                  return_value="data:image/jpeg;base64,abc",
336              ),
337              caplog.at_level(logging.WARNING, logger="tools.vision_tools"),
338          ):
339              # Mock the async_call_llm function to return a mock response
340              mock_response = MagicMock()
341              mock_choice = MagicMock()
342              mock_choice.message.content = "A test image description"
343              mock_response.choices = [mock_choice]
344  
345              with (
346                  patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock, return_value=mock_response),
347              ):
348                  # Make unlink fail to trigger cleanup warning
349                  original_unlink = Path.unlink
350  
351                  def failing_unlink(self, *args, **kwargs):
352                      raise PermissionError("no permission")
353  
354                  with patch.object(Path, "unlink", failing_unlink):
355                      result = await vision_analyze_tool(
356                          "https://example.com/tempimg.jpg", "describe", "test/model"
357                      )
358  
359              warning_records = [
360                  r
361                  for r in caplog.records
362                  if r.levelno == logging.WARNING
363                  and "temporary file" in r.getMessage().lower()
364              ]
365              assert len(warning_records) >= 1
366              assert warning_records[0].exc_info is not None
367  
368  
369  class TestVisionConfig:
370      @pytest.mark.asyncio
371      async def test_vision_uses_configured_temperature_and_timeout(self, tmp_path):
372          img = tmp_path / "test.png"
373          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
374  
375          mock_response = MagicMock()
376          mock_choice = MagicMock()
377          mock_choice.message.content = "Configured image analysis"
378          mock_response.choices = [mock_choice]
379  
380          with (
381              patch("hermes_cli.config.load_config", return_value={
382                  "auxiliary": {"vision": {"temperature": 1, "timeout": 77}}
383              }),
384              patch(
385                  "tools.vision_tools._image_to_base64_data_url",
386                  return_value="data:image/png;base64,abc",
387              ),
388              patch(
389                  "tools.vision_tools.async_call_llm",
390                  new_callable=AsyncMock,
391                  return_value=mock_response,
392              ) as mock_llm,
393          ):
394              result = json.loads(await vision_analyze_tool(str(img), "describe this", "test/model"))
395  
396          assert result["success"] is True
397          assert mock_llm.await_args.kwargs["temperature"] == 1.0
398          assert mock_llm.await_args.kwargs["timeout"] == 77.0
399  
400      @pytest.mark.asyncio
401      async def test_vision_defaults_temperature_when_config_omits_it(self, tmp_path):
402          img = tmp_path / "test.png"
403          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
404  
405          mock_response = MagicMock()
406          mock_choice = MagicMock()
407          mock_choice.message.content = "Default image analysis"
408          mock_response.choices = [mock_choice]
409  
410          with (
411              patch("hermes_cli.config.load_config", return_value={"auxiliary": {"vision": {}}}),
412              patch(
413                  "tools.vision_tools._image_to_base64_data_url",
414                  return_value="data:image/png;base64,abc",
415              ),
416              patch(
417                  "tools.vision_tools.async_call_llm",
418                  new_callable=AsyncMock,
419                  return_value=mock_response,
420              ) as mock_llm,
421          ):
422              result = json.loads(await vision_analyze_tool(str(img), "describe this", "test/model"))
423  
424          assert result["success"] is True
425          assert mock_llm.await_args.kwargs["temperature"] == 0.1
426          assert mock_llm.await_args.kwargs["timeout"] == 120.0
427  
428  
429  class TestVisionSafetyGuards:
430      @pytest.mark.asyncio
431      async def test_local_non_image_file_rejected_before_llm_call(self, tmp_path):
432          secret = tmp_path / "secret.txt"
433          secret.write_text("TOP-SECRET=1\n", encoding="utf-8")
434  
435          with patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock) as mock_llm:
436              result = json.loads(await vision_analyze_tool(str(secret), "extract text"))
437  
438          assert result["success"] is False
439          assert "Only real image files are supported" in result["error"]
440          mock_llm.assert_not_awaited()
441  
442      @pytest.mark.asyncio
443      async def test_blocked_remote_url_short_circuits_before_download(self):
444          blocked = {
445              "host": "blocked.test",
446              "rule": "blocked.test",
447              "source": "config",
448              "message": "Blocked by website policy",
449          }
450  
451          with (
452              patch("tools.vision_tools.check_website_access", return_value=blocked),
453              patch("tools.vision_tools._validate_image_url", return_value=True),
454              patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download,
455          ):
456              result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe"))
457  
458          assert result["success"] is False
459          assert "Blocked by website policy" in result["error"]
460          mock_download.assert_not_awaited()
461  
462      @pytest.mark.asyncio
463      async def test_download_blocks_redirected_final_url(self, tmp_path):
464          from tools.vision_tools import _download_image
465  
466          def fake_check(url):
467              if url == "https://allowed.test/cat.png":
468                  return None
469              if url == "https://blocked.test/final.png":
470                  return {
471                      "host": "blocked.test",
472                      "rule": "blocked.test",
473                      "source": "config",
474                      "message": "Blocked by website policy",
475                  }
476              raise AssertionError(f"unexpected URL checked: {url}")
477  
478          class FakeResponse:
479              url = "https://blocked.test/final.png"
480              headers = {"content-length": "24"}
481              content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 16
482  
483              def raise_for_status(self):
484                  return None
485  
486          with (
487              patch("tools.vision_tools.check_website_access", side_effect=fake_check),
488              patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls,
489              pytest.raises(PermissionError, match="Blocked by website policy"),
490          ):
491              mock_client = AsyncMock()
492              mock_client.__aenter__ = AsyncMock(return_value=mock_client)
493              mock_client.__aexit__ = AsyncMock(return_value=False)
494              mock_client.get = AsyncMock(return_value=FakeResponse())
495              mock_client_cls.return_value = mock_client
496  
497              await _download_image("https://allowed.test/cat.png", tmp_path / "cat.png", max_retries=1)
498  
499          assert not (tmp_path / "cat.png").exists()
500  
501  
502  # ---------------------------------------------------------------------------
503  # check_vision_requirements
504  # ---------------------------------------------------------------------------
505  
506  
507  class TestVisionRequirements:
508      def test_check_requirements_returns_bool(self):
509          result = check_vision_requirements()
510          assert isinstance(result, bool)
511  
512      def test_check_requirements_accepts_codex_auth(self, monkeypatch, tmp_path):
513          monkeypatch.setenv("HERMES_HOME", str(tmp_path))
514          (tmp_path / "auth.json").write_text(
515              '{"active_provider":"openai-codex","providers":{"openai-codex":{"tokens":{"access_token":"codex-access-token","refresh_token":"codex-refresh-token"}}}}'
516          )
517          # config.yaml must reference the codex provider so vision auto-detect
518          # falls back to the active provider via _read_main_provider().
519          (tmp_path / "config.yaml").write_text(
520              'model:\n  default: gpt-4o\n  provider: openai-codex\n'
521          )
522          monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
523          monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
524          monkeypatch.delenv("OPENAI_API_KEY", raising=False)
525  
526          assert check_vision_requirements() is True
527  
528  
529  # ---------------------------------------------------------------------------
530  # Integration: registry entry
531  # ---------------------------------------------------------------------------
532  
533  
534  # ---------------------------------------------------------------------------
535  # Tilde expansion in local file paths
536  # ---------------------------------------------------------------------------
537  
538  
539  class TestTildeExpansion:
540      """Verify that ~/path style paths are expanded correctly."""
541  
542      @pytest.mark.asyncio
543      async def test_tilde_path_expanded_to_local_file(self, tmp_path, monkeypatch):
544          """vision_analyze_tool should expand ~ in file paths."""
545          # Create a fake image file under a fake home directory
546          fake_home = tmp_path / "fakehome"
547          fake_home.mkdir()
548          img = fake_home / "test_image.png"
549          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
550  
551          monkeypatch.setenv("HOME", str(fake_home))
552  
553          mock_response = MagicMock()
554          mock_choice = MagicMock()
555          mock_choice.message.content = "A test image"
556          mock_response.choices = [mock_choice]
557  
558          with (
559              patch(
560                  "tools.vision_tools._image_to_base64_data_url",
561                  return_value="data:image/png;base64,abc",
562              ),
563              patch(
564                  "tools.vision_tools.async_call_llm",
565                  new_callable=AsyncMock,
566                  return_value=mock_response,
567              ),
568          ):
569              result = await vision_analyze_tool(
570                  "~/test_image.png", "describe this", "test/model"
571              )
572              data = json.loads(result)
573              assert data["success"] is True
574              assert data["analysis"] == "A test image"
575  
576      @pytest.mark.asyncio
577      async def test_tilde_path_nonexistent_file_gives_error(self, tmp_path, monkeypatch):
578          """A tilde path that doesn't resolve to a real file should fail gracefully."""
579          fake_home = tmp_path / "fakehome"
580          fake_home.mkdir()
581          monkeypatch.setenv("HOME", str(fake_home))
582  
583          result = await vision_analyze_tool(
584              "~/nonexistent.png", "describe this", "test/model"
585          )
586          data = json.loads(result)
587          assert data["success"] is False
588  
589  
590  # ---------------------------------------------------------------------------
591  # file:// URI support
592  # ---------------------------------------------------------------------------
593  
594  
595  class TestFileUriSupport:
596      """Verify that file:// URIs resolve as local file paths."""
597  
598      @pytest.mark.asyncio
599      async def test_file_uri_resolved_as_local_path(self, tmp_path):
600          """file:///absolute/path should be treated as a local file."""
601          img = tmp_path / "photo.png"
602          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
603  
604          mock_response = MagicMock()
605          mock_choice = MagicMock()
606          mock_choice.message.content = "A test image"
607          mock_response.choices = [mock_choice]
608  
609          with (
610              patch(
611                  "tools.vision_tools._image_to_base64_data_url",
612                  return_value="data:image/png;base64,abc",
613              ),
614              patch(
615                  "tools.vision_tools.async_call_llm",
616                  new_callable=AsyncMock,
617                  return_value=mock_response,
618              ),
619          ):
620              result = await vision_analyze_tool(
621                  f"file://{img}", "describe this", "test/model"
622              )
623              data = json.loads(result)
624              assert data["success"] is True
625  
626      @pytest.mark.asyncio
627      async def test_file_uri_nonexistent_gives_error(self, tmp_path):
628          """file:// pointing to a missing file should fail gracefully."""
629          result = await vision_analyze_tool(
630              f"file://{tmp_path}/nonexistent.png", "describe this", "test/model"
631          )
632          data = json.loads(result)
633          assert data["success"] is False
634  
635  
636  # ---------------------------------------------------------------------------
637  # Base64 size pre-flight check
638  # ---------------------------------------------------------------------------
639  
640  
641  class TestBase64SizeLimit:
642      """Verify that oversized images are rejected before hitting the API."""
643  
644      @pytest.mark.asyncio
645      async def test_oversized_image_rejected_before_api_call(self, tmp_path):
646          """Images exceeding the 20 MB hard limit should fail with a clear error."""
647          img = tmp_path / "huge.png"
648          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * (4 * 1024 * 1024))
649  
650          # Patch the hard limit to a small value so the test runs fast.
651          with patch("tools.vision_tools._MAX_BASE64_BYTES", 1000), \
652               patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock) as mock_llm:
653              result = json.loads(await vision_analyze_tool(str(img), "describe this"))
654  
655          assert result["success"] is False
656          assert "too large" in result["error"].lower()
657          mock_llm.assert_not_awaited()
658  
659      @pytest.mark.asyncio
660      async def test_small_image_not_rejected(self, tmp_path):
661          """Images well under the limit should pass the size check."""
662          img = tmp_path / "small.png"
663          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 64)
664  
665          mock_response = MagicMock()
666          mock_choice = MagicMock()
667          mock_choice.message.content = "Small image"
668          mock_response.choices = [mock_choice]
669  
670          with (
671              patch(
672                  "tools.vision_tools.async_call_llm",
673                  new_callable=AsyncMock,
674                  return_value=mock_response,
675              ),
676          ):
677              result = json.loads(await vision_analyze_tool(str(img), "describe this", "test/model"))
678  
679          assert result["success"] is True
680  
681  
682  # ---------------------------------------------------------------------------
683  # Error classification for 400 responses
684  # ---------------------------------------------------------------------------
685  
686  
687  class TestErrorClassification:
688      """Verify that API 400 errors produce actionable guidance."""
689  
690      @pytest.mark.asyncio
691      async def test_invalid_request_error_gives_image_guidance(self, tmp_path):
692          """An invalid_request_error from the API should mention image size/format."""
693          img = tmp_path / "test.png"
694          img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
695  
696          api_error = Exception(
697              "Error code: 400 - {'type': 'error', 'error': "
698              "{'type': 'invalid_request_error', 'message': 'Invalid request data'}}"
699          )
700  
701          with (
702              patch(
703                  "tools.vision_tools._image_to_base64_data_url",
704                  return_value="data:image/png;base64,abc",
705              ),
706              patch(
707                  "tools.vision_tools.async_call_llm",
708                  new_callable=AsyncMock,
709                  side_effect=api_error,
710              ),
711          ):
712              result = json.loads(await vision_analyze_tool(str(img), "describe", "test/model"))
713  
714          assert result["success"] is False
715          assert "rejected the image" in result["analysis"].lower()
716          assert "smaller" in result["analysis"].lower()
717  
718  
719  class TestVisionRegistration:
720      def test_vision_analyze_registered(self):
721          from tools.registry import registry
722  
723          entry = registry._tools.get("vision_analyze")
724          assert entry is not None
725          assert entry.toolset == "vision"
726          assert entry.is_async is True
727  
728      def test_schema_has_required_fields(self):
729          from tools.registry import registry
730  
731          entry = registry._tools.get("vision_analyze")
732          schema = entry.schema
733          assert schema["name"] == "vision_analyze"
734          params = schema.get("parameters", {})
735          props = params.get("properties", {})
736          assert "image_url" in props
737          assert "question" in props
738  
739      def test_handler_is_callable(self):
740          from tools.registry import registry
741  
742          entry = registry._tools.get("vision_analyze")
743          assert callable(entry.handler)
744  
745  
746  # ---------------------------------------------------------------------------
747  # _resize_image_for_vision — auto-resize oversized images
748  # ---------------------------------------------------------------------------
749  
750  
751  class TestResizeImageForVision:
752      """Tests for the auto-resize function."""
753  
754      def test_small_image_returned_as_is(self, tmp_path):
755          """Images under the limit should be returned unchanged."""
756          # Create a small 10x10 red PNG
757          try:
758              from PIL import Image
759          except ImportError:
760              pytest.skip("Pillow not installed")
761          img = Image.new("RGB", (10, 10), (255, 0, 0))
762          path = tmp_path / "small.png"
763          img.save(path, "PNG")
764  
765          result = _resize_image_for_vision(path, mime_type="image/png")
766          assert result.startswith("data:image/png;base64,")
767          assert len(result) < _MAX_BASE64_BYTES
768  
769      def test_large_image_is_resized(self, tmp_path):
770          """Images over the default target should be auto-resized to fit."""
771          try:
772              from PIL import Image
773          except ImportError:
774              pytest.skip("Pillow not installed")
775          # Create a large image that will exceed 5 MB in base64
776          # A 4000x4000 uncompressed PNG will be large
777          img = Image.new("RGB", (4000, 4000), (128, 200, 50))
778          path = tmp_path / "large.png"
779          img.save(path, "PNG")
780  
781          result = _resize_image_for_vision(path, mime_type="image/png")
782          assert result.startswith("data:image/png;base64,")
783          # Default target is _RESIZE_TARGET_BYTES (5 MB), not _MAX_BASE64_BYTES (20 MB)
784          assert len(result) <= _RESIZE_TARGET_BYTES
785  
786      def test_custom_max_bytes(self, tmp_path):
787          """The max_base64_bytes parameter should be respected."""
788          try:
789              from PIL import Image
790          except ImportError:
791              pytest.skip("Pillow not installed")
792          img = Image.new("RGB", (200, 200), (0, 128, 255))
793          path = tmp_path / "medium.png"
794          img.save(path, "PNG")
795  
796          # Set a very low limit to force resizing
797          result = _resize_image_for_vision(path, max_base64_bytes=500)
798          # Should still return a valid data URL
799          assert result.startswith("data:image/")
800  
801      def test_jpeg_output_for_non_png(self, tmp_path):
802          """Non-PNG images should be resized as JPEG."""
803          try:
804              from PIL import Image
805          except ImportError:
806              pytest.skip("Pillow not installed")
807          img = Image.new("RGB", (2000, 2000), (255, 128, 0))
808          path = tmp_path / "photo.jpg"
809          img.save(path, "JPEG", quality=95)
810  
811          result = _resize_image_for_vision(path, mime_type="image/jpeg",
812                                             max_base64_bytes=50_000)
813          assert result.startswith("data:image/jpeg;base64,")
814  
815      def test_constants_sane(self):
816          """Hard limit should be larger than resize target."""
817          assert _MAX_BASE64_BYTES == 20 * 1024 * 1024
818          assert _RESIZE_TARGET_BYTES == 5 * 1024 * 1024
819          assert _MAX_BASE64_BYTES > _RESIZE_TARGET_BYTES
820  
821      def test_extreme_aspect_ratio_preserved(self, tmp_path):
822          """Extreme aspect ratios should be preserved during resize."""
823          try:
824              from PIL import Image
825          except ImportError:
826              pytest.skip("Pillow not installed")
827          # Very wide panorama: 8000x200
828          img = Image.new("RGB", (8000, 200), (100, 150, 200))
829          path = tmp_path / "panorama.png"
830          img.save(path, "PNG")
831  
832          result = _resize_image_for_vision(path, mime_type="image/png",
833                                             max_base64_bytes=50_000)
834          assert result.startswith("data:image/")
835          # Decode and check aspect ratio is roughly preserved
836          import base64
837          header, b64data = result.split(",", 1)
838          raw = base64.b64decode(b64data)
839          from io import BytesIO
840          resized = Image.open(BytesIO(raw))
841          original_ratio = 8000 / 200  # 40:1
842          resized_ratio = resized.width / resized.height if resized.height > 0 else 0
843          # Allow some tolerance (floor clamping), but ratio should stay above 10:1
844          # With independent halving, ratio would collapse to ~1:1. Proportional
845          # scaling should keep it well above 10.
846          assert resized_ratio > 10, (
847              f"Aspect ratio collapsed: {resized.width}x{resized.height} "
848              f"(ratio {resized_ratio:.1f}, expected >10)"
849          )
850  
851      def test_tall_narrow_image_preserved(self, tmp_path):
852          """Tall narrow images should also preserve aspect ratio."""
853          try:
854              from PIL import Image
855          except ImportError:
856              pytest.skip("Pillow not installed")
857          # Very tall: 200x6000
858          img = Image.new("RGB", (200, 6000), (200, 100, 50))
859          path = tmp_path / "tall.png"
860          img.save(path, "PNG")
861  
862          result = _resize_image_for_vision(path, mime_type="image/png",
863                                             max_base64_bytes=50_000)
864          assert result.startswith("data:image/")
865          import base64
866          from io import BytesIO
867          header, b64data = result.split(",", 1)
868          raw = base64.b64decode(b64data)
869          resized = Image.open(BytesIO(raw))
870          original_ratio = 6000 / 200  # 30:1 (h/w)
871          resized_ratio = resized.height / resized.width if resized.width > 0 else 0
872          assert resized_ratio > 5, (
873              f"Aspect ratio collapsed: {resized.width}x{resized.height} "
874              f"(h/w ratio {resized_ratio:.1f}, expected >5)"
875          )
876  
877      def test_no_pillow_returns_original(self, tmp_path):
878          """Without Pillow, oversized images should be returned as-is."""
879          # Create a dummy file
880          path = tmp_path / "test.png"
881          # Write enough bytes to exceed a tiny limit
882          path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 1000)
883  
884          with patch("tools.vision_tools._image_to_base64_data_url") as mock_b64:
885              # Simulate a large base64 result
886              mock_b64.return_value = "data:image/png;base64," + "A" * 200
887              with patch.dict("sys.modules", {"PIL": None, "PIL.Image": None}):
888                  result = _resize_image_for_vision(path, max_base64_bytes=100)
889                  # Should return the original (oversized) data url
890                  assert len(result) > 100
891  
892  
893  # ---------------------------------------------------------------------------
894  # _is_image_size_error — detect size-related API errors
895  # ---------------------------------------------------------------------------
896  
897  
898  class TestIsImageSizeError:
899      """Tests for the size-error detection helper."""
900  
901      def test_too_large_message(self):
902          assert _is_image_size_error(Exception("Request payload too large"))
903  
904      def test_413_status(self):
905          assert _is_image_size_error(Exception("HTTP 413 Payload Too Large"))
906  
907      def test_invalid_request(self):
908          assert _is_image_size_error(Exception("invalid_request_error: image too big"))
909  
910      def test_exceeds_limit(self):
911          assert _is_image_size_error(Exception("Image exceeds maximum size"))
912  
913      def test_unrelated_error(self):
914          assert not _is_image_size_error(Exception("Connection refused"))
915  
916      def test_auth_error(self):
917          assert not _is_image_size_error(Exception("401 Unauthorized"))
918  
919      def test_empty_message(self):
920          assert not _is_image_size_error(Exception(""))