- app/models/image_models.py: BBox, QuadrupleItem, ImageExtract{Request,Response}
- app/services/image_service.py: download → base64 LLM → bbox clamp → crop upload
- app/routers/image.py: POST /image/extract handler
- tests: 4 service + 3 router tests, 7/7 passing
91 lines
2.7 KiB
Python
91 lines
2.7 KiB
Python
import base64
|
||
import io
|
||
|
||
import cv2
|
||
import numpy as np
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.config import get_config
|
||
from app.core.json_utils import extract_json
|
||
from app.core.logging import get_logger
|
||
from app.models.image_models import (
|
||
BBox,
|
||
ImageExtractRequest,
|
||
ImageExtractResponse,
|
||
QuadrupleItem,
|
||
)
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
_DEFAULT_PROMPT = (
|
||
"请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:"
|
||
"subject(主体实体)、predicate(关系/属性)、object(客体实体)、"
|
||
"qualifier(修饰信息,可为 null)、bbox({{x, y, w, h}} 像素坐标)。"
|
||
)
|
||
|
||
|
||
async def extract_quads(
|
||
req: ImageExtractRequest,
|
||
llm: LLMClient,
|
||
storage: StorageClient,
|
||
) -> ImageExtractResponse:
|
||
cfg = get_config()
|
||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||
model = req.model or cfg["models"]["default_vision"]
|
||
|
||
image_bytes = await storage.download_bytes(bucket, req.file_path)
|
||
|
||
# Decode with OpenCV for cropping; encode as base64 for LLM
|
||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||
img_h, img_w = img.shape[:2]
|
||
|
||
b64 = base64.b64encode(image_bytes).decode()
|
||
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||
|
||
prompt = req.prompt_template or _DEFAULT_PROMPT
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
]
|
||
|
||
raw = await llm.chat_vision(model, messages)
|
||
logger.info("image_extract", extra={"file": req.file_path, "model": model})
|
||
|
||
items_raw = extract_json(raw)
|
||
items: list[QuadrupleItem] = []
|
||
|
||
for idx, item in enumerate(items_raw):
|
||
b = item["bbox"]
|
||
# Clamp bbox to image dimensions
|
||
x = max(0, min(int(b["x"]), img_w - 1))
|
||
y = max(0, min(int(b["y"]), img_h - 1))
|
||
w = min(int(b["w"]), img_w - x)
|
||
h = min(int(b["h"]), img_h - y)
|
||
|
||
crop = img[y : y + h, x : x + w]
|
||
_, crop_buf = cv2.imencode(".jpg", crop)
|
||
crop_bytes = crop_buf.tobytes()
|
||
|
||
crop_path = f"crops/{req.task_id}/{idx}.jpg"
|
||
await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg")
|
||
|
||
items.append(
|
||
QuadrupleItem(
|
||
subject=item["subject"],
|
||
predicate=item["predicate"],
|
||
object=item["object"],
|
||
qualifier=item.get("qualifier"),
|
||
bbox=BBox(x=x, y=y, w=w, h=h),
|
||
cropped_image_path=crop_path,
|
||
)
|
||
)
|
||
|
||
return ImageExtractResponse(items=items)
|