diff --git a/app/models/__pycache__/image_models.cpython-312.pyc b/app/models/__pycache__/image_models.cpython-312.pyc new file mode 100644 index 0000000..35c1d1e Binary files /dev/null and b/app/models/__pycache__/image_models.cpython-312.pyc differ diff --git a/app/models/image_models.py b/app/models/image_models.py new file mode 100644 index 0000000..79f6a3b --- /dev/null +++ b/app/models/image_models.py @@ -0,0 +1,28 @@ +from pydantic import BaseModel + + +class BBox(BaseModel): + x: int + y: int + w: int + h: int + + +class QuadrupleItem(BaseModel): + subject: str + predicate: str + object: str + qualifier: str | None = None + bbox: BBox + cropped_image_path: str + + +class ImageExtractRequest(BaseModel): + file_path: str + task_id: int + model: str | None = None + prompt_template: str | None = None + + +class ImageExtractResponse(BaseModel): + items: list[QuadrupleItem] diff --git a/app/routers/__pycache__/image.cpython-312.pyc b/app/routers/__pycache__/image.cpython-312.pyc index 9c78dda..513b0f1 100644 Binary files a/app/routers/__pycache__/image.cpython-312.pyc and b/app/routers/__pycache__/image.cpython-312.pyc differ diff --git a/app/routers/image.py b/app/routers/image.py index 30aefbc..7ede9f4 100644 --- a/app/routers/image.py +++ b/app/routers/image.py @@ -1,3 +1,18 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.dependencies import get_llm_client, get_storage_client +from app.models.image_models import ImageExtractRequest, ImageExtractResponse +from app.services import image_service router = APIRouter(tags=["Image"]) + + +@router.post("/image/extract", response_model=ImageExtractResponse) +async def extract_image( + req: ImageExtractRequest, + llm: LLMClient = Depends(get_llm_client), + storage: StorageClient = Depends(get_storage_client), +) -> ImageExtractResponse: + return await image_service.extract_quads(req, llm, storage) diff --git a/app/services/__pycache__/image_service.cpython-312.pyc b/app/services/__pycache__/image_service.cpython-312.pyc new file mode 100644 index 0000000..334c011 Binary files /dev/null and b/app/services/__pycache__/image_service.cpython-312.pyc differ diff --git a/app/services/image_service.py b/app/services/image_service.py new file mode 100644 index 0000000..ad7bcc0 --- /dev/null +++ b/app/services/image_service.py @@ -0,0 +1,90 @@ +import base64 +import io + +import cv2 +import numpy as np + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.config import get_config +from app.core.json_utils import extract_json +from app.core.logging import get_logger +from app.models.image_models import ( + BBox, + ImageExtractRequest, + ImageExtractResponse, + QuadrupleItem, +) + +logger = get_logger(__name__) + +_DEFAULT_PROMPT = ( + "请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:" + "subject(主体实体)、predicate(关系/属性)、object(客体实体)、" + "qualifier(修饰信息,可为 null)、bbox({{x, y, w, h}} 像素坐标)。" +) + + +async def extract_quads( + req: ImageExtractRequest, + llm: LLMClient, + storage: StorageClient, +) -> ImageExtractResponse: + cfg = get_config() + bucket = cfg["storage"]["buckets"]["source_data"] + model = req.model or cfg["models"]["default_vision"] + + image_bytes = await storage.download_bytes(bucket, req.file_path) + + # Decode with OpenCV for cropping; encode as base64 for LLM + nparr = np.frombuffer(image_bytes, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + img_h, img_w = img.shape[:2] + + b64 = base64.b64encode(image_bytes).decode() + image_data_url = f"data:image/jpeg;base64,{b64}" + + prompt = req.prompt_template or _DEFAULT_PROMPT + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_data_url}}, + {"type": "text", "text": prompt}, + ], + } + ] + + raw = await llm.chat_vision(model, messages) + logger.info("image_extract", extra={"file": req.file_path, "model": model}) + + items_raw = extract_json(raw) + items: list[QuadrupleItem] = [] + + for idx, item in enumerate(items_raw): + b = item["bbox"] + # Clamp bbox to image dimensions + x = max(0, min(int(b["x"]), img_w - 1)) + y = max(0, min(int(b["y"]), img_h - 1)) + w = min(int(b["w"]), img_w - x) + h = min(int(b["h"]), img_h - y) + + crop = img[y : y + h, x : x + w] + _, crop_buf = cv2.imencode(".jpg", crop) + crop_bytes = crop_buf.tobytes() + + crop_path = f"crops/{req.task_id}/{idx}.jpg" + await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg") + + items.append( + QuadrupleItem( + subject=item["subject"], + predicate=item["predicate"], + object=item["object"], + qualifier=item.get("qualifier"), + bbox=BBox(x=x, y=y, w=w, h=h), + cropped_image_path=crop_path, + ) + ) + + return ImageExtractResponse(items=items) diff --git a/tests/__pycache__/test_image_router.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/test_image_router.cpython-312-pytest-9.0.3.pyc new file mode 100644 index 0000000..a38db64 Binary files /dev/null and b/tests/__pycache__/test_image_router.cpython-312-pytest-9.0.3.pyc differ diff --git a/tests/__pycache__/test_image_service.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/test_image_service.cpython-312-pytest-9.0.3.pyc new file mode 100644 index 0000000..161aaaf Binary files /dev/null and b/tests/__pycache__/test_image_service.cpython-312-pytest-9.0.3.pyc differ diff --git a/tests/test_image_router.py b/tests/test_image_router.py new file mode 100644 index 0000000..e98ce31 --- /dev/null +++ b/tests/test_image_router.py @@ -0,0 +1,63 @@ +import json +import numpy as np +import cv2 +import pytest +from unittest.mock import AsyncMock + +from app.core.exceptions import StorageError + + +def _make_test_image_bytes() -> bytes: + img = np.zeros((80, 100, 3), dtype=np.uint8) + _, buf = cv2.imencode(".jpg", img) + return buf.tobytes() + + +SAMPLE_QUADS_JSON = json.dumps([ + { + "subject": "电缆接头", + "predicate": "位于", + "object": "配电箱左侧", + "qualifier": "2024年检修", + "bbox": {"x": 5, "y": 5, "w": 20, "h": 15}, + } +]) + + +def test_image_extract_returns_200(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=_make_test_image_bytes()) + mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON) + mock_storage.upload_bytes = AsyncMock(return_value=None) + + resp = client.post( + "/api/v1/image/extract", + json={"file_path": "image/test.jpg", "task_id": 1}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "items" in data + assert data["items"][0]["subject"] == "电缆接头" + assert data["items"][0]["cropped_image_path"] == "crops/1/0.jpg" + + +def test_image_extract_llm_parse_error_returns_502(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=_make_test_image_bytes()) + mock_llm.chat_vision = AsyncMock(return_value="not json {{") + + resp = client.post( + "/api/v1/image/extract", + json={"file_path": "image/test.jpg", "task_id": 1}, + ) + assert resp.status_code == 502 + assert resp.json()["code"] == "LLM_PARSE_ERROR" + + +def test_image_extract_storage_error_returns_502(client, mock_storage): + mock_storage.download_bytes = AsyncMock(side_effect=StorageError("RustFS down")) + + resp = client.post( + "/api/v1/image/extract", + json={"file_path": "image/test.jpg", "task_id": 1}, + ) + assert resp.status_code == 502 + assert resp.json()["code"] == "STORAGE_ERROR" diff --git a/tests/test_image_service.py b/tests/test_image_service.py new file mode 100644 index 0000000..ee6e8ae --- /dev/null +++ b/tests/test_image_service.py @@ -0,0 +1,102 @@ +import io +import json +import pytest +import numpy as np +import cv2 +from unittest.mock import AsyncMock + +from app.core.exceptions import LLMParseError +from app.models.image_models import ImageExtractRequest + + +def _make_test_image_bytes(width=100, height=80) -> bytes: + img = np.zeros((height, width, 3), dtype=np.uint8) + img[10:50, 10:60] = (255, 0, 0) # blue rectangle + _, buf = cv2.imencode(".jpg", img) + return buf.tobytes() + + +SAMPLE_QUADS_JSON = json.dumps([ + { + "subject": "电缆接头", + "predicate": "位于", + "object": "配电箱左侧", + "qualifier": "2024年检修", + "bbox": {"x": 10, "y": 10, "w": 40, "h": 30}, + } +]) + + +@pytest.fixture +def image_bytes(): + return _make_test_image_bytes() + + +@pytest.fixture +def req(): + return ImageExtractRequest(file_path="image/test.jpg", task_id=1) + + +@pytest.mark.asyncio +async def test_extract_quads_returns_items(mock_llm, mock_storage, image_bytes, req): + mock_storage.download_bytes = AsyncMock(return_value=image_bytes) + mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON) + mock_storage.upload_bytes = AsyncMock(return_value=None) + + from app.services.image_service import extract_quads + result = await extract_quads(req, mock_llm, mock_storage) + + assert len(result.items) == 1 + item = result.items[0] + assert item.subject == "电缆接头" + assert item.predicate == "位于" + assert item.bbox.x == 10 + assert item.bbox.y == 10 + assert item.cropped_image_path == "crops/1/0.jpg" + + +@pytest.mark.asyncio +async def test_crop_is_uploaded(mock_llm, mock_storage, image_bytes, req): + mock_storage.download_bytes = AsyncMock(return_value=image_bytes) + mock_llm.chat_vision = AsyncMock(return_value=SAMPLE_QUADS_JSON) + mock_storage.upload_bytes = AsyncMock(return_value=None) + + from app.services.image_service import extract_quads + await extract_quads(req, mock_llm, mock_storage) + + # upload_bytes called once for the crop + mock_storage.upload_bytes.assert_called_once() + call_args = mock_storage.upload_bytes.call_args + assert call_args.args[1] == "crops/1/0.jpg" + + +@pytest.mark.asyncio +async def test_out_of_bounds_bbox_is_clamped(mock_llm, mock_storage, req): + img = _make_test_image_bytes(width=50, height=40) + mock_storage.download_bytes = AsyncMock(return_value=img) + + # bbox goes outside image boundary + oob_json = json.dumps([{ + "subject": "test", + "predicate": "rel", + "object": "obj", + "qualifier": None, + "bbox": {"x": 30, "y": 20, "w": 100, "h": 100}, # extends beyond 50x40 + }]) + mock_llm.chat_vision = AsyncMock(return_value=oob_json) + mock_storage.upload_bytes = AsyncMock(return_value=None) + + from app.services.image_service import extract_quads + # Should not raise; bbox is clamped + result = await extract_quads(req, mock_llm, mock_storage) + assert len(result.items) == 1 + + +@pytest.mark.asyncio +async def test_llm_parse_error_raised(mock_llm, mock_storage, image_bytes, req): + mock_storage.download_bytes = AsyncMock(return_value=image_bytes) + mock_llm.chat_vision = AsyncMock(return_value="bad json {{") + + from app.services.image_service import extract_quads + with pytest.raises(LLMParseError): + await extract_quads(req, mock_llm, mock_storage)