feat(US5+6): QA generation — POST /api/v1/qa/gen-text and /gen-image

- Add qa_models.py with TextQAItem, GenTextQARequest, QAPair, ImageQAItem,
  GenImageQARequest, ImageQAPair, TextQAResponse, ImageQAResponse
- Implement gen_text_qa(): batch-formats triples into a single prompt, calls
  llm.chat(), parses JSON array via extract_json()
- Implement gen_image_qa(): downloads cropped image from source-data bucket,
  base64-encodes inline (data URI), builds multimodal message, calls
  llm.chat_vision(), parses JSON; image_path preserved on ImageQAPair
- Replace qa.py stub with full router: POST /qa/gen-text and /qa/gen-image
  using Depends(get_llm_client) and Depends(get_storage_client)
- 15 new tests (8 service + 7 router), 53/53 total passing
This commit is contained in:
wh
2026-04-10 16:05:49 +08:00
parent 0274bb470a
commit 4211e587ee
5 changed files with 539 additions and 1 deletions

121
tests/test_qa_router.py Normal file
View File

@@ -0,0 +1,121 @@
"""Tests for QA router: /api/v1/qa/gen-text and /api/v1/qa/gen-image."""
import json
import pytest
from unittest.mock import AsyncMock
from app.core.exceptions import LLMCallError, LLMParseError, StorageError
SAMPLE_QA_JSON = json.dumps([
{"question": "电缆接头位于哪里?", "answer": "配电箱左侧"},
])
FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content"
TEXT_QA_PAYLOAD = {
"items": [
{
"subject": "电缆接头",
"predicate": "位于",
"object": "配电箱左侧",
"source_snippet": "电缆接头位于配电箱左侧",
}
]
}
IMAGE_QA_PAYLOAD = {
"items": [
{
"subject": "电缆接头",
"predicate": "位于",
"object": "配电箱左侧",
"cropped_image_path": "crops/1/0.jpg",
}
]
}
# ---------------------------------------------------------------------------
# POST /api/v1/qa/gen-text
# ---------------------------------------------------------------------------
def test_gen_text_qa_returns_200(client, mock_llm):
mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON)
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
assert resp.status_code == 200
data = resp.json()
assert "pairs" in data
assert len(data["pairs"]) == 1
assert data["pairs"][0]["question"] == "电缆接头位于哪里?"
assert data["pairs"][0]["answer"] == "配电箱左侧"
def test_gen_text_qa_llm_parse_error_returns_502(client, mock_llm):
mock_llm.chat = AsyncMock(return_value="not valid json {{")
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
assert resp.status_code == 502
assert resp.json()["code"] == "LLM_PARSE_ERROR"
def test_gen_text_qa_llm_call_error_returns_503(client, mock_llm):
mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout"))
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
assert resp.status_code == 503
assert resp.json()["code"] == "LLM_CALL_ERROR"
# ---------------------------------------------------------------------------
# POST /api/v1/qa/gen-image
# ---------------------------------------------------------------------------
def test_gen_image_qa_returns_200(client, mock_llm, mock_storage):
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
assert resp.status_code == 200
data = resp.json()
assert "pairs" in data
assert len(data["pairs"]) == 1
pair = data["pairs"][0]
assert pair["question"] == "电缆接头位于哪里?"
assert pair["answer"] == "配电箱左侧"
assert pair["image_path"] == "crops/1/0.jpg"
def test_gen_image_qa_llm_parse_error_returns_502(client, mock_llm, mock_storage):
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
mock_llm.chat_vision = AsyncMock(return_value="bad json {{")
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
assert resp.status_code == 502
assert resp.json()["code"] == "LLM_PARSE_ERROR"
def test_gen_image_qa_llm_call_error_returns_503(client, mock_llm, mock_storage):
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
mock_llm.chat_vision = AsyncMock(side_effect=LLMCallError("GLM vision timeout"))
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
assert resp.status_code == 503
assert resp.json()["code"] == "LLM_CALL_ERROR"
def test_gen_image_qa_storage_error_returns_502(client, mock_storage):
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down"))
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
assert resp.status_code == 502
assert resp.json()["code"] == "STORAGE_ERROR"

236
tests/test_qa_service.py Normal file
View File

