feat(US1): text triple extraction — POST /api/v1/text/extract
- app/models/text_models.py: TripleItem, SourceOffset, TextExtract{Request,Response}
- app/services/text_service.py: TXT/PDF/DOCX parsing + LLM call + JSON parse
- app/routers/text.py: POST /text/extract handler with Depends injection
- tests/test_text_service.py: 6 unit tests (formats, errors)
- tests/test_text_router.py: 4 router tests (200, 400, 502×2)
- 10/10 tests passing
This commit is contained in:
BIN
app/__pycache__/main.cpython-312.pyc
Normal file
BIN
app/__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/models/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/models/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/models/__pycache__/text_models.cpython-312.pyc
Normal file
BIN
app/models/__pycache__/text_models.cpython-312.pyc
Normal file
Binary file not shown.
25
app/models/text_models.py
Normal file
25
app/models/text_models.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SourceOffset(BaseModel):
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
|
||||||
|
class TripleItem(BaseModel):
|
||||||
|
subject: str
|
||||||
|
predicate: str
|
||||||
|
object: str
|
||||||
|
source_snippet: str
|
||||||
|
source_offset: SourceOffset
|
||||||
|
|
||||||
|
|
||||||
|
class TextExtractRequest(BaseModel):
|
||||||
|
file_path: str
|
||||||
|
file_name: str
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TextExtractResponse(BaseModel):
|
||||||
|
items: list[TripleItem]
|
||||||
BIN
app/routers/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/routers/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/routers/__pycache__/finetune.cpython-312.pyc
Normal file
BIN
app/routers/__pycache__/finetune.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/routers/__pycache__/image.cpython-312.pyc
Normal file
BIN
app/routers/__pycache__/image.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/routers/__pycache__/qa.cpython-312.pyc
Normal file
BIN
app/routers/__pycache__/qa.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/routers/__pycache__/text.cpython-312.pyc
Normal file
BIN
app/routers/__pycache__/text.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/routers/__pycache__/video.cpython-312.pyc
Normal file
BIN
app/routers/__pycache__/video.cpython-312.pyc
Normal file
Binary file not shown.
@@ -1,3 +1,18 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.dependencies import get_llm_client, get_storage_client
|
||||||
|
from app.models.text_models import TextExtractRequest, TextExtractResponse
|
||||||
|
from app.services import text_service
|
||||||
|
|
||||||
router = APIRouter(tags=["Text"])
|
router = APIRouter(tags=["Text"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/text/extract", response_model=TextExtractResponse)
|
||||||
|
async def extract_text(
|
||||||
|
req: TextExtractRequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
storage: StorageClient = Depends(get_storage_client),
|
||||||
|
) -> TextExtractResponse:
|
||||||
|
return await text_service.extract_triples(req, llm, storage)
|
||||||
|
|||||||
BIN
app/services/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/services/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/services/__pycache__/text_service.cpython-312.pyc
Normal file
BIN
app/services/__pycache__/text_service.cpython-312.pyc
Normal file
Binary file not shown.
95
app/services/text_service.py
Normal file
95
app/services/text_service.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import io
|
||||||
|
|
||||||
|
import pdfplumber
|
||||||
|
import docx
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
from app.core.exceptions import UnsupportedFileTypeError
|
||||||
|
from app.core.json_utils import extract_json
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.text_models import (
|
||||||
|
SourceOffset,
|
||||||
|
TextExtractRequest,
|
||||||
|
TextExtractResponse,
|
||||||
|
TripleItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_SUPPORTED_EXTENSIONS = {".txt", ".pdf", ".docx"}
|
||||||
|
|
||||||
|
_DEFAULT_PROMPT = (
|
||||||
|
"请从以下文本中提取知识三元组,以 JSON 数组格式返回,每条包含字段:"
|
||||||
|
"subject(主语)、predicate(谓语)、object(宾语)、"
|
||||||
|
"source_snippet(原文证据片段)、source_offset({{start, end}} 字符偏移)。\n\n"
|
||||||
|
"文本内容:\n{text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _file_extension(file_name: str) -> str:
|
||||||
|
idx = file_name.rfind(".")
|
||||||
|
return file_name[idx:].lower() if idx != -1 else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_txt(data: bytes) -> str:
|
||||||
|
return data.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_pdf(data: bytes) -> str:
|
||||||
|
with pdfplumber.open(io.BytesIO(data)) as pdf:
|
||||||
|
pages = [page.extract_text() or "" for page in pdf.pages]
|
||||||
|
return "\n".join(pages)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_docx(data: bytes) -> str:
|
||||||
|
doc = docx.Document(io.BytesIO(data))
|
||||||
|
return "\n".join(p.text for p in doc.paragraphs)
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_triples(
|
||||||
|
req: TextExtractRequest,
|
||||||
|
llm: LLMClient,
|
||||||
|
storage: StorageClient,
|
||||||
|
) -> TextExtractResponse:
|
||||||
|
ext = _file_extension(req.file_name)
|
||||||
|
if ext not in _SUPPORTED_EXTENSIONS:
|
||||||
|
raise UnsupportedFileTypeError(f"不支持的文件格式: {ext}")
|
||||||
|
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
model = req.model or cfg["models"]["default_text"]
|
||||||
|
|
||||||
|
data = await storage.download_bytes(bucket, req.file_path)
|
||||||
|
|
||||||
|
if ext == ".txt":
|
||||||
|
text = _parse_txt(data)
|
||||||
|
elif ext == ".pdf":
|
||||||
|
text = _parse_pdf(data)
|
||||||
|
else:
|
||||||
|
text = _parse_docx(data)
|
||||||
|
|
||||||
|
prompt_template = req.prompt_template or _DEFAULT_PROMPT
|
||||||
|
prompt = prompt_template.format(text=text) if "{text}" in prompt_template else prompt_template + "\n\n" + text
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
raw = await llm.chat(model, messages)
|
||||||
|
|
||||||
|
logger.info("text_extract", extra={"file": req.file_name, "model": model})
|
||||||
|
|
||||||
|
items_raw = extract_json(raw)
|
||||||
|
items = [
|
||||||
|
TripleItem(
|
||||||
|
subject=item["subject"],
|
||||||
|
predicate=item["predicate"],
|
||||||
|
object=item["object"],
|
||||||
|
source_snippet=item["source_snippet"],
|
||||||
|
source_offset=SourceOffset(
|
||||||
|
start=item["source_offset"]["start"],
|
||||||
|
end=item["source_offset"]["end"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for item in items_raw
|
||||||
|
]
|
||||||
|
return TextExtractResponse(items=items)
|
||||||
BIN
tests/__pycache__/test_text_router.cpython-312-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/test_text_router.cpython-312-pytest-9.0.3.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_text_service.cpython-312-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/test_text_service.cpython-312-pytest-9.0.3.pyc
Normal file
Binary file not shown.
63
tests/test_text_router.py
Normal file
63
tests/test_text_router.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_TRIPLES_JSON = '''[
|
||||||
|
{
|
||||||
|
"subject": "变压器",
|
||||||
|
"predicate": "额定电压",
|
||||||
|
"object": "110kV",
|
||||||
|
"source_snippet": "该变压器额定电压为110kV",
|
||||||
|
"source_offset": {"start": 0, "end": 12}
|
||||||
|
}
|
||||||
|
]'''
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_extract_returns_200(client, mock_llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"some text content")
|
||||||
|
mock_llm.chat = AsyncMock(return_value=SAMPLE_TRIPLES_JSON)
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/text/extract",
|
||||||
|
json={"file_path": "text/test.txt", "file_name": "test.txt"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "items" in data
|
||||||
|
assert data["items"][0]["subject"] == "变压器"
|
||||||
|
assert data["items"][0]["source_offset"]["start"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_extract_unsupported_format_returns_400(client, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"data")
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/text/extract",
|
||||||
|
json={"file_path": "text/test.xlsx", "file_name": "data.xlsx"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert resp.json()["code"] == "UNSUPPORTED_FILE_TYPE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_extract_storage_error_returns_502(client, mock_llm, mock_storage):
|
||||||
|
from app.core.exceptions import StorageError
|
||||||
|
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS unreachable"))
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/text/extract",
|
||||||
|
json={"file_path": "text/test.txt", "file_name": "test.txt"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert resp.json()["code"] == "STORAGE_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_extract_llm_parse_error_returns_502(client, mock_llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"content")
|
||||||
|
mock_llm.chat = AsyncMock(return_value="not json {{{{")
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/text/extract",
|
||||||
|
json={"file_path": "text/test.txt", "file_name": "test.txt"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 502
|
||||||
|
assert resp.json()["code"] == "LLM_PARSE_ERROR"
|
||||||
122
tests/test_text_service.py
Normal file
122
tests/test_text_service.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from app.core.exceptions import LLMParseError, StorageError, UnsupportedFileTypeError
|
||||||
|
from app.models.text_models import TextExtractRequest
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_TRIPLES_JSON = '''[
|
||||||
|
{
|
||||||
|
"subject": "变压器",
|
||||||
|
"predicate": "额定电压",
|
||||||
|
"object": "110kV",
|
||||||
|
"source_snippet": "该变压器额定电压为110kV",
|
||||||
|
"source_offset": {"start": 0, "end": 12}
|
||||||
|
}
|
||||||
|
]'''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def req_txt():
|
||||||
|
return TextExtractRequest(file_path="text/test.txt", file_name="test.txt")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def req_pdf():
|
||||||
|
return TextExtractRequest(file_path="text/test.pdf", file_name="report.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def req_docx():
|
||||||
|
return TextExtractRequest(file_path="text/test.docx", file_name="doc.docx")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm(mock_llm):
|
||||||
|
mock_llm.chat = AsyncMock(return_value=SAMPLE_TRIPLES_JSON)
|
||||||
|
return mock_llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_txt_extraction_returns_triples(llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"test content")
|
||||||
|
from app.services.text_service import extract_triples
|
||||||
|
req = TextExtractRequest(file_path="text/test.txt", file_name="test.txt")
|
||||||
|
result = await extract_triples(req, llm, mock_storage)
|
||||||
|
assert len(result.items) == 1
|
||||||
|
assert result.items[0].subject == "变压器"
|
||||||
|
assert result.items[0].predicate == "额定电压"
|
||||||
|
assert result.items[0].object == "110kV"
|
||||||
|
assert result.items[0].source_offset.start == 0
|
||||||
|
assert result.items[0].source_offset.end == 12
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pdf_extraction(llm, mock_storage, tmp_path):
|
||||||
|
import pdfplumber, io
|
||||||
|
# We mock download_bytes to return a minimal PDF-like response
|
||||||
|
# and mock pdfplumber.open to return pages with text
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"%PDF fake")
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mock_page = MagicMock()
|
||||||
|
mock_page.extract_text.return_value = "PDF content here"
|
||||||
|
mock_pdf = MagicMock()
|
||||||
|
mock_pdf.__enter__ = lambda s: s
|
||||||
|
mock_pdf.__exit__ = MagicMock(return_value=False)
|
||||||
|
mock_pdf.pages = [mock_page]
|
||||||
|
mp.setattr("pdfplumber.open", lambda f: mock_pdf)
|
||||||
|
|
||||||
|
from app.services import text_service
|
||||||
|
import importlib
|
||||||
|
importlib.reload(text_service)
|
||||||
|
req = TextExtractRequest(file_path="text/test.pdf", file_name="doc.pdf")
|
||||||
|
result = await text_service.extract_triples(req, llm, mock_storage)
|
||||||
|
assert len(result.items) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_docx_extraction(llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"PK fake docx bytes")
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mock_para = MagicMock()
|
||||||
|
mock_para.text = "Word paragraph content"
|
||||||
|
mock_doc = MagicMock()
|
||||||
|
mock_doc.paragraphs = [mock_para]
|
||||||
|
mp.setattr("docx.Document", lambda f: mock_doc)
|
||||||
|
|
||||||
|
from app.services import text_service
|
||||||
|
import importlib
|
||||||
|
importlib.reload(text_service)
|
||||||
|
req = TextExtractRequest(file_path="text/test.docx", file_name="doc.docx")
|
||||||
|
result = await text_service.extract_triples(req, llm, mock_storage)
|
||||||
|
assert len(result.items) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsupported_format_raises_error(llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"data")
|
||||||
|
from app.services.text_service import extract_triples
|
||||||
|
req = TextExtractRequest(file_path="text/test.xlsx", file_name="data.xlsx")
|
||||||
|
with pytest.raises(UnsupportedFileTypeError):
|
||||||
|
await extract_triples(req, llm, mock_storage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_storage_error_propagates(llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("not found"))
|
||||||
|
from app.services.text_service import extract_triples
|
||||||
|
req = TextExtractRequest(file_path="text/test.txt", file_name="test.txt")
|
||||||
|
with pytest.raises(StorageError):
|
||||||
|
await extract_triples(req, llm, mock_storage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_parse_error_propagates(mock_llm, mock_storage):
|
||||||
|
mock_storage.download_bytes = AsyncMock(return_value=b"content")
|
||||||
|
mock_llm.chat = AsyncMock(return_value="not json {{")
|
||||||
|
from app.services.text_service import extract_triples
|
||||||
|
req = TextExtractRequest(file_path="text/test.txt", file_name="test.txt")
|
||||||
|
with pytest.raises(LLMParseError):
|
||||||
|
await extract_triples(req, mock_llm, mock_storage)
|
||||||
Reference in New Issue
Block a user