diff --git a/app/services/qa_service.py b/app/services/qa_service.py index a8136b7..fe23c80 100644 --- a/app/services/qa_service.py +++ b/app/services/qa_service.py @@ -3,6 +3,7 @@ import base64 from app.clients.llm.base import LLMClient from app.clients.storage.base import StorageClient from app.core.config import get_config +from app.core.exceptions import LLMParseError from app.core.json_utils import extract_json from app.core.logging import get_logger from app.models.qa_models import ( @@ -47,10 +48,15 @@ async def gen_text_qa(req: GenTextQARequest, llm: LLMClient) -> TextQAResponse: messages = [{"role": "user", "content": prompt}] raw = await llm.chat(model, messages) + items_raw = extract_json(raw) logger.info("gen_text_qa", extra={"items": len(req.items), "model": model}) - items_raw = extract_json(raw) - pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw] + if not isinstance(items_raw, list): + raise LLMParseError("大模型返回的问答对格式不正确") + try: + pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw] + except (KeyError, TypeError): + raise LLMParseError("大模型返回的问答对格式不正确") return TextQAResponse(pairs=pairs) @@ -91,16 +97,22 @@ async def gen_image_qa( ] raw = await llm.chat_vision(model, messages) - logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model}) items_raw = extract_json(raw) - for qa in items_raw: - pairs.append( - ImageQAPair( - question=qa["question"], - answer=qa["answer"], - image_path=item.cropped_image_path, + logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model}) + + if not isinstance(items_raw, list): + raise LLMParseError("大模型返回的问答对格式不正确") + try: + for qa in items_raw: + pairs.append( + ImageQAPair( + question=qa["question"], + answer=qa["answer"], + image_path=item.cropped_image_path, + ) ) - ) + except (KeyError, TypeError): + raise LLMParseError("大模型返回的问答对格式不正确") return ImageQAResponse(pairs=pairs)