@@ -0,0 +1,236 @@
"""Tests for qa_service: text QA (US5) and image QA (US6)."""
import base64
import json
import pytest
from unittest.mock import AsyncMock
from app.core.exceptions import LLMCallError, LLMParseError, StorageError
# ---------------------------------------------------------------------------
# Shared fixtures / helpers
# ---------------------------------------------------------------------------
SAMPLE_QA_JSON = json.dumps([
{"question": "电缆接头位于哪里?", "answer": "配电箱左侧"},
])
# ---------------------------------------------------------------------------
# T039 — Text QA service tests (US5)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_gen_text_qa_prompt_contains_triples(mock_llm):
"""Triple fields and source_snippet must appear in the message sent to LLM."""
from app.models.qa_models import GenTextQARequest, TextQAItem
from app.services.qa_service import gen_text_qa
mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON)
req = GenTextQARequest(items=[
TextQAItem(
subject="电缆接头",
predicate="位于",
object="配电箱左侧",
source_snippet="电缆接头位于配电箱左侧",
)
])
await gen_text_qa(req, mock_llm)
assert mock_llm.chat.called
call_args = mock_llm.chat.call_args
messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"]
prompt_text = messages[0]["content"]
assert "电缆接头" in prompt_text
assert "位于" in prompt_text
assert "配电箱左侧" in prompt_text
assert "电缆接头位于配电箱左侧" in prompt_text
@pytest.mark.asyncio
async def test_gen_text_qa_returns_qa_pair_list(mock_llm):
"""Parsed JSON must be returned as QAPair list."""
from app.models.qa_models import GenTextQARequest, QAPair, TextQAItem
from app.services.qa_service import gen_text_qa
mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON)
req = GenTextQARequest(items=[
TextQAItem(
subject="电缆接头",
predicate="位于",
object="配电箱左侧",
source_snippet="电缆接头位于配电箱左侧",
)
])
result = await gen_text_qa(req, mock_llm)
assert len(result.pairs) == 1
pair = result.pairs[0]
assert isinstance(pair, QAPair)
assert pair.question == "电缆接头位于哪里?"
assert pair.answer == "配电箱左侧"
@pytest.mark.asyncio
async def test_gen_text_qa_llm_parse_error_on_malformed_response(mock_llm):
"""LLMParseError must be raised when LLM returns non-JSON."""
from app.models.qa_models import GenTextQARequest, TextQAItem
from app.services.qa_service import gen_text_qa
mock_llm.chat = AsyncMock(return_value="this is not json {{")
req = GenTextQARequest(items=[
TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip")
])
with pytest.raises(LLMParseError):
await gen_text_qa(req, mock_llm)
@pytest.mark.asyncio
async def test_gen_text_qa_llm_call_error_propagates(mock_llm):
"""LLMCallError from LLM client must propagate unchanged."""
from app.models.qa_models import GenTextQARequest, TextQAItem
from app.services.qa_service import gen_text_qa
mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout"))
req = GenTextQARequest(items=[
TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip")
])
with pytest.raises(LLMCallError):
await gen_text_qa(req, mock_llm)
# ---------------------------------------------------------------------------
# T040 — Image QA service tests (US6)
# ---------------------------------------------------------------------------
FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content"
@pytest.mark.asyncio
async def test_gen_image_qa_downloads_image_and_encodes_base64(mock_llm, mock_storage):
"""Storage.download_bytes must be called, result base64-encoded in LLM message."""
from app.models.qa_models import GenImageQARequest, ImageQAItem
from app.services.qa_service import gen_image_qa
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
req = GenImageQARequest(items=[
ImageQAItem(
subject="电缆接头",
predicate="位于",
object="配电箱左侧",
cropped_image_path="crops/1/0.jpg",
)
])
await gen_image_qa(req, mock_llm, mock_storage)
# Storage download must have been called with the correct path
mock_storage.download_bytes.assert_called_once()
call_args = mock_storage.download_bytes.call_args
path_arg = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("path", call_args.kwargs.get("key"))
assert path_arg == "crops/1/0.jpg"
@pytest.mark.asyncio
async def test_gen_image_qa_multimodal_message_format(mock_llm, mock_storage):
"""Multimodal message must contain inline base64 image_url and text."""
from app.models.qa_models import GenImageQARequest, ImageQAItem
from app.services.qa_service import gen_image_qa
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
req = GenImageQARequest(items=[
ImageQAItem(
subject="电缆接头",
predicate="位于",
object="配电箱左侧",
qualifier="2024检修",
cropped_image_path="crops/1/0.jpg",
)
])
await gen_image_qa(req, mock_llm, mock_storage)
assert mock_llm.chat_vision.called
call_args = mock_llm.chat_vision.call_args
messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"]
# Find the content list in messages
content = messages[0]["content"]
assert isinstance(content, list)
# Must have an image_url part with inline base64 data URI
image_parts = [p for p in content if p.get("type") == "image_url"]
assert len(image_parts) >= 1
url = image_parts[0]["image_url"]["url"]
expected_b64 = base64.b64encode(FAKE_IMAGE_BYTES).decode()
assert url == f"data:image/jpeg;base64,{expected_b64}"
# Must have a text part containing quad info
text_parts = [p for p in content if p.get("type") == "text"]
assert len(text_parts) >= 1
text = text_parts[0]["text"]
assert "电缆接头" in text
assert "位于" in text
assert "配电箱左侧" in text
@pytest.mark.asyncio
async def test_gen_image_qa_returns_image_qa_pair_with_image_path(mock_llm, mock_storage):
"""Result ImageQAPair must include image_path from the item."""
from app.models.qa_models import GenImageQARequest, ImageQAItem, ImageQAPair
from app.services.qa_service import gen_image_qa
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
req = GenImageQARequest(items=[
ImageQAItem(
subject="电缆接头",
predicate="位于",
object="配电箱左侧",
cropped_image_path="crops/1/0.jpg",
)
])
result = await gen_image_qa(req, mock_llm, mock_storage)
assert len(result.pairs) == 1
pair = result.pairs[0]
assert isinstance(pair, ImageQAPair)
assert pair.question == "电缆接头位于哪里?"
assert pair.answer == "配电箱左侧"
assert pair.image_path == "crops/1/0.jpg"
@pytest.mark.asyncio
async def test_gen_image_qa_storage_error_propagates(mock_llm, mock_storage):
"""StorageError from download must propagate unchanged."""
from app.models.qa_models import GenImageQARequest, ImageQAItem
from app.services.qa_service import gen_image_qa
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down"))
req = GenImageQARequest(items=[
ImageQAItem(
subject="s",
predicate="p",
object="o",
cropped_image_path="crops/1/0.jpg",
)
])
with pytest.raises(StorageError):
await gen_image_qa(req, mock_llm, mock_storage)