From 4211e587ee6bc75d1b18fe994560821395427d80 Mon Sep 17 00:00:00 2001 From: wh Date: Fri, 10 Apr 2026 16:05:49 +0800 Subject: [PATCH] =?UTF-8?q?feat(US5+6):=20QA=20generation=20=E2=80=94=20PO?= =?UTF-8?q?ST=20/api/v1/qa/gen-text=20and=20/gen-image?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add qa_models.py with TextQAItem, GenTextQARequest, QAPair, ImageQAItem, GenImageQARequest, ImageQAPair, TextQAResponse, ImageQAResponse - Implement gen_text_qa(): batch-formats triples into a single prompt, calls llm.chat(), parses JSON array via extract_json() - Implement gen_image_qa(): downloads cropped image from source-data bucket, base64-encodes inline (data URI), builds multimodal message, calls llm.chat_vision(), parses JSON; image_path preserved on ImageQAPair - Replace qa.py stub with full router: POST /qa/gen-text and /qa/gen-image using Depends(get_llm_client) and Depends(get_storage_client) - 15 new tests (8 service + 7 router), 53/53 total passing --- app/models/qa_models.py | 47 ++++++++ app/routers/qa.py | 30 ++++- app/services/qa_service.py | 106 +++++++++++++++++ tests/test_qa_router.py | 121 +++++++++++++++++++ tests/test_qa_service.py | 236 +++++++++++++++++++++++++++++++++++++ 5 files changed, 539 insertions(+), 1 deletion(-) create mode 100644 app/models/qa_models.py create mode 100644 app/services/qa_service.py create mode 100644 tests/test_qa_router.py create mode 100644 tests/test_qa_service.py diff --git a/app/models/qa_models.py b/app/models/qa_models.py new file mode 100644 index 0000000..54d7149 --- /dev/null +++ b/app/models/qa_models.py @@ -0,0 +1,47 @@ +from pydantic import BaseModel + + +class TextQAItem(BaseModel): + subject: str + predicate: str + object: str + source_snippet: str + + +class GenTextQARequest(BaseModel): + items: list[TextQAItem] + model: str | None = None + prompt_template: str | None = None + + +class QAPair(BaseModel): + question: str + answer: str + + +class ImageQAItem(BaseModel): + subject: str + predicate: str + object: str + qualifier: str | None = None + cropped_image_path: str + + +class GenImageQARequest(BaseModel): + items: list[ImageQAItem] + model: str | None = None + prompt_template: str | None = None + + +class ImageQAPair(BaseModel): + question: str + answer: str + image_path: str + + +class TextQAResponse(BaseModel): + pairs: list[QAPair] + + +class ImageQAResponse(BaseModel): + pairs: list[ImageQAPair] diff --git a/app/routers/qa.py b/app/routers/qa.py index 5b22c10..f0f3258 100644 --- a/app/routers/qa.py +++ b/app/routers/qa.py @@ -1,3 +1,31 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.dependencies import get_llm_client, get_storage_client +from app.models.qa_models import ( + GenImageQARequest, + GenTextQARequest, + ImageQAResponse, + TextQAResponse, +) +from app.services import qa_service router = APIRouter(tags=["QA"]) + + +@router.post("/qa/gen-text", response_model=TextQAResponse) +async def gen_text_qa( + req: GenTextQARequest, + llm: LLMClient = Depends(get_llm_client), +) -> TextQAResponse: + return await qa_service.gen_text_qa(req, llm) + + +@router.post("/qa/gen-image", response_model=ImageQAResponse) +async def gen_image_qa( + req: GenImageQARequest, + llm: LLMClient = Depends(get_llm_client), + storage: StorageClient = Depends(get_storage_client), +) -> ImageQAResponse: + return await qa_service.gen_image_qa(req, llm, storage) diff --git a/app/services/qa_service.py b/app/services/qa_service.py new file mode 100644 index 0000000..a8136b7 --- /dev/null +++ b/app/services/qa_service.py @@ -0,0 +1,106 @@ +import base64 + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.config import get_config +from app.core.json_utils import extract_json +from app.core.logging import get_logger +from app.models.qa_models import ( + GenImageQARequest, + GenTextQARequest, + ImageQAPair, + ImageQAResponse, + QAPair, + TextQAResponse, +) + +logger = get_logger(__name__) + +_DEFAULT_TEXT_PROMPT = ( + "请根据以下知识三元组生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。\n\n" + "三元组列表:\n{triples_text}" +) + +_DEFAULT_IMAGE_PROMPT = ( + "请根据图片内容和以下四元组信息生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。" +) + + +async def gen_text_qa(req: GenTextQARequest, llm: LLMClient) -> TextQAResponse: + cfg = get_config() + model = req.model or cfg["models"]["default_text"] + + # Format all triples + source snippets into a single batch prompt + triple_lines: list[str] = [] + for item in req.items: + triple_lines.append( + f"({item.subject}, {item.predicate}, {item.object}) — 来源: {item.source_snippet}" + ) + triples_text = "\n".join(triple_lines) + + prompt_template = req.prompt_template or _DEFAULT_TEXT_PROMPT + if "{triples_text}" in prompt_template: + prompt = prompt_template.format(triples_text=triples_text) + else: + prompt = prompt_template + "\n\n" + triples_text + + messages = [{"role": "user", "content": prompt}] + raw = await llm.chat(model, messages) + + logger.info("gen_text_qa", extra={"items": len(req.items), "model": model}) + + items_raw = extract_json(raw) + pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw] + return TextQAResponse(pairs=pairs) + + +async def gen_image_qa( + req: GenImageQARequest, + llm: LLMClient, + storage: StorageClient, +) -> ImageQAResponse: + cfg = get_config() + bucket = cfg["storage"]["buckets"]["source_data"] + model = req.model or cfg["models"]["default_vision"] + + prompt = req.prompt_template or _DEFAULT_IMAGE_PROMPT + + pairs: list[ImageQAPair] = [] + + for item in req.items: + # Download cropped image bytes from storage + image_bytes = await storage.download_bytes(bucket, item.cropped_image_path) + + # Base64 encode inline for multimodal message + b64 = base64.b64encode(image_bytes).decode() + image_data_url = f"data:image/jpeg;base64,{b64}" + + # Build quad info text + quad_text = f"{item.subject} — {item.predicate} — {item.object}" + if item.qualifier: + quad_text += f" ({item.qualifier})" + + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_data_url}}, + {"type": "text", "text": f"{prompt}\n\n{quad_text}"}, + ], + } + ] + + raw = await llm.chat_vision(model, messages) + logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model}) + + items_raw = extract_json(raw) + for qa in items_raw: + pairs.append( + ImageQAPair( + question=qa["question"], + answer=qa["answer"], + image_path=item.cropped_image_path, + ) + ) + + return ImageQAResponse(pairs=pairs) diff --git a/tests/test_qa_router.py b/tests/test_qa_router.py new file mode 100644 index 0000000..8f82575 --- /dev/null +++ b/tests/test_qa_router.py @@ -0,0 +1,121 @@ +"""Tests for QA router: /api/v1/qa/gen-text and /api/v1/qa/gen-image.""" +import json +import pytest +from unittest.mock import AsyncMock + +from app.core.exceptions import LLMCallError, LLMParseError, StorageError + + +SAMPLE_QA_JSON = json.dumps([ + {"question": "电缆接头位于哪里?", "answer": "配电箱左侧"}, +]) + +FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content" + +TEXT_QA_PAYLOAD = { + "items": [ + { + "subject": "电缆接头", + "predicate": "位于", + "object": "配电箱左侧", + "source_snippet": "电缆接头位于配电箱左侧", + } + ] +} + +IMAGE_QA_PAYLOAD = { + "items": [ + { + "subject": "电缆接头", + "predicate": "位于", + "object": "配电箱左侧", + "cropped_image_path": "crops/1/0.jpg", + } + ] +} + + +# --------------------------------------------------------------------------- +# POST /api/v1/qa/gen-text +# --------------------------------------------------------------------------- + + +def test_gen_text_qa_returns_200(client, mock_llm): + mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON) + + resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD) + + assert resp.status_code == 200 + data = resp.json() + assert "pairs" in data + assert len(data["pairs"]) == 1 + assert data["pairs"][0]["question"] == "电缆接头位于哪里?" + assert data["pairs"][0]["answer"] == "配电箱左侧" + + +def test_gen_text_qa_llm_parse_error_returns_502(client, mock_llm): + mock_llm.chat = AsyncMock(return_value="not valid json {{") + + resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD) + + assert resp.status_code == 502 + assert resp.json()["code"] == "LLM_PARSE_ERROR" + + +def test_gen_text_qa_llm_call_error_returns_503(client, mock_llm): + mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout")) + + resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD) + + assert resp.status_code == 503 + assert resp.json()["code"] == "LLM_CALL_ERROR" + + +# --------------------------------------------------------------------------- +# POST /api/v1/qa/gen-image +# --------------------------------------------------------------------------- + + +def test_gen_image_qa_returns_200(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) + mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON) + + resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD) + + assert resp.status_code == 200 + data = resp.json() + assert "pairs" in data + assert len(data["pairs"]) == 1 + pair = data["pairs"][0] + assert pair["question"] == "电缆接头位于哪里?" + assert pair["answer"] == "配电箱左侧" + assert pair["image_path"] == "crops/1/0.jpg" + + +def test_gen_image_qa_llm_parse_error_returns_502(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) + mock_llm.chat_vision = AsyncMock(return_value="bad json {{") + + resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD) + + assert resp.status_code == 502 + assert resp.json()["code"] == "LLM_PARSE_ERROR" + + +def test_gen_image_qa_llm_call_error_returns_503(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) + mock_llm.chat_vision = AsyncMock(side_effect=LLMCallError("GLM vision timeout")) + + resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD) + + assert resp.status_code == 503 + assert resp.json()["code"] == "LLM_CALL_ERROR" + + +def test_gen_image_qa_storage_error_returns_502(client, mock_storage): + mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down")) + + resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD) + + assert resp.status_code == 502 + assert resp.json()["code"] == "STORAGE_ERROR" diff --git a/tests/test_qa_service.py b/tests/test_qa_service.py new file mode 100644 index 0000000..7a6e258 --- /dev/null +++ b/tests/test_qa_service.py @@ -0,0 +1,236 @@ +"""Tests for qa_service: text QA (US5) and image QA (US6).""" +import base64 +import json +import pytest +from unittest.mock import AsyncMock + +from app.core.exceptions import LLMCallError, LLMParseError, StorageError + + +# --------------------------------------------------------------------------- +# Shared fixtures / helpers +# --------------------------------------------------------------------------- + +SAMPLE_QA_JSON = json.dumps([ + {"question": "电缆接头位于哪里?", "answer": "配电箱左侧"}, +]) + + +# --------------------------------------------------------------------------- +# T039 — Text QA service tests (US5) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_gen_text_qa_prompt_contains_triples(mock_llm): + """Triple fields and source_snippet must appear in the message sent to LLM.""" + from app.models.qa_models import GenTextQARequest, TextQAItem + from app.services.qa_service import gen_text_qa + + mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON) + + req = GenTextQARequest(items=[ + TextQAItem( + subject="电缆接头", + predicate="位于", + object="配电箱左侧", + source_snippet="电缆接头位于配电箱左侧", + ) + ]) + + await gen_text_qa(req, mock_llm) + + assert mock_llm.chat.called + call_args = mock_llm.chat.call_args + messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"] + prompt_text = messages[0]["content"] + assert "电缆接头" in prompt_text + assert "位于" in prompt_text + assert "配电箱左侧" in prompt_text + assert "电缆接头位于配电箱左侧" in prompt_text + + +@pytest.mark.asyncio +async def test_gen_text_qa_returns_qa_pair_list(mock_llm): + """Parsed JSON must be returned as QAPair list.""" + from app.models.qa_models import GenTextQARequest, QAPair, TextQAItem + from app.services.qa_service import gen_text_qa + + mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON) + + req = GenTextQARequest(items=[ + TextQAItem( + subject="电缆接头", + predicate="位于", + object="配电箱左侧", + source_snippet="电缆接头位于配电箱左侧", + ) + ]) + + result = await gen_text_qa(req, mock_llm) + + assert len(result.pairs) == 1 + pair = result.pairs[0] + assert isinstance(pair, QAPair) + assert pair.question == "电缆接头位于哪里?" + assert pair.answer == "配电箱左侧" + + +@pytest.mark.asyncio +async def test_gen_text_qa_llm_parse_error_on_malformed_response(mock_llm): + """LLMParseError must be raised when LLM returns non-JSON.""" + from app.models.qa_models import GenTextQARequest, TextQAItem + from app.services.qa_service import gen_text_qa + + mock_llm.chat = AsyncMock(return_value="this is not json {{") + + req = GenTextQARequest(items=[ + TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip") + ]) + + with pytest.raises(LLMParseError): + await gen_text_qa(req, mock_llm) + + +@pytest.mark.asyncio +async def test_gen_text_qa_llm_call_error_propagates(mock_llm): + """LLMCallError from LLM client must propagate unchanged.""" + from app.models.qa_models import GenTextQARequest, TextQAItem + from app.services.qa_service import gen_text_qa + + mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout")) + + req = GenTextQARequest(items=[ + TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip") + ]) + + with pytest.raises(LLMCallError): + await gen_text_qa(req, mock_llm) + + +# --------------------------------------------------------------------------- +# T040 — Image QA service tests (US6) +# --------------------------------------------------------------------------- + +FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content" + + +@pytest.mark.asyncio +async def test_gen_image_qa_downloads_image_and_encodes_base64(mock_llm, mock_storage): + """Storage.download_bytes must be called, result base64-encoded in LLM message.""" + from app.models.qa_models import GenImageQARequest, ImageQAItem + from app.services.qa_service import gen_image_qa + + mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) + mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON) + + req = GenImageQARequest(items=[ + ImageQAItem( + subject="电缆接头", + predicate="位于", + object="配电箱左侧", + cropped_image_path="crops/1/0.jpg", + ) + ]) + + await gen_image_qa(req, mock_llm, mock_storage) + + # Storage download must have been called with the correct path + mock_storage.download_bytes.assert_called_once() + call_args = mock_storage.download_bytes.call_args + path_arg = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("path", call_args.kwargs.get("key")) + assert path_arg == "crops/1/0.jpg" + + +@pytest.mark.asyncio +async def test_gen_image_qa_multimodal_message_format(mock_llm, mock_storage): + """Multimodal message must contain inline base64 image_url and text.""" + from app.models.qa_models import GenImageQARequest, ImageQAItem + from app.services.qa_service import gen_image_qa + + mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) + mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON) + + req = GenImageQARequest(items=[ + ImageQAItem( + subject="电缆接头", + predicate="位于", + object="配电箱左侧", + qualifier="2024检修", + cropped_image_path="crops/1/0.jpg", + ) + ]) + + await gen_image_qa(req, mock_llm, mock_storage) + + assert mock_llm.chat_vision.called + call_args = mock_llm.chat_vision.call_args + messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"] + + # Find the content list in messages + content = messages[0]["content"] + assert isinstance(content, list) + + # Must have an image_url part with inline base64 data URI + image_parts = [p for p in content if p.get("type") == "image_url"] + assert len(image_parts) >= 1 + url = image_parts[0]["image_url"]["url"] + expected_b64 = base64.b64encode(FAKE_IMAGE_BYTES).decode() + assert url == f"data:image/jpeg;base64,{expected_b64}" + + # Must have a text part containing quad info + text_parts = [p for p in content if p.get("type") == "text"] + assert len(text_parts) >= 1 + text = text_parts[0]["text"] + assert "电缆接头" in text + assert "位于" in text + assert "配电箱左侧" in text + + +@pytest.mark.asyncio +async def test_gen_image_qa_returns_image_qa_pair_with_image_path(mock_llm, mock_storage): + """Result ImageQAPair must include image_path from the item.""" + from app.models.qa_models import GenImageQARequest, ImageQAItem, ImageQAPair + from app.services.qa_service import gen_image_qa + + mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES) + mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON) + + req = GenImageQARequest(items=[ + ImageQAItem( + subject="电缆接头", + predicate="位于", + object="配电箱左侧", + cropped_image_path="crops/1/0.jpg", + ) + ]) + + result = await gen_image_qa(req, mock_llm, mock_storage) + + assert len(result.pairs) == 1 + pair = result.pairs[0] + assert isinstance(pair, ImageQAPair) + assert pair.question == "电缆接头位于哪里?" + assert pair.answer == "配电箱左侧" + assert pair.image_path == "crops/1/0.jpg" + + +@pytest.mark.asyncio +async def test_gen_image_qa_storage_error_propagates(mock_llm, mock_storage): + """StorageError from download must propagate unchanged.""" + from app.models.qa_models import GenImageQARequest, ImageQAItem + from app.services.qa_service import gen_image_qa + + mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down")) + + req = GenImageQARequest(items=[ + ImageQAItem( + subject="s", + predicate="p", + object="o", + cropped_image_path="crops/1/0.jpg", + ) + ]) + + with pytest.raises(StorageError): + await gen_image_qa(req, mock_llm, mock_storage)