feat(US2): image quad extraction — POST /api/v1/image/extract
- app/models/image_models.py: BBox, QuadrupleItem, ImageExtract{Request,Response}
- app/services/image_service.py: download → base64 LLM → bbox clamp → crop upload
- app/routers/image.py: POST /image/extract handler
- tests: 4 service + 3 router tests, 7/7 passing
This commit is contained in:
BIN
app/models/__pycache__/image_models.cpython-312.pyc
Normal file
BIN
app/models/__pycache__/image_models.cpython-312.pyc
Normal file
Binary file not shown.
28
app/models/image_models.py
Normal file
28
app/models/image_models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BBox(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
|
||||
|
||||
class QuadrupleItem(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
qualifier: str | None = None
|
||||
bbox: BBox
|
||||
cropped_image_path: str
|
||||
|
||||
|
||||
class ImageExtractRequest(BaseModel):
|
||||
file_path: str
|
||||
task_id: int
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class ImageExtractResponse(BaseModel):
|
||||
items: list[QuadrupleItem]
|
||||
Binary file not shown.
@@ -1,3 +1,18 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.dependencies import get_llm_client, get_storage_client
|
||||
from app.models.image_models import ImageExtractRequest, ImageExtractResponse
|
||||
from app.services import image_service
|
||||
|
||||
router = APIRouter(tags=["Image"])
|
||||
|
||||
|
||||
@router.post("/image/extract", response_model=ImageExtractResponse)
|
||||
async def extract_image(
|
||||
req: ImageExtractRequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
storage: StorageClient = Depends(get_storage_client),
|
||||
) -> ImageExtractResponse:
|
||||
return await image_service.extract_quads(req, llm, storage)
|
||||
|
||||
BIN
app/services/__pycache__/image_service.cpython-312.pyc
Normal file
BIN
app/services/__pycache__/image_service.cpython-312.pyc
Normal file
Binary file not shown.
90
app/services/image_service.py
Normal file
90
app/services/image_service.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.config import get_config
|
||||
from app.core.json_utils import extract_json
|
||||
from app.core.logging import get_logger
|
||||
from app.models.image_models import (
|
||||
BBox,
|
||||
ImageExtractRequest,
|
||||
ImageExtractResponse,
|
||||
QuadrupleItem,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_DEFAULT_PROMPT = (
|
||||
"请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:"
|
||||
"subject(主体实体)、predicate(关系/属性)、object(客体实体)、"
|
||||
"qualifier(修饰信息,可为 null)、bbox({{x, y, w, h}} 像素坐标)。"
|
||||
)
|
||||
|
||||
|
||||
async def extract_quads(
|
||||
req: ImageExtractRequest,
|
||||
llm: LLMClient,
|
||||
storage: StorageClient,
|
||||
) -> ImageExtractResponse:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
model = req.model or cfg["models"]["default_vision"]
|
||||
|
||||
image_bytes = await storage.download_bytes(bucket, req.file_path)
|
||||
|
||||
# Decode with OpenCV for cropping; encode as base64 for LLM
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||||
|
||||
prompt = req.prompt_template or _DEFAULT_PROMPT
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
raw = await llm.chat_vision(model, messages)
|
||||
logger.info("image_extract", extra={"file": req.file_path, "model": model})
|
||||
|
||||
items_raw = extract_json(raw)
|
||||
items: list[QuadrupleItem] = []
|
||||
|
||||
for idx, item in enumerate(items_raw):
|
||||
b = item["bbox"]
|
||||
# Clamp bbox to image dimensions
|
||||
x = max(0, min(int(b["x"]), img_w - 1))
|
||||
y = max(0, min(int(b["y"]), img_h - 1))
|
||||
w = min(int(b["w"]), img_w - x)
|
||||
h = min(int(b["h"]), img_h - y)
|
||||
|
||||
crop = img[y : y + h, x : x + w]
|
||||
_, crop_buf = cv2.imencode(".jpg", crop)
|
||||
crop_bytes = crop_buf.tobytes()
|
||||
|
||||
crop_path = f"crops/{req.task_id}/{idx}.jpg"
|
||||
await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg")
|
||||
|
||||
items.append(
|
||||
QuadrupleItem(
|
||||
subject=item["subject"],
|
||||
predicate=item["predicate"],
|
||||
object=item["object"],
|
||||
qualifier=item.get("qualifier"),
|
||||
bbox=BBox(x=x, y=y, w=w, h=h),
|
||||
cropped_image_path=crop_path,
|
||||
)
|
||||
)
|
||||
|
||||
return ImageExtractResponse(items=items)
|
||||
BIN
tests/__pycache__/test_image_router.cpython-312-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/test_image_router.cpython-312-pytest-9.0.3.pyc
Normal file
Binary file not shown.
Binary file not shown.
63
tests/test_image_router.py
Normal file
63
tests/test_image_router.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from app.core.exceptions import StorageError
|
||||
|
||||
|
||||
def _make_test_image_bytes() -> bytes:
|
||||
img = np.zeros((80, 100, 3), dtype=np.uint8)
|
||||
_, buf = cv2.imencode(".jpg", img)
|
||||
return buf.tobytes()
|
||||
|
||||
|
||||
SAMPLE_QUADS_JSON = json.dumps([
|
||||
{
|
||||
"subject": "电缆接头",
|
||||
"predicate": "位于",
|
||||
"object": "配电箱左侧",
|
||||
"qualifier": "2024年检修",
|
||||
"bbox": {"x": 5, "y": 5, "w": 20, "h": 15},
|
||||
}
|
||||
])
|
||||
|
||||
|
||||
def test_image_extract_returns_200(client, mock_llm, mock_storage):
|
||||
mock_storage.download_bytes = AsyncMock(return_value=_make_test_image_bytes())
|
||||
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON)
|
||||
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
||||
|
||||
resp = client.post(
|
||||
"/api/v1/image/extract",
|
||||
json={"file_path": "image/test.jpg", "task_id": 1},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "items" in data
|
||||
assert data["items"][0]["subject"] == "电缆接头"
|
||||
assert data["items"][0]["cropped_image_path"] == "crops/1/0.jpg"
|
||||
|
||||
|
||||
def test_image_extract_llm_parse_error_returns_502(client, mock_llm, mock_storage):
|
||||
mock_storage.download_bytes = AsyncMock(return_value=_make_test_image_bytes())
|
||||
mock_llm.chat_vision = AsyncMock(return_value="not json {{")
|
||||
|
||||
resp = client.post(
|
||||
"/api/v1/image/extract",
|
||||
json={"file_path": "image/test.jpg", "task_id": 1},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
assert resp.json()["code"] == "LLM_PARSE_ERROR"
|
||||
|
||||
|
||||
def test_image_extract_storage_error_returns_502(client, mock_storage):
|
||||
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down"))
|
||||
|
||||
resp = client.post(
|
||||
"/api/v1/image/extract",
|
||||
json={"file_path": "image/test.jpg", "task_id": 1},
|
||||
)
|
||||
assert resp.status_code == 502
|
||||
assert resp.json()["code"] == "STORAGE_ERROR"
|
||||
102
tests/test_image_service.py
Normal file
102
tests/test_image_service.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import io
|
||||
import json
|
||||
import pytest
|
||||
import numpy as np
|
||||
import cv2
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from app.core.exceptions import LLMParseError
|
||||
from app.models.image_models import ImageExtractRequest
|
||||
|
||||
|
||||
def _make_test_image_bytes(width=100, height=80) -> bytes:
|
||||
img = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
img[10:50, 10:60] = (255, 0, 0) # blue rectangle
|
||||
_, buf = cv2.imencode(".jpg", img)
|
||||
return buf.tobytes()
|
||||
|
||||
|
||||
SAMPLE_QUADS_JSON = json.dumps([
|
||||
{
|
||||
"subject": "电缆接头",
|
||||
"predicate": "位于",
|
||||
"object": "配电箱左侧",
|
||||
"qualifier": "2024年检修",
|
||||
"bbox": {"x": 10, "y": 10, "w": 40, "h": 30},
|
||||
}
|
||||
])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_bytes():
|
||||
return _make_test_image_bytes()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def req():
|
||||
return ImageExtractRequest(file_path="image/test.jpg", task_id=1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_quads_returns_items(mock_llm, mock_storage, image_bytes, req):
|
||||
mock_storage.download_bytes = AsyncMock(return_value=image_bytes)
|
||||
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON)
|
||||
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
||||
|
||||
from app.services.image_service import extract_quads
|
||||
result = await extract_quads(req, mock_llm, mock_storage)
|
||||
|
||||
assert len(result.items) == 1
|
||||
item = result.items[0]
|
||||
assert item.subject == "电缆接头"
|
||||
assert item.predicate == "位于"
|
||||
assert item.bbox.x == 10
|
||||
assert item.bbox.y == 10
|
||||
assert item.cropped_image_path == "crops/1/0.jpg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crop_is_uploaded(mock_llm, mock_storage, image_bytes, req):
|
||||
mock_storage.download_bytes = AsyncMock(return_value=image_bytes)
|
||||
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON)
|
||||
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
||||
|
||||
from app.services.image_service import extract_quads
|
||||
await extract_quads(req, mock_llm, mock_storage)
|
||||
|
||||
# upload_bytes called once for the crop
|
||||
mock_storage.upload_bytes.assert_called_once()
|
||||
call_args = mock_storage.upload_bytes.call_args
|
||||
assert call_args.args[1] == "crops/1/0.jpg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_out_of_bounds_bbox_is_clamped(mock_llm, mock_storage, req):
|
||||
img = _make_test_image_bytes(width=50, height=40)
|
||||
mock_storage.download_bytes = AsyncMock(return_value=img)
|
||||
|
||||
# bbox goes outside image boundary
|
||||
oob_json = json.dumps([{
|
||||
"subject": "test",
|
||||
"predicate": "rel",
|
||||
"object": "obj",
|
||||
"qualifier": None,
|
||||
"bbox": {"x": 30, "y": 20, "w": 100, "h": 100}, # extends beyond 50x40
|
||||
}])
|
||||
mock_llm.chat_vision = AsyncMock(return_value=oob_json)
|
||||
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
||||
|
||||
from app.services.image_service import extract_quads
|
||||
# Should not raise; bbox is clamped
|
||||
result = await extract_quads(req, mock_llm, mock_storage)
|
||||
assert len(result.items) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_parse_error_raised(mock_llm, mock_storage, image_bytes, req):
|
||||
mock_storage.download_bytes = AsyncMock(return_value=image_bytes)
|
||||
mock_llm.chat_vision = AsyncMock(return_value="bad json {{")
|
||||
|
||||
from app.services.image_service import extract_quads
|
||||
with pytest.raises(LLMParseError):
|
||||
await extract_quads(req, mock_llm, mock_storage)
|
||||
Reference in New Issue
Block a user