Files
label_ai_service/app/services/image_service.py

91 lines
2.7 KiB
Python
Raw Permalink Normal View History

import base64
import io
import cv2
import numpy as np
from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.json_utils import extract_json
from app.core.logging import get_logger
from app.models.image_models import (
BBox,
ImageExtractRequest,
ImageExtractResponse,
QuadrupleItem,
)
logger = get_logger(__name__)
_DEFAULT_PROMPT = (
"请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:"
"subject主体实体、predicate关系/属性、object客体实体"
"qualifier修饰信息可为 null、bbox{{x, y, w, h}} 像素坐标)。"
)
async def extract_quads(
req: ImageExtractRequest,
llm: LLMClient,
storage: StorageClient,
) -> ImageExtractResponse:
cfg = get_config()
bucket = cfg["storage"]["buckets"]["source_data"]
model = req.model or cfg["models"]["default_vision"]
image_bytes = await storage.download_bytes(bucket, req.file_path)
# Decode with OpenCV for cropping; encode as base64 for LLM
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
img_h, img_w = img.shape[:2]
b64 = base64.b64encode(image_bytes).decode()
image_data_url = f"data:image/jpeg;base64,{b64}"
prompt = req.prompt_template or _DEFAULT_PROMPT
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_data_url}},
{"type": "text", "text": prompt},
],
}
]
raw = await llm.chat_vision(model, messages)
logger.info("image_extract", extra={"file": req.file_path, "model": model})
items_raw = extract_json(raw)
items: list[QuadrupleItem] = []
for idx, item in enumerate(items_raw):
b = item["bbox"]
# Clamp bbox to image dimensions
x = max(0, min(int(b["x"]), img_w - 1))
y = max(0, min(int(b["y"]), img_h - 1))
w = min(int(b["w"]), img_w - x)
h = min(int(b["h"]), img_h - y)
crop = img[y : y + h, x : x + w]
_, crop_buf = cv2.imencode(".jpg", crop)
crop_bytes = crop_buf.tobytes()
crop_path = f"crops/{req.task_id}/{idx}.jpg"
await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg")
items.append(
QuadrupleItem(
subject=item["subject"],
predicate=item["predicate"],
object=item["object"],
qualifier=item.get("qualifier"),
bbox=BBox(x=x, y=y, w=w, h=h),
cropped_image_path=crop_path,
)
)
return ImageExtractResponse(items=items)