"""Tests for qa_service: text QA (US5) and image QA (US6).""" import base64 import json import pytest from unittest.mock import AsyncMock from app.core.exceptions import LLMCallError, LLMParseError, StorageError # --------------------------------------------------------------------------- # Shared fixtures / helpers # --------------------------------------------------------------------------- SAMPLE_QA_JSON = json.dumps([ {"question": "电缆接头位于哪里?", "answer": "配电箱左侧"}, ]) # --------------------------------------------------------------------------- # T039 — Text QA service tests (US5) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_gen_text_qa_prompt_contains_triples(mock_llm): """Triple fields and source_snippet must appear in the message sent to LLM.""" from app.models.qa_models import GenTextQARequest, TextQAItem from app.services.qa_service import gen_text_qa mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON) req = GenTextQARequest(items=[ TextQAItem( subject="电缆接头", predicate="位于", object="配电箱左侧", source_snippet="电缆接头位于配电箱左侧", ) ]) await gen_text_qa(req, mock_llm) assert mock_llm.chat.called call_args = mock_llm.chat.call_args messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"] prompt_text = messages[0]["content"] assert "电缆接头" in prompt_text assert "位于" in prompt_text assert "配电箱左侧" in prompt_text assert "电缆接头位于配电箱左侧" in prompt_text @pytest.mark.asyncio async def test_gen_text_qa_returns_qa_pair_list(mock_llm): """Parsed JSON must be returned as QAPair list.""" from app.models.qa_models import GenTextQARequest, QAPair, TextQAItem from app.services.qa_service import gen_text_qa mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON) req = GenTextQARequest(items=[ TextQAItem( subject="电缆接头", predicate="位于", object="配电箱左侧", source_snippet="电缆接头位于配电箱左侧", ) ]) result = await gen_text_qa(req, mock_llm) assert len(result.pairs) == 1 pair = result.pairs[0] assert isinstance(pair, QAPair) assert pair.question == "电缆接头位于哪里?" assert pair.answer == "配电箱左侧" @pytest.mark.asyncio async def test_gen_text_qa_llm_parse_error_on_malformed_response(mock_llm): """LLMParseError must be raised when LLM returns non-JSON.""" from app.models.qa_models import GenTextQARequest, TextQAItem from app.services.qa_service import gen_text_qa mock_llm.chat = AsyncMock(return_value="this is not json {{") req = GenTextQARequest(items=[ TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip") ]) with pytest.raises(LLMParseError): await gen_text_qa(req, mock_llm) @pytest.mark.asyncio async def test_gen_text_qa_llm_call_error_propagates(mock_llm): """LLMCallError from LLM client must propagate unchanged.""" from app.models.qa_models import GenTextQARequest, TextQAItem from app.services.qa_service import gen_text_qa mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout")) req = GenTextQARequest(items=[ TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip") ]) with pytest.raises(LLMCallError): await gen_text_qa(req, mock_llm) # --------------------------------------------------------------------------- # T040 — Image QA service tests (US6) # --------------------------------------------------------------------------- FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content" @pytest.mark.asyncio async def test_gen_image_qa_downloads_image_and_encodes_base64(mock_llm, mock_storage): """Storage.download_bytes must be called, result base64-encoded in LLM message.""" from app.models.qa_models import GenImageQARequest, ImageQAItem from app.services.qa_service import gen_image_qa mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON) req = GenImageQARequest(items=[ ImageQAItem( subject="电缆接头", predicate="位于", object="配电箱左侧", cropped_image_path="crops/1/0.jpg", ) ]) await gen_image_qa(req, mock_llm, mock_storage) # Storage download must have been called with the correct path mock_storage.download_bytes.assert_called_once() call_args = mock_storage.download_bytes.call_args path_arg = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("path", call_args.kwargs.get("key")) assert path_arg == "crops/1/0.jpg" @pytest.mark.asyncio async def test_gen_image_qa_multimodal_message_format(mock_llm, mock_storage): """Multimodal message must contain inline base64 image_url and text.""" from app.models.qa_models import GenImageQARequest, ImageQAItem from app.services.qa_service import gen_image_qa mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON) req = GenImageQARequest(items=[ ImageQAItem( subject="电缆接头", predicate="位于", object="配电箱左侧", qualifier="2024检修", cropped_image_path="crops/1/0.jpg", ) ]) await gen_image_qa(req, mock_llm, mock_storage) assert mock_llm.chat_vision.called call_args = mock_llm.chat_vision.call_args messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"] # Find the content list in messages content = messages[0]["content"] assert isinstance(content, list) # Must have an image_url part with inline base64 data URI image_parts = [p for p in content if p.get("type") == "image_url"] assert len(image_parts) >= 1 url = image_parts[0]["image_url"]["url"] expected_b64 = base64.b64encode(FAKE_IMAGE_BYTES).decode() assert url == f"data:image/jpeg;base64,{expected_b64}" # Must have a text part containing quad info text_parts = [p for p in content if p.get("type") == "text"] assert len(text_parts) >= 1 text = text_parts[0]["text"] assert "电缆接头" in text assert "位于" in text assert "配电箱左侧" in text @pytest.mark.asyncio async def test_gen_image_qa_returns_image_qa_pair_with_image_path(mock_llm, mock_storage): """Result ImageQAPair must include image_path from the item.""" from app.models.qa_models import GenImageQARequest, ImageQAItem, ImageQAPair from app.services.qa_service import gen_image_qa mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON) req = GenImageQARequest(items=[ ImageQAItem( subject="电缆接头", predicate="位于", object="配电箱左侧", cropped_image_path="crops/1/0.jpg", ) ]) result = await gen_image_qa(req, mock_llm, mock_storage) assert len(result.pairs) == 1 pair = result.pairs[0] assert isinstance(pair, ImageQAPair) assert pair.question == "电缆接头位于哪里?" assert pair.answer == "配电箱左侧" assert pair.image_path == "crops/1/0.jpg" @pytest.mark.asyncio async def test_gen_image_qa_storage_error_propagates(mock_llm, mock_storage): """StorageError from download must propagate unchanged.""" from app.models.qa_models import GenImageQARequest, ImageQAItem from app.services.qa_service import gen_image_qa mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down")) req = GenImageQARequest(items=[ ImageQAItem( subject="s", predicate="p", object="o", cropped_image_path="crops/1/0.jpg", ) ]) with pytest.raises(StorageError): await gen_image_qa(req, mock_llm, mock_storage)