feat(US5+6): QA generation — POST /api/v1/qa/gen-text and /gen-image
- Add qa_models.py with TextQAItem, GenTextQARequest, QAPair, ImageQAItem, GenImageQARequest, ImageQAPair, TextQAResponse, ImageQAResponse - Implement gen_text_qa(): batch-formats triples into a single prompt, calls llm.chat(), parses JSON array via extract_json() - Implement gen_image_qa(): downloads cropped image from source-data bucket, base64-encodes inline (data URI), builds multimodal message, calls llm.chat_vision(), parses JSON; image_path preserved on ImageQAPair - Replace qa.py stub with full router: POST /qa/gen-text and /qa/gen-image using Depends(get_llm_client) and Depends(get_storage_client) - 15 new tests (8 service + 7 router), 53/53 total passing
This commit is contained in:
47
app/models/qa_models.py
Normal file
47
app/models/qa_models.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TextQAItem(BaseModel):
|
||||||
|
subject: str
|
||||||
|
predicate: str
|
||||||
|
object: str
|
||||||
|
source_snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
class GenTextQARequest(BaseModel):
|
||||||
|
items: list[TextQAItem]
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class QAPair(BaseModel):
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageQAItem(BaseModel):
|
||||||
|
subject: str
|
||||||
|
predicate: str
|
||||||
|
object: str
|
||||||
|
qualifier: str | None = None
|
||||||
|
cropped_image_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class GenImageQARequest(BaseModel):
|
||||||
|
items: list[ImageQAItem]
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ImageQAPair(BaseModel):
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
image_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class TextQAResponse(BaseModel):
|
||||||
|
pairs: list[QAPair]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageQAResponse(BaseModel):
|
||||||
|
pairs: list[ImageQAPair]
|
||||||
@@ -1,3 +1,31 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.dependencies import get_llm_client, get_storage_client
|
||||||
|
from app.models.qa_models import (
|
||||||
|
GenImageQARequest,
|
||||||
|
GenTextQARequest,
|
||||||
|
ImageQAResponse,
|
||||||
|
TextQAResponse,
|
||||||
|
)
|
||||||
|
from app.services import qa_service
|
||||||
|
|
||||||
router = APIRouter(tags=["QA"])
|
router = APIRouter(tags=["QA"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/qa/gen-text", response_model=TextQAResponse)
|
||||||
|
async def gen_text_qa(
|
||||||
|
req: GenTextQARequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
) -> TextQAResponse:
|
||||||
|
return await qa_service.gen_text_qa(req, llm)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/qa/gen-image", response_model=ImageQAResponse)
|
||||||
|
async def gen_image_qa(
|
||||||
|
req: GenImageQARequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
storage: StorageClient = Depends(get_storage_client),
|
||||||
|
) -> ImageQAResponse:
|
||||||
|
return await qa_service.gen_image_qa(req, llm, storage)
|
||||||
|
|||||||
106
app/services/qa_service.py
Normal file
106
app/services/qa_service.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
from app.core.json_utils import extract_json
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.qa_models import (
|
||||||
|
GenImageQARequest,
|
||||||
|
GenTextQARequest,
|
||||||
|
ImageQAPair,
|
||||||
|
ImageQAResponse,
|
||||||
|
QAPair,
|
||||||
|
TextQAResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_TEXT_PROMPT = (
|
||||||
|
"请根据以下知识三元组生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。\n\n"
|
||||||
|
"三元组列表:\n{triples_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_DEFAULT_IMAGE_PROMPT = (
|
||||||
|
"请根据图片内容和以下四元组信息生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def gen_text_qa(req: GenTextQARequest, llm: LLMClient) -> TextQAResponse:
|
||||||
|
cfg = get_config()
|
||||||
|
model = req.model or cfg["models"]["default_text"]
|
||||||
|
|
||||||
|
# Format all triples + source snippets into a single batch prompt
|
||||||
|
triple_lines: list[str] = []
|
||||||
|
for item in req.items:
|
||||||
|
triple_lines.append(
|
||||||
|
f"({item.subject}, {item.predicate}, {item.object}) — 来源: {item.source_snippet}"
|
||||||
|
)
|
||||||
|
triples_text = "\n".join(triple_lines)
|
||||||
|
|
||||||
|
prompt_template = req.prompt_template or _DEFAULT_TEXT_PROMPT
|
||||||
|
if "{triples_text}" in prompt_template:
|
||||||
|
prompt = prompt_template.format(triples_text=triples_text)
|
||||||
|
else:
|
||||||
|
prompt = prompt_template + "\n\n" + triples_text
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
raw = await llm.chat(model, messages)
|
||||||
|
|
||||||
|
logger.info("gen_text_qa", extra={"items": len(req.items), "model": model})
|
||||||
|
|
||||||
|
items_raw = extract_json(raw)
|
||||||
|
pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw]
|
||||||
|
return TextQAResponse(pairs=pairs)
|
||||||
|
|
||||||
|
|
||||||
|
async def gen_image_qa(
|
||||||
|
req: GenImageQARequest,
|
||||||
|
llm: LLMClient,
|
||||||
|
storage: StorageClient,
|
||||||
|
) -> ImageQAResponse:
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
model = req.model or cfg["models"]["default_vision"]
|
||||||
|
|
||||||
|
prompt = req.prompt_template or _DEFAULT_IMAGE_PROMPT
|
||||||
|
|
||||||
|
pairs: list[ImageQAPair] = []
|
||||||
|
|
||||||
|
for item in req.items:
|
||||||
|
# Download cropped image bytes from storage
|
||||||
|
image_bytes = await storage.download_bytes(bucket, item.cropped_image_path)
|
||||||
|
|
||||||
|
# Base64 encode inline for multimodal message
|
||||||
|
b64 = base64.b64encode(image_bytes).decode()
|
||||||
|
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||||||
|
|
||||||
|
# Build quad info text
|
||||||
|
quad_text = f"{item.subject} — {item.predicate} — {item.object}"
|
||||||
|
if item.qualifier:
|
||||||
|
quad_text += f" ({item.qualifier})"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||||
|
{"type": "text", "text": f"{prompt}\n\n{quad_text}"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
raw = await llm.chat_vision(model, messages)
|
||||||
|
logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model})
|
||||||
|
|
||||||
|
items_raw = extract_json(raw)
|
||||||
|
for qa in items_raw:
|
||||||
|
pairs.append(
|
||||||
|
ImageQAPair(
|
||||||
|
question=qa["question"],
|
||||||
|
answer=qa["answer"],
|
||||||
|
image_path=item.cropped_image_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageQAResponse(pairs=pairs)
|
||||||
121
tests/test_qa_router.py
Normal file
121
tests/test_qa_router.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Tests for QA router: /api/v1/qa/gen-text and /api/v1/qa/gen-image."""
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from app.core.exceptions import LLMCallError, LLMParseError, StorageError
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_QA_JSON = json.dumps([
|
||||||
|
{"question": "电缆接头位于哪里?", "answer": "配电箱左侧"},
|
||||||
|
])
|
||||||
|
|
||||||
|
FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content"
|
||||||
|
|
||||||
|
TEXT_QA_PAYLOAD = {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"subject": "电缆接头",
|
||||||
|
"predicate": "位于",
|
||||||
|
"object": "配电箱左侧",
|
||||||
|
"source_snippet": "电缆接头位于配电箱左侧",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
IMAGE_QA_PAYLOAD = {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"subject": "电缆接头",
|
||||||
|
"predicate": "位于",
|
||||||
|
"object": "配电箱左侧",
|
||||||
|
"cropped_image_path": "crops/1/0.jpg",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/v1/qa/gen-text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_gen_text_qa_returns_200(client, mock_llm):
|
||||||
|
mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "pairs" in data
|
||||||
|
assert len(data["pairs"]) == 1
|
||||||
|
assert data["pairs"][0]["question"] == "电缆接头位于哪里?"
|
||||||
|
assert data["pairs"][0]["answer"] == "配电箱左侧"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gen_text_qa_llm_parse_error_returns_502(client, mock_llm):
|
||||||
|
mock_llm.chat = AsyncMock(return_value="not valid json {{")
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
|
||||||
|
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert resp.json()["code"] == "LLM_PARSE_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gen_text_qa_llm_call_error_returns_503(client, mock_llm):
|
||||||
|
mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout"))
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/qa/gen-text", json=TEXT_QA_PAYLOAD)
|
||||||
|
|
||||||
|
assert resp.status_code == 503
|
||||||
|
assert resp.json()["code"] == "LLM_CALL_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# POST /api/v1/qa/gen-image
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_gen_image_qa_returns_200(client, mock_llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||||||
|
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "pairs" in data
|
||||||
|
assert len(data["pairs"]) == 1
|
||||||
|
pair = data["pairs"][0]
|
||||||
|
assert pair["question"] == "电缆接头位于哪里?"
|
||||||
|
assert pair["answer"] == "配电箱左侧"
|
||||||
|
assert pair["image_path"] == "crops/1/0.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gen_image_qa_llm_parse_error_returns_502(client, mock_llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||||||
|
mock_llm.chat_vision = AsyncMock(return_value="bad json {{")
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||||||
|
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert resp.json()["code"] == "LLM_PARSE_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gen_image_qa_llm_call_error_returns_503(client, mock_llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||||||
|
mock_llm.chat_vision = AsyncMock(side_effect=LLMCallError("GLM vision timeout"))
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||||||
|
|
||||||
|
assert resp.status_code == 503
|
||||||
|
assert resp.json()["code"] == "LLM_CALL_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gen_image_qa_storage_error_returns_502(client, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down"))
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/qa/gen-image", json=IMAGE_QA_PAYLOAD)
|
||||||
|
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert resp.json()["code"] == "STORAGE_ERROR"
|
||||||
236
tests/test_qa_service.py
Normal file
236
tests/test_qa_service.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""Tests for qa_service: text QA (US5) and image QA (US6)."""
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from app.core.exceptions import LLMCallError, LLMParseError, StorageError
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared fixtures / helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
SAMPLE_QA_JSON = json.dumps([
|
||||||
|
{"question": "电缆接头位于哪里?", "answer": "配电箱左侧"},
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# T039 — Text QA service tests (US5)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_text_qa_prompt_contains_triples(mock_llm):
|
||||||
|
"""Triple fields and source_snippet must appear in the message sent to LLM."""
|
||||||
|
from app.models.qa_models import GenTextQARequest, TextQAItem
|
||||||
|
from app.services.qa_service import gen_text_qa
|
||||||
|
|
||||||
|
mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||||||
|
|
||||||
|
req = GenTextQARequest(items=[
|
||||||
|
TextQAItem(
|
||||||
|
subject="电缆接头",
|
||||||
|
predicate="位于",
|
||||||
|
object="配电箱左侧",
|
||||||
|
source_snippet="电缆接头位于配电箱左侧",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
await gen_text_qa(req, mock_llm)
|
||||||
|
|
||||||
|
assert mock_llm.chat.called
|
||||||
|
call_args = mock_llm.chat.call_args
|
||||||
|
messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"]
|
||||||
|
prompt_text = messages[0]["content"]
|
||||||
|
assert "电缆接头" in prompt_text
|
||||||
|
assert "位于" in prompt_text
|
||||||
|
assert "配电箱左侧" in prompt_text
|
||||||
|
assert "电缆接头位于配电箱左侧" in prompt_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_text_qa_returns_qa_pair_list(mock_llm):
|
||||||
|
"""Parsed JSON must be returned as QAPair list."""
|
||||||
|
from app.models.qa_models import GenTextQARequest, QAPair, TextQAItem
|
||||||
|
from app.services.qa_service import gen_text_qa
|
||||||
|
|
||||||
|
mock_llm.chat = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||||||
|
|
||||||
|
req = GenTextQARequest(items=[
|
||||||
|
TextQAItem(
|
||||||
|
subject="电缆接头",
|
||||||
|
predicate="位于",
|
||||||
|
object="配电箱左侧",
|
||||||
|
source_snippet="电缆接头位于配电箱左侧",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
result = await gen_text_qa(req, mock_llm)
|
||||||
|
|
||||||
|
assert len(result.pairs) == 1
|
||||||
|
pair = result.pairs[0]
|
||||||
|
assert isinstance(pair, QAPair)
|
||||||
|
assert pair.question == "电缆接头位于哪里?"
|
||||||
|
assert pair.answer == "配电箱左侧"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_text_qa_llm_parse_error_on_malformed_response(mock_llm):
|
||||||
|
"""LLMParseError must be raised when LLM returns non-JSON."""
|
||||||
|
from app.models.qa_models import GenTextQARequest, TextQAItem
|
||||||
|
from app.services.qa_service import gen_text_qa
|
||||||
|
|
||||||
|
mock_llm.chat = AsyncMock(return_value="this is not json {{")
|
||||||
|
|
||||||
|
req = GenTextQARequest(items=[
|
||||||
|
TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip")
|
||||||
|
])
|
||||||
|
|
||||||
|
with pytest.raises(LLMParseError):
|
||||||
|
await gen_text_qa(req, mock_llm)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_text_qa_llm_call_error_propagates(mock_llm):
|
||||||
|
"""LLMCallError from LLM client must propagate unchanged."""
|
||||||
|
from app.models.qa_models import GenTextQARequest, TextQAItem
|
||||||
|
from app.services.qa_service import gen_text_qa
|
||||||
|
|
||||||
|
mock_llm.chat = AsyncMock(side_effect=LLMCallError("GLM timeout"))
|
||||||
|
|
||||||
|
req = GenTextQARequest(items=[
|
||||||
|
TextQAItem(subject="s", predicate="p", object="o", source_snippet="snip")
|
||||||
|
])
|
||||||
|
|
||||||
|
with pytest.raises(LLMCallError):
|
||||||
|
await gen_text_qa(req, mock_llm)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# T040 — Image QA service tests (US6)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
FAKE_IMAGE_BYTES = b"\xff\xd8\xff\xe0fake_jpeg_content"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_image_qa_downloads_image_and_encodes_base64(mock_llm, mock_storage):
|
||||||
|
"""Storage.download_bytes must be called, result base64-encoded in LLM message."""
|
||||||
|
from app.models.qa_models import GenImageQARequest, ImageQAItem
|
||||||
|
from app.services.qa_service import gen_image_qa
|
||||||
|
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||||||
|
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||||||
|
|
||||||
|
req = GenImageQARequest(items=[
|
||||||
|
ImageQAItem(
|
||||||
|
subject="电缆接头",
|
||||||
|
predicate="位于",
|
||||||
|
object="配电箱左侧",
|
||||||
|
cropped_image_path="crops/1/0.jpg",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
await gen_image_qa(req, mock_llm, mock_storage)
|
||||||
|
|
||||||
|
# Storage download must have been called with the correct path
|
||||||
|
mock_storage.download_bytes.assert_called_once()
|
||||||
|
call_args = mock_storage.download_bytes.call_args
|
||||||
|
path_arg = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs.get("path", call_args.kwargs.get("key"))
|
||||||
|
assert path_arg == "crops/1/0.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_image_qa_multimodal_message_format(mock_llm, mock_storage):
|
||||||
|
"""Multimodal message must contain inline base64 image_url and text."""
|
||||||
|
from app.models.qa_models import GenImageQARequest, ImageQAItem
|
||||||
|
from app.services.qa_service import gen_image_qa
|
||||||
|
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||||||
|
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||||||
|
|
||||||
|
req = GenImageQARequest(items=[
|
||||||
|
ImageQAItem(
|
||||||
|
subject="电缆接头",
|
||||||
|
predicate="位于",
|
||||||
|
object="配电箱左侧",
|
||||||
|
qualifier="2024检修",
|
||||||
|
cropped_image_path="crops/1/0.jpg",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
await gen_image_qa(req, mock_llm, mock_storage)
|
||||||
|
|
||||||
|
assert mock_llm.chat_vision.called
|
||||||
|
call_args = mock_llm.chat_vision.call_args
|
||||||
|
messages = call_args.args[1] if call_args.args else call_args.kwargs["messages"]
|
||||||
|
|
||||||
|
# Find the content list in messages
|
||||||
|
content = messages[0]["content"]
|
||||||
|
assert isinstance(content, list)
|
||||||
|
|
||||||
|
# Must have an image_url part with inline base64 data URI
|
||||||
|
image_parts = [p for p in content if p.get("type") == "image_url"]
|
||||||
|
assert len(image_parts) >= 1
|
||||||
|
url = image_parts[0]["image_url"]["url"]
|
||||||
|
expected_b64 = base64.b64encode(FAKE_IMAGE_BYTES).decode()
|
||||||
|
assert url == f"data:image/jpeg;base64,{expected_b64}"
|
||||||
|
|
||||||
|
# Must have a text part containing quad info
|
||||||
|
text_parts = [p for p in content if p.get("type") == "text"]
|
||||||
|
assert len(text_parts) >= 1
|
||||||
|
text = text_parts[0]["text"]
|
||||||
|
assert "电缆接头" in text
|
||||||
|
assert "位于" in text
|
||||||
|
assert "配电箱左侧" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_image_qa_returns_image_qa_pair_with_image_path(mock_llm, mock_storage):
|
||||||
|
"""Result ImageQAPair must include image_path from the item."""
|
||||||
|
from app.models.qa_models import GenImageQARequest, ImageQAItem, ImageQAPair
|
||||||
|
from app.services.qa_service import gen_image_qa
|
||||||
|
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=FAKE_IMAGE_BYTES)
|
||||||
|
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QA_JSON)
|
||||||
|
|
||||||
|
req = GenImageQARequest(items=[
|
||||||
|
ImageQAItem(
|
||||||
|
subject="电缆接头",
|
||||||
|
predicate="位于",
|
||||||
|
object="配电箱左侧",
|
||||||
|
cropped_image_path="crops/1/0.jpg",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
result = await gen_image_qa(req, mock_llm, mock_storage)
|
||||||
|
|
||||||
|
assert len(result.pairs) == 1
|
||||||
|
pair = result.pairs[0]
|
||||||
|
assert isinstance(pair, ImageQAPair)
|
||||||
|
assert pair.question == "电缆接头位于哪里?"
|
||||||
|
assert pair.answer == "配电箱左侧"
|
||||||
|
assert pair.image_path == "crops/1/0.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gen_image_qa_storage_error_propagates(mock_llm, mock_storage):
|
||||||
|
"""StorageError from download must propagate unchanged."""
|
||||||
|
from app.models.qa_models import GenImageQARequest, ImageQAItem
|
||||||
|
from app.services.qa_service import gen_image_qa
|
||||||
|
|
||||||
|
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down"))
|
||||||
|
|
||||||
|
req = GenImageQARequest(items=[
|
||||||
|
ImageQAItem(
|
||||||
|
subject="s",
|
||||||
|
predicate="p",
|
||||||
|
object="o",
|
||||||
|
cropped_image_path="crops/1/0.jpg",
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
with pytest.raises(StorageError):
|
||||||
|
await gen_image_qa(req, mock_llm, mock_storage)
|
||||||
Reference in New Issue
Block a user