feat(US5+6): QA generation — POST /api/v1/qa/gen-text and /gen-image
- Add qa_models.py with TextQAItem, GenTextQARequest, QAPair, ImageQAItem, GenImageQARequest, ImageQAPair, TextQAResponse, ImageQAResponse - Implement gen_text_qa(): batch-formats triples into a single prompt, calls llm.chat(), parses JSON array via extract_json() - Implement gen_image_qa(): downloads cropped image from source-data bucket, base64-encodes inline (data URI), builds multimodal message, calls llm.chat_vision(), parses JSON; image_path preserved on ImageQAPair - Replace qa.py stub with full router: POST /qa/gen-text and /qa/gen-image using Depends(get_llm_client) and Depends(get_storage_client) - 15 new tests (8 service + 7 router), 53/53 total passing
This commit is contained in:
47
app/models/qa_models.py
Normal file
47
app/models/qa_models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TextQAItem(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
source_snippet: str
|
||||
|
||||
|
||||
class GenTextQARequest(BaseModel):
|
||||
items: list[TextQAItem]
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class QAPair(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class ImageQAItem(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
qualifier: str | None = None
|
||||
cropped_image_path: str
|
||||
|
||||
|
||||
class GenImageQARequest(BaseModel):
|
||||
items: list[ImageQAItem]
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class ImageQAPair(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
image_path: str
|
||||
|
||||
|
||||
class TextQAResponse(BaseModel):
|
||||
pairs: list[QAPair]
|
||||
|
||||
|
||||
class ImageQAResponse(BaseModel):
|
||||
pairs: list[ImageQAPair]
|
||||
@@ -1,3 +1,31 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.dependencies import get_llm_client, get_storage_client
|
||||
from app.models.qa_models import (
|
||||
GenImageQARequest,
|
||||
GenTextQARequest,
|
||||
ImageQAResponse,
|
||||
TextQAResponse,
|
||||
)
|
||||
from app.services import qa_service
|
||||
|
||||
router = APIRouter(tags=["QA"])
|
||||
|
||||
|
||||
@router.post("/qa/gen-text", response_model=TextQAResponse)
|
||||
async def gen_text_qa(
|
||||
req: GenTextQARequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
) -> TextQAResponse:
|
||||
return await qa_service.gen_text_qa(req, llm)
|
||||
|
||||
|
||||
@router.post("/qa/gen-image", response_model=ImageQAResponse)
|
||||
async def gen_image_qa(
|
||||
req: GenImageQARequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
storage: StorageClient = Depends(get_storage_client),
|
||||
) -> ImageQAResponse:
|
||||
return await qa_service.gen_image_qa(req, llm, storage)
|
||||
|
||||
106
app/services/qa_service.py
Normal file
106
app/services/qa_service.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import base64
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.config import get_config
|
||||
from app.core.json_utils import extract_json
|
||||
from app.core.logging import get_logger
|
||||
from app.models.qa_models import (
|
||||
GenImageQARequest,
|
||||
GenTextQARequest,
|
||||
ImageQAPair,
|
||||
ImageQAResponse,
|
||||
QAPair,
|
||||
TextQAResponse,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_DEFAULT_TEXT_PROMPT = (
|
||||
"请根据以下知识三元组生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。\n\n"
|
||||
"三元组列表:\n{triples_text}"
|
||||
)
|
||||
|
||||
_DEFAULT_IMAGE_PROMPT = (
|
||||
"请根据图片内容和以下四元组信息生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。"
|
||||
)
|
||||
|
||||
|
||||
async def gen_text_qa(req: GenTextQARequest, llm: LLMClient) -> TextQAResponse:
|
||||
cfg = get_config()
|
||||
model = req.model or cfg["models"]["default_text"]
|
||||
|
||||
# Format all triples + source snippets into a single batch prompt
|
||||
triple_lines: list[str] = []
|
||||
for item in req.items:
|
||||
triple_lines.append(
|
||||
f"({item.subject}, {item.predicate}, {item.object}) — 来源: {item.source_snippet}"
|
||||
)
|
||||
triples_text = "\n".join(triple_lines)
|
||||
|
||||
prompt_template = req.prompt_template or _DEFAULT_TEXT_PROMPT
|
||||
if "{triples_text}" in prompt_template:
|
||||
prompt = prompt_template.format(triples_text=triples_text)
|
||||
else:
|
||||
prompt = prompt_template + "\n\n" + triples_text
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
raw = await llm.chat(model, messages)
|
||||
|
||||
logger.info("gen_text_qa", extra={"items": len(req.items), "model": model})
|
||||
|
||||
items_raw = extract_json(raw)
|
||||
pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw]
|
||||
return TextQAResponse(pairs=pairs)
|
||||
|
||||
|
||||
async def gen_image_qa(
|
||||
req: GenImageQARequest,
|
||||
llm: LLMClient,
|
||||
storage: StorageClient,
|
||||
) -> ImageQAResponse:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
model = req.model or cfg["models"]["default_vision"]
|
||||
|
||||
prompt = req.prompt_template or _DEFAULT_IMAGE_PROMPT
|
||||
|
||||
pairs: list[ImageQAPair] = []
|
||||
|
||||
for item in req.items:
|
||||
# Download cropped image bytes from storage
|
||||
image_bytes = await storage.download_bytes(bucket, item.cropped_image_path)
|
||||
|
||||
# Base64 encode inline for multimodal message
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||||
|
||||
# Build quad info text
|
||||
quad_text = f"{item.subject} — {item.predicate} — {item.object}"
|
||||
if item.qualifier:
|
||||
quad_text += f" ({item.qualifier})"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||
{"type": "text", "text": f"{prompt}\n\n{quad_text}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
raw = await llm.chat_vision(model, messages)
|
||||
logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model})
|
||||
|
||||
items_raw = extract_json(raw)
|
||||
for qa in items_raw:
|
||||
pairs.append(
|
||||
ImageQAPair(
|
||||
question=qa["question"],
|
||||
answer=qa["answer"],
|
||||
image_path=item.cropped_image_path,
|
||||
)
|
||||
)
|
||||
|
||||
return ImageQAResponse(pairs=pairs)
|
||||
Reference in New Issue
Block a user