96 lines
2.8 KiB
Python
96 lines
2.8 KiB
Python
|
|
import io
|
|||
|
|
|
|||
|
|
import pdfplumber
|
|||
|
|
import docx
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.config import get_config
|
|||
|
|
from app.core.exceptions import UnsupportedFileTypeError
|
|||
|
|
from app.core.json_utils import extract_json
|
|||
|
|
from app.core.logging import get_logger
|
|||
|
|
from app.models.text_models import (
|
|||
|
|
SourceOffset,
|
|||
|
|
TextExtractRequest,
|
|||
|
|
TextExtractResponse,
|
|||
|
|
TripleItem,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
_SUPPORTED_EXTENSIONS = {".txt", ".pdf", ".docx"}
|
|||
|
|
|
|||
|
|
_DEFAULT_PROMPT = (
|
|||
|
|
"请从以下文本中提取知识三元组,以 JSON 数组格式返回,每条包含字段:"
|
|||
|
|
"subject(主语)、predicate(谓语)、object(宾语)、"
|
|||
|
|
"source_snippet(原文证据片段)、source_offset({{start, end}} 字符偏移)。\n\n"
|
|||
|
|
"文本内容:\n{text}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _file_extension(file_name: str) -> str:
|
|||
|
|
idx = file_name.rfind(".")
|
|||
|
|
return file_name[idx:].lower() if idx != -1 else ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_txt(data: bytes) -> str:
|
|||
|
|
return data.decode("utf-8", errors="replace")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_pdf(data: bytes) -> str:
|
|||
|
|
with pdfplumber.open(io.BytesIO(data)) as pdf:
|
|||
|
|
pages = [page.extract_text() or "" for page in pdf.pages]
|
|||
|
|
return "\n".join(pages)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_docx(data: bytes) -> str:
|
|||
|
|
doc = docx.Document(io.BytesIO(data))
|
|||
|
|
return "\n".join(p.text for p in doc.paragraphs)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def extract_triples(
|
|||
|
|
req: TextExtractRequest,
|
|||
|
|
llm: LLMClient,
|
|||
|
|
storage: StorageClient,
|
|||
|
|
) -> TextExtractResponse:
|
|||
|
|
ext = _file_extension(req.file_name)
|
|||
|
|
if ext not in _SUPPORTED_EXTENSIONS:
|
|||
|
|
raise UnsupportedFileTypeError(f"不支持的文件格式: {ext}")
|
|||
|
|
|
|||
|
|
cfg = get_config()
|
|||
|
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
|||
|
|
model = req.model or cfg["models"]["default_text"]
|
|||
|
|
|
|||
|
|
data = await storage.download_bytes(bucket, req.file_path)
|
|||
|
|
|
|||
|
|
if ext == ".txt":
|
|||
|
|
text = _parse_txt(data)
|
|||
|
|
elif ext == ".pdf":
|
|||
|
|
text = _parse_pdf(data)
|
|||
|
|
else:
|
|||
|
|
text = _parse_docx(data)
|
|||
|
|
|
|||
|
|
prompt_template = req.prompt_template or _DEFAULT_PROMPT
|
|||
|
|
prompt = prompt_template.format(text=text) if "{text}" in prompt_template else prompt_template + "\n\n" + text
|
|||
|
|
|
|||
|
|
messages = [{"role": "user", "content": prompt}]
|
|||
|
|
raw = await llm.chat(model, messages)
|
|||
|
|
|
|||
|
|
logger.info("text_extract", extra={"file": req.file_name, "model": model})
|
|||
|
|
|
|||
|
|
items_raw = extract_json(raw)
|
|||
|
|
items = [
|
|||
|
|
TripleItem(
|
|||
|
|
subject=item["subject"],
|
|||
|
|
predicate=item["predicate"],
|
|||
|
|
object=item["object"],
|
|||
|
|
source_snippet=item["source_snippet"],
|
|||
|
|
source_offset=SourceOffset(
|
|||
|
|
start=item["source_offset"]["start"],
|
|||
|
|
end=item["source_offset"]["end"],
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
for item in items_raw
|
|||
|
|
]
|
|||
|
|
return TextExtractResponse(items=items)
|