Fix test_analysis_error_logs_exc_info: mock _aux_async_client so download path is reached
This commit is contained in:
parent
c358af7861
commit
0229e6b407
1 changed files with 82 additions and 30 deletions
|
|
@ -25,6 +25,7 @@ from tools.vision_tools import (
|
||||||
# _validate_image_url — urlparse-based validation
|
# _validate_image_url — urlparse-based validation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestValidateImageUrl:
|
class TestValidateImageUrl:
|
||||||
"""Tests for URL validation, including urlparse-based netloc check."""
|
"""Tests for URL validation, including urlparse-based netloc check."""
|
||||||
|
|
||||||
|
|
@ -95,6 +96,7 @@ class TestValidateImageUrl:
|
||||||
# _determine_mime_type
|
# _determine_mime_type
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestDetermineMimeType:
|
class TestDetermineMimeType:
|
||||||
def test_jpg(self):
|
def test_jpg(self):
|
||||||
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
|
assert _determine_mime_type(Path("photo.jpg")) == "image/jpeg"
|
||||||
|
|
@ -119,6 +121,7 @@ class TestDetermineMimeType:
|
||||||
# _image_to_base64_data_url
|
# _image_to_base64_data_url
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestImageToBase64DataUrl:
|
class TestImageToBase64DataUrl:
|
||||||
def test_returns_data_url(self, tmp_path):
|
def test_returns_data_url(self, tmp_path):
|
||||||
img = tmp_path / "test.png"
|
img = tmp_path / "test.png"
|
||||||
|
|
@ -141,15 +144,21 @@ class TestImageToBase64DataUrl:
|
||||||
# _handle_vision_analyze — type signature & behavior
|
# _handle_vision_analyze — type signature & behavior
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestHandleVisionAnalyze:
|
class TestHandleVisionAnalyze:
|
||||||
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
|
"""Verify _handle_vision_analyze returns an Awaitable and builds correct prompt."""
|
||||||
|
|
||||||
def test_returns_awaitable(self):
|
def test_returns_awaitable(self):
|
||||||
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
|
"""The handler must return an Awaitable (coroutine) since it's registered as async."""
|
||||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
with patch(
|
||||||
|
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||||
|
) as mock_tool:
|
||||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||||
result = _handle_vision_analyze(
|
result = _handle_vision_analyze(
|
||||||
{"image_url": "https://example.com/img.png", "question": "What is this?"}
|
{
|
||||||
|
"image_url": "https://example.com/img.png",
|
||||||
|
"question": "What is this?",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
# It should be an Awaitable (coroutine)
|
# It should be an Awaitable (coroutine)
|
||||||
assert isinstance(result, Awaitable)
|
assert isinstance(result, Awaitable)
|
||||||
|
|
@ -158,10 +167,15 @@ class TestHandleVisionAnalyze:
|
||||||
|
|
||||||
def test_prompt_contains_question(self):
|
def test_prompt_contains_question(self):
|
||||||
"""The full prompt should incorporate the user's question."""
|
"""The full prompt should incorporate the user's question."""
|
||||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
with patch(
|
||||||
|
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||||
|
) as mock_tool:
|
||||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||||
coro = _handle_vision_analyze(
|
coro = _handle_vision_analyze(
|
||||||
{"image_url": "https://example.com/img.png", "question": "Describe the cat"}
|
{
|
||||||
|
"image_url": "https://example.com/img.png",
|
||||||
|
"question": "Describe the cat",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
# Clean up coroutine
|
# Clean up coroutine
|
||||||
coro.close()
|
coro.close()
|
||||||
|
|
@ -172,8 +186,12 @@ class TestHandleVisionAnalyze:
|
||||||
|
|
||||||
def test_uses_auxiliary_vision_model_env(self):
|
def test_uses_auxiliary_vision_model_env(self):
|
||||||
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
|
"""AUXILIARY_VISION_MODEL env var should override DEFAULT_VISION_MODEL."""
|
||||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
with (
|
||||||
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}):
|
patch(
|
||||||
|
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||||
|
) as mock_tool,
|
||||||
|
patch.dict(os.environ, {"AUXILIARY_VISION_MODEL": "custom/model-v1"}),
|
||||||
|
):
|
||||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||||
coro = _handle_vision_analyze(
|
coro = _handle_vision_analyze(
|
||||||
{"image_url": "https://example.com/img.png", "question": "test"}
|
{"image_url": "https://example.com/img.png", "question": "test"}
|
||||||
|
|
@ -185,8 +203,12 @@ class TestHandleVisionAnalyze:
|
||||||
|
|
||||||
def test_falls_back_to_default_model(self):
|
def test_falls_back_to_default_model(self):
|
||||||
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
|
"""Without AUXILIARY_VISION_MODEL, should use DEFAULT_VISION_MODEL or fallback."""
|
||||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool, \
|
with (
|
||||||
patch.dict(os.environ, {}, clear=False):
|
patch(
|
||||||
|
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||||
|
) as mock_tool,
|
||||||
|
patch.dict(os.environ, {}, clear=False),
|
||||||
|
):
|
||||||
# Ensure AUXILIARY_VISION_MODEL is not set
|
# Ensure AUXILIARY_VISION_MODEL is not set
|
||||||
os.environ.pop("AUXILIARY_VISION_MODEL", None)
|
os.environ.pop("AUXILIARY_VISION_MODEL", None)
|
||||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||||
|
|
@ -202,7 +224,9 @@ class TestHandleVisionAnalyze:
|
||||||
|
|
||||||
def test_empty_args_graceful(self):
|
def test_empty_args_graceful(self):
|
||||||
"""Missing keys should default to empty strings, not raise."""
|
"""Missing keys should default to empty strings, not raise."""
|
||||||
with patch("tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock) as mock_tool:
|
with patch(
|
||||||
|
"tools.vision_tools.vision_analyze_tool", new_callable=AsyncMock
|
||||||
|
) as mock_tool:
|
||||||
mock_tool.return_value = json.dumps({"result": "ok"})
|
mock_tool.return_value = json.dumps({"result": "ok"})
|
||||||
result = _handle_vision_analyze({})
|
result = _handle_vision_analyze({})
|
||||||
assert isinstance(result, Awaitable)
|
assert isinstance(result, Awaitable)
|
||||||
|
|
@ -213,6 +237,7 @@ class TestHandleVisionAnalyze:
|
||||||
# Error logging with exc_info — verify tracebacks are logged
|
# Error logging with exc_info — verify tracebacks are logged
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestErrorLoggingExcInfo:
|
class TestErrorLoggingExcInfo:
|
||||||
"""Verify that exc_info=True is used in error/warning log calls."""
|
"""Verify that exc_info=True is used in error/warning log calls."""
|
||||||
|
|
||||||
|
|
@ -229,9 +254,13 @@ class TestErrorLoggingExcInfo:
|
||||||
mock_client_cls.return_value = mock_client
|
mock_client_cls.return_value = mock_client
|
||||||
|
|
||||||
dest = tmp_path / "image.jpg"
|
dest = tmp_path / "image.jpg"
|
||||||
with caplog.at_level(logging.ERROR, logger="tools.vision_tools"), \
|
with (
|
||||||
pytest.raises(ConnectionError):
|
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
|
||||||
await _download_image("https://example.com/img.jpg", dest, max_retries=1)
|
pytest.raises(ConnectionError),
|
||||||
|
):
|
||||||
|
await _download_image(
|
||||||
|
"https://example.com/img.jpg", dest, max_retries=1
|
||||||
|
)
|
||||||
|
|
||||||
# Should have logged with exc_info (traceback present)
|
# Should have logged with exc_info (traceback present)
|
||||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||||
|
|
@ -241,11 +270,17 @@ class TestErrorLoggingExcInfo:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_analysis_error_logs_exc_info(self, caplog):
|
async def test_analysis_error_logs_exc_info(self, caplog):
|
||||||
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
|
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
|
||||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
with (
|
||||||
patch("tools.vision_tools._download_image", new_callable=AsyncMock,
|
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||||
side_effect=Exception("download boom")), \
|
patch(
|
||||||
caplog.at_level(logging.ERROR, logger="tools.vision_tools"):
|
"tools.vision_tools._download_image",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
side_effect=Exception("download boom"),
|
||||||
|
),
|
||||||
|
patch("tools.vision_tools._aux_async_client", MagicMock()),
|
||||||
|
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
|
||||||
|
caplog.at_level(logging.ERROR, logger="tools.vision_tools"),
|
||||||
|
):
|
||||||
result = await vision_analyze_tool(
|
result = await vision_analyze_tool(
|
||||||
"https://example.com/img.jpg", "describe this", "test/model"
|
"https://example.com/img.jpg", "describe this", "test/model"
|
||||||
)
|
)
|
||||||
|
|
@ -269,14 +304,20 @@ class TestErrorLoggingExcInfo:
|
||||||
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
|
dest.write_bytes(b"\xff\xd8\xff" + b"\x00" * 16)
|
||||||
return dest
|
return dest
|
||||||
|
|
||||||
with patch("tools.vision_tools._validate_image_url", return_value=True), \
|
with (
|
||||||
patch("tools.vision_tools._download_image", side_effect=fake_download), \
|
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||||
patch("tools.vision_tools._image_to_base64_data_url",
|
patch("tools.vision_tools._download_image", side_effect=fake_download),
|
||||||
return_value="data:image/jpeg;base64,abc"), \
|
patch(
|
||||||
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None), \
|
"tools.vision_tools._image_to_base64_data_url",
|
||||||
patch("agent.auxiliary_client.auxiliary_max_tokens_param", return_value={"max_tokens": 2000}), \
|
return_value="data:image/jpeg;base64,abc",
|
||||||
caplog.at_level(logging.WARNING, logger="tools.vision_tools"):
|
),
|
||||||
|
patch("agent.auxiliary_client.get_auxiliary_extra_body", return_value=None),
|
||||||
|
patch(
|
||||||
|
"agent.auxiliary_client.auxiliary_max_tokens_param",
|
||||||
|
return_value={"max_tokens": 2000},
|
||||||
|
),
|
||||||
|
caplog.at_level(logging.WARNING, logger="tools.vision_tools"),
|
||||||
|
):
|
||||||
# Mock the vision client
|
# Mock the vision client
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
|
@ -286,11 +327,13 @@ class TestErrorLoggingExcInfo:
|
||||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
# Patch module-level _aux_async_client so the tool doesn't bail early
|
# Patch module-level _aux_async_client so the tool doesn't bail early
|
||||||
with patch("tools.vision_tools._aux_async_client", mock_client), \
|
with (
|
||||||
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"):
|
patch("tools.vision_tools._aux_async_client", mock_client),
|
||||||
|
patch("tools.vision_tools.DEFAULT_VISION_MODEL", "test/model"),
|
||||||
|
):
|
||||||
# Make unlink fail to trigger cleanup warning
|
# Make unlink fail to trigger cleanup warning
|
||||||
original_unlink = Path.unlink
|
original_unlink = Path.unlink
|
||||||
|
|
||||||
def failing_unlink(self, *args, **kwargs):
|
def failing_unlink(self, *args, **kwargs):
|
||||||
raise PermissionError("no permission")
|
raise PermissionError("no permission")
|
||||||
|
|
||||||
|
|
@ -299,8 +342,12 @@ class TestErrorLoggingExcInfo:
|
||||||
"https://example.com/tempimg.jpg", "describe", "test/model"
|
"https://example.com/tempimg.jpg", "describe", "test/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
warning_records = [r for r in caplog.records if r.levelno == logging.WARNING
|
warning_records = [
|
||||||
and "temporary file" in r.getMessage().lower()]
|
r
|
||||||
|
for r in caplog.records
|
||||||
|
if r.levelno == logging.WARNING
|
||||||
|
and "temporary file" in r.getMessage().lower()
|
||||||
|
]
|
||||||
assert len(warning_records) >= 1
|
assert len(warning_records) >= 1
|
||||||
assert warning_records[0].exc_info is not None
|
assert warning_records[0].exc_info is not None
|
||||||
|
|
||||||
|
|
@ -309,6 +356,7 @@ class TestErrorLoggingExcInfo:
|
||||||
# check_vision_requirements & get_debug_session_info
|
# check_vision_requirements & get_debug_session_info
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestVisionRequirements:
|
class TestVisionRequirements:
|
||||||
def test_check_requirements_returns_bool(self):
|
def test_check_requirements_returns_bool(self):
|
||||||
result = check_vision_requirements()
|
result = check_vision_requirements()
|
||||||
|
|
@ -327,9 +375,11 @@ class TestVisionRequirements:
|
||||||
# Integration: registry entry
|
# Integration: registry entry
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestVisionRegistration:
|
class TestVisionRegistration:
|
||||||
def test_vision_analyze_registered(self):
|
def test_vision_analyze_registered(self):
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
|
|
||||||
entry = registry._tools.get("vision_analyze")
|
entry = registry._tools.get("vision_analyze")
|
||||||
assert entry is not None
|
assert entry is not None
|
||||||
assert entry.toolset == "vision"
|
assert entry.toolset == "vision"
|
||||||
|
|
@ -337,6 +387,7 @@ class TestVisionRegistration:
|
||||||
|
|
||||||
def test_schema_has_required_fields(self):
|
def test_schema_has_required_fields(self):
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
|
|
||||||
entry = registry._tools.get("vision_analyze")
|
entry = registry._tools.get("vision_analyze")
|
||||||
schema = entry.schema
|
schema = entry.schema
|
||||||
assert schema["name"] == "vision_analyze"
|
assert schema["name"] == "vision_analyze"
|
||||||
|
|
@ -347,5 +398,6 @@ class TestVisionRegistration:
|
||||||
|
|
||||||
def test_handler_is_callable(self):
|
def test_handler_is_callable(self):
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
|
|
||||||
entry = registry._tools.get("vision_analyze")
|
entry = registry._tools.get("vision_analyze")
|
||||||
assert callable(entry.handler)
|
assert callable(entry.handler)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue