Files
label_ai_service/app/services/image_service.py
wh 2876c179ac feat(US2): image quad extraction — POST /api/v1/image/extract
- app/models/image_models.py: BBox, QuadrupleItem, ImageExtract{Request,Response}
- app/services/image_service.py: download → base64 LLM → bbox clamp → crop upload
- app/routers/image.py: POST /image/extract handler
- tests: 4 service + 3 router tests, 7/7 passing
2026-04-10 15:40:56 +08:00

91 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import io
import cv2
import numpy as np
from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.json_utils import extract_json
from app.core.logging import get_logger
from app.models.image_models import (
BBox,
ImageExtractRequest,
ImageExtractResponse,
QuadrupleItem,
)
logger = get_logger(__name__)
_DEFAULT_PROMPT = (
"请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:"
"subject主体实体、predicate关系/属性、object客体实体"
"qualifier修饰信息可为 null、bbox{{x, y, w, h}} 像素坐标)。"
)
async def extract_quads(
req: ImageExtractRequest,
llm: LLMClient,
storage: StorageClient,
) -> ImageExtractResponse:
cfg = get_config()
bucket = cfg["storage"]["buckets"]["source_data"]
model = req.model or cfg["models"]["default_vision"]
image_bytes = await storage.download_bytes(bucket, req.file_path)
# Decode with OpenCV for cropping; encode as base64 for LLM
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
img_h, img_w = img.shape[:2]
b64 = base64.b64encode(image_bytes).decode()
image_data_url = f"data:image/jpeg;base64,{b64}"
prompt = req.prompt_template or _DEFAULT_PROMPT
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_data_url}},
{"type": "text", "text": prompt},
],
}
]
raw = await llm.chat_vision(model, messages)
logger.info("image_extract", extra={"file": req.file_path, "model": model})
items_raw = extract_json(raw)
items: list[QuadrupleItem] = []
for idx, item in enumerate(items_raw):
b = item["bbox"]
# Clamp bbox to image dimensions
x = max(0, min(int(b["x"]), img_w - 1))
y = max(0, min(int(b["y"]), img_h - 1))
w = min(int(b["w"]), img_w - x)
h = min(int(b["h"]), img_h - y)
crop = img[y : y + h, x : x + w]
_, crop_buf = cv2.imencode(".jpg", crop)
crop_bytes = crop_buf.tobytes()
crop_path = f"crops/{req.task_id}/{idx}.jpg"
await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg")
items.append(
QuadrupleItem(
subject=item["subject"],
predicate=item["predicate"],
object=item["object"],
qualifier=item.get("qualifier"),
bbox=BBox(x=x, y=y, w=w, h=h),
cropped_image_path=crop_path,
)
)
return ImageExtractResponse(items=items)