feat(US2): image quad extraction — POST /api/v1/image/extract

- app/models/image_models.py: BBox, QuadrupleItem, ImageExtract{Request,Response}
- app/services/image_service.py: download → base64 LLM → bbox clamp → crop upload
- app/routers/image.py: POST /image/extract handler
- tests: 4 service + 3 router tests, 7/7 passing
This commit is contained in:
wh
2026-04-10 15:40:56 +08:00
parent dd8da386f4
commit 2876c179ac
10 changed files with 299 additions and 1 deletions

Binary file not shown.

View File

@@ -0,0 +1,28 @@
from pydantic import BaseModel
class BBox(BaseModel):
x: int
y: int
w: int
h: int
class QuadrupleItem(BaseModel):
subject: str
predicate: str
object: str
qualifier: str | None = None
bbox: BBox
cropped_image_path: str
class ImageExtractRequest(BaseModel):
file_path: str
task_id: int
model: str | None = None
prompt_template: str | None = None
class ImageExtractResponse(BaseModel):
items: list[QuadrupleItem]

View File

@@ -1,3 +1,18 @@
from fastapi import APIRouter from fastapi import APIRouter, Depends
from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.dependencies import get_llm_client, get_storage_client
from app.models.image_models import ImageExtractRequest, ImageExtractResponse
from app.services import image_service
router = APIRouter(tags=["Image"]) router = APIRouter(tags=["Image"])
@router.post("/image/extract", response_model=ImageExtractResponse)
async def extract_image(
req: ImageExtractRequest,
llm: LLMClient = Depends(get_llm_client),
storage: StorageClient = Depends(get_storage_client),
) -> ImageExtractResponse:
return await image_service.extract_quads(req, llm, storage)

Binary file not shown.

View File

@@ -0,0 +1,90 @@
import base64
import io
import cv2
import numpy as np
from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.json_utils import extract_json
from app.core.logging import get_logger
from app.models.image_models import (
BBox,
ImageExtractRequest,
ImageExtractResponse,
QuadrupleItem,
)
logger = get_logger(__name__)
_DEFAULT_PROMPT = (
"请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:"
"subject主体实体、predicate关系/属性、object客体实体"
"qualifier修饰信息可为 null、bbox{{x, y, w, h}} 像素坐标)。"
)
async def extract_quads(
req: ImageExtractRequest,
llm: LLMClient,
storage: StorageClient,
) -> ImageExtractResponse:
cfg = get_config()
bucket = cfg["storage"]["buckets"]["source_data"]
model = req.model or cfg["models"]["default_vision"]
image_bytes = await storage.download_bytes(bucket, req.file_path)
# Decode with OpenCV for cropping; encode as base64 for LLM
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
img_h, img_w = img.shape[:2]
b64 = base64.b64encode(image_bytes).decode()
image_data_url = f"data:image/jpeg;base64,{b64}"
prompt = req.prompt_template or _DEFAULT_PROMPT
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_data_url}},
{"type": "text", "text": prompt},
],
}
]
raw = await llm.chat_vision(model, messages)
logger.info("image_extract", extra={"file": req.file_path, "model": model})
items_raw = extract_json(raw)
items: list[QuadrupleItem] = []
for idx, item in enumerate(items_raw):
b = item["bbox"]
# Clamp bbox to image dimensions
x = max(0, min(int(b["x"]), img_w - 1))
y = max(0, min(int(b["y"]), img_h - 1))
w = min(int(b["w"]), img_w - x)
h = min(int(b["h"]), img_h - y)
crop = img[y : y + h, x : x + w]
_, crop_buf = cv2.imencode(".jpg", crop)
crop_bytes = crop_buf.tobytes()
crop_path = f"crops/{req.task_id}/{idx}.jpg"
await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg")
items.append(
QuadrupleItem(
subject=item["subject"],
predicate=item["predicate"],
object=item["object"],
qualifier=item.get("qualifier"),
bbox=BBox(x=x, y=y, w=w, h=h),
cropped_image_path=crop_path,
)
)
return ImageExtractResponse(items=items)

View File

@@ -0,0 +1,63 @@
import json
import numpy as np
import cv2
import pytest
from unittest.mock import AsyncMock
from app.core.exceptions import StorageError
def _make_test_image_bytes() -> bytes:
img = np.zeros((80, 100, 3), dtype=np.uint8)
_, buf = cv2.imencode(".jpg", img)
return buf.tobytes()
SAMPLE_QUADS_JSON = json.dumps([
{
"subject": "电缆接头",
"predicate": "位于",
"object": "配电箱左侧",
"qualifier": "2024年检修",
"bbox": {"x": 5, "y": 5, "w": 20, "h": 15},
}
])
def test_image_extract_returns_200(client, mock_llm, mock_storage):
mock_storage.download_bytes = AsyncMock(return_value=_make_test_image_bytes())
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON)
mock_storage.upload_bytes = AsyncMock(return_value=None)
resp = client.post(
"/api/v1/image/extract",
json={"file_path": "image/test.jpg", "task_id": 1},
)
assert resp.status_code == 200
data = resp.json()
assert "items" in data
assert data["items"][0]["subject"] == "电缆接头"
assert data["items"][0]["cropped_image_path"] == "crops/1/0.jpg"
def test_image_extract_llm_parse_error_returns_502(client, mock_llm, mock_storage):
mock_storage.download_bytes = AsyncMock(return_value=_make_test_image_bytes())
mock_llm.chat_vision = AsyncMock(return_value="not json {{")
resp = client.post(
"/api/v1/image/extract",
json={"file_path": "image/test.jpg", "task_id": 1},
)
assert resp.status_code == 502
assert resp.json()["code"] == "LLM_PARSE_ERROR"
def test_image_extract_storage_error_returns_502(client, mock_storage):
mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down"))
resp = client.post(
"/api/v1/image/extract",
json={"file_path": "image/test.jpg", "task_id": 1},
)
assert resp.status_code == 502
assert resp.json()["code"] == "STORAGE_ERROR"

102
tests/test_image_service.py Normal file
View File

@@ -0,0 +1,102 @@
import io
import json
import pytest
import numpy as np
import cv2
from unittest.mock import AsyncMock
from app.core.exceptions import LLMParseError
from app.models.image_models import ImageExtractRequest
def _make_test_image_bytes(width=100, height=80) -> bytes:
img = np.zeros((height, width, 3), dtype=np.uint8)
img[10:50, 10:60] = (255, 0, 0) # blue rectangle
_, buf = cv2.imencode(".jpg", img)
return buf.tobytes()
SAMPLE_QUADS_JSON = json.dumps([
{
"subject": "电缆接头",
"predicate": "位于",
"object": "配电箱左侧",
"qualifier": "2024年检修",
"bbox": {"x": 10, "y": 10, "w": 40, "h": 30},
}
])
@pytest.fixture
def image_bytes():
return _make_test_image_bytes()
@pytest.fixture
def req():
return ImageExtractRequest(file_path="image/test.jpg", task_id=1)
@pytest.mark.asyncio
async def test_extract_quads_returns_items(mock_llm, mock_storage, image_bytes, req):
mock_storage.download_bytes = AsyncMock(return_value=image_bytes)
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON)
mock_storage.upload_bytes = AsyncMock(return_value=None)
from app.services.image_service import extract_quads
result = await extract_quads(req, mock_llm, mock_storage)
assert len(result.items) == 1
item = result.items[0]
assert item.subject == "电缆接头"
assert item.predicate == "位于"
assert item.bbox.x == 10
assert item.bbox.y == 10
assert item.cropped_image_path == "crops/1/0.jpg"
@pytest.mark.asyncio
async def test_crop_is_uploaded(mock_llm, mock_storage, image_bytes, req):
mock_storage.download_bytes = AsyncMock(return_value=image_bytes)
mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON)
mock_storage.upload_bytes = AsyncMock(return_value=None)
from app.services.image_service import extract_quads
await extract_quads(req, mock_llm, mock_storage)
# upload_bytes called once for the crop
mock_storage.upload_bytes.assert_called_once()
call_args = mock_storage.upload_bytes.call_args
assert call_args.args[1] == "crops/1/0.jpg"
@pytest.mark.asyncio
async def test_out_of_bounds_bbox_is_clamped(mock_llm, mock_storage, req):
img = _make_test_image_bytes(width=50, height=40)
mock_storage.download_bytes = AsyncMock(return_value=img)
# bbox goes outside image boundary
oob_json = json.dumps([{
"subject": "test",
"predicate": "rel",
"object": "obj",
"qualifier": None,
"bbox": {"x": 30, "y": 20, "w": 100, "h": 100}, # extends beyond 50x40
}])
mock_llm.chat_vision = AsyncMock(return_value=oob_json)
mock_storage.upload_bytes = AsyncMock(return_value=None)
from app.services.image_service import extract_quads
# Should not raise; bbox is clamped
result = await extract_quads(req, mock_llm, mock_storage)
assert len(result.items) == 1
@pytest.mark.asyncio
async def test_llm_parse_error_raised(mock_llm, mock_storage, image_bytes, req):
mock_storage.download_bytes = AsyncMock(return_value=image_bytes)
mock_llm.chat_vision = AsyncMock(return_value="bad json {{")
from app.services.image_service import extract_quads
with pytest.raises(LLMParseError):
await extract_quads(req, mock_llm, mock_storage)