Files
label_ai_service/app/services/text_service.py
wh dd8da386f4 feat(US1): text triple extraction — POST /api/v1/text/extract
- app/models/text_models.py: TripleItem, SourceOffset, TextExtract{Request,Response}
- app/services/text_service.py: TXT/PDF/DOCX parsing + LLM call + JSON parse
- app/routers/text.py: POST /text/extract handler with Depends injection
- tests/test_text_service.py: 6 unit tests (formats, errors)
- tests/test_text_router.py: 4 router tests (200, 400, 502×2)
- 10/10 tests passing
2026-04-10 15:27:27 +08:00

96 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import pdfplumber
import docx
from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.exceptions import UnsupportedFileTypeError
from app.core.json_utils import extract_json
from app.core.logging import get_logger
from app.models.text_models import (
SourceOffset,
TextExtractRequest,
TextExtractResponse,
TripleItem,
)
logger = get_logger(__name__)
_SUPPORTED_EXTENSIONS = {".txt", ".pdf", ".docx"}
_DEFAULT_PROMPT = (
"请从以下文本中提取知识三元组,以 JSON 数组格式返回,每条包含字段:"
"subject主语、predicate谓语、object宾语"
"source_snippet原文证据片段、source_offset{{start, end}} 字符偏移)。\n\n"
"文本内容:\n{text}"
)
def _file_extension(file_name: str) -> str:
idx = file_name.rfind(".")
return file_name[idx:].lower() if idx != -1 else ""
def _parse_txt(data: bytes) -> str:
return data.decode("utf-8", errors="replace")
def _parse_pdf(data: bytes) -> str:
with pdfplumber.open(io.BytesIO(data)) as pdf:
pages = [page.extract_text() or "" for page in pdf.pages]
return "\n".join(pages)
def _parse_docx(data: bytes) -> str:
doc = docx.Document(io.BytesIO(data))
return "\n".join(p.text for p in doc.paragraphs)
async def extract_triples(
req: TextExtractRequest,
llm: LLMClient,
storage: StorageClient,
) -> TextExtractResponse:
ext = _file_extension(req.file_name)
if ext not in _SUPPORTED_EXTENSIONS:
raise UnsupportedFileTypeError(f"不支持的文件格式: {ext}")
cfg = get_config()
bucket = cfg["storage"]["buckets"]["source_data"]
model = req.model or cfg["models"]["default_text"]
data = await storage.download_bytes(bucket, req.file_path)
if ext == ".txt":
text = _parse_txt(data)
elif ext == ".pdf":
text = _parse_pdf(data)
else:
text = _parse_docx(data)
prompt_template = req.prompt_template or _DEFAULT_PROMPT
prompt = prompt_template.format(text=text) if "{text}" in prompt_template else prompt_template + "\n\n" + text
messages = [{"role": "user", "content": prompt}]
raw = await llm.chat(model, messages)
logger.info("text_extract", extra={"file": req.file_name, "model": model})
items_raw = extract_json(raw)
items = [
TripleItem(
subject=item["subject"],
predicate=item["predicate"],
object=item["object"],
source_snippet=item["source_snippet"],
source_offset=SourceOffset(
start=item["source_offset"]["start"],
end=item["source_offset"]["end"],
),
)
for item in items_raw
]
return TextExtractResponse(items=items)