119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
import base64
|
|
|
|
from app.clients.llm.base import LLMClient
|
|
from app.clients.storage.base import StorageClient
|
|
from app.core.config import get_config
|
|
from app.core.exceptions import LLMParseError
|
|
from app.core.json_utils import extract_json
|
|
from app.core.logging import get_logger
|
|
from app.models.qa_models import (
|
|
GenImageQARequest,
|
|
GenTextQARequest,
|
|
ImageQAPair,
|
|
ImageQAResponse,
|
|
QAPair,
|
|
TextQAResponse,
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
_DEFAULT_TEXT_PROMPT = (
|
|
"请根据以下知识三元组生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。\n\n"
|
|
"三元组列表:\n{triples_text}"
|
|
)
|
|
|
|
_DEFAULT_IMAGE_PROMPT = (
|
|
"请根据图片内容和以下四元组信息生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。"
|
|
)
|
|
|
|
|
|
async def gen_text_qa(req: GenTextQARequest, llm: LLMClient) -> TextQAResponse:
|
|
cfg = get_config()
|
|
model = req.model or cfg["models"]["default_text"]
|
|
|
|
# Format all triples + source snippets into a single batch prompt
|
|
triple_lines: list[str] = []
|
|
for item in req.items:
|
|
triple_lines.append(
|
|
f"({item.subject}, {item.predicate}, {item.object}) — 来源: {item.source_snippet}"
|
|
)
|
|
triples_text = "\n".join(triple_lines)
|
|
|
|
prompt_template = req.prompt_template or _DEFAULT_TEXT_PROMPT
|
|
if "{triples_text}" in prompt_template:
|
|
prompt = prompt_template.format(triples_text=triples_text)
|
|
else:
|
|
prompt = prompt_template + "\n\n" + triples_text
|
|
|
|
messages = [{"role": "user", "content": prompt}]
|
|
raw = await llm.chat(model, messages)
|
|
|
|
items_raw = extract_json(raw)
|
|
logger.info("gen_text_qa", extra={"items": len(req.items), "model": model})
|
|
|
|
if not isinstance(items_raw, list):
|
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
|
try:
|
|
pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw]
|
|
except (KeyError, TypeError):
|
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
|
return TextQAResponse(pairs=pairs)
|
|
|
|
|
|
async def gen_image_qa(
|
|
req: GenImageQARequest,
|
|
llm: LLMClient,
|
|
storage: StorageClient,
|
|
) -> ImageQAResponse:
|
|
cfg = get_config()
|
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
|
model = req.model or cfg["models"]["default_vision"]
|
|
|
|
prompt = req.prompt_template or _DEFAULT_IMAGE_PROMPT
|
|
|
|
pairs: list[ImageQAPair] = []
|
|
|
|
for item in req.items:
|
|
# Download cropped image bytes from storage
|
|
image_bytes = await storage.download_bytes(bucket, item.cropped_image_path)
|
|
|
|
# Base64 encode inline for multimodal message
|
|
b64 = base64.b64encode(image_bytes).decode()
|
|
image_data_url = f"data:image/jpeg;base64,{b64}"
|
|
|
|
# Build quad info text
|
|
quad_text = f"{item.subject} — {item.predicate} — {item.object}"
|
|
if item.qualifier:
|
|
quad_text += f" ({item.qualifier})"
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image_url", "image_url": {"url": image_data_url}},
|
|
{"type": "text", "text": f"{prompt}\n\n{quad_text}"},
|
|
],
|
|
}
|
|
]
|
|
|
|
raw = await llm.chat_vision(model, messages)
|
|
|
|
items_raw = extract_json(raw)
|
|
logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model})
|
|
|
|
if not isinstance(items_raw, list):
|
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
|
try:
|
|
for qa in items_raw:
|
|
pairs.append(
|
|
ImageQAPair(
|
|
question=qa["question"],
|
|
answer=qa["answer"],
|
|
image_path=item.cropped_image_path,
|
|
)
|
|
)
|
|
except (KeyError, TypeError):
|
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
|
|
|
return ImageQAResponse(pairs=pairs)
|