Files
label_ai_service/docs/superpowers/plans/2026-04-10-ai-service-impl.md
wh 3892c6e60f docs: apply eng review findings to design doc and impl plan
Architecture fixes:
- Image QA: presigned URL → base64 (RustFS is internal, GLM-4V is cloud)
- Add GET /health endpoint + Docker healthcheck
- Video size limit: add get_object_size() to StorageClient ABC, check before background task
- Video size configurable via MAX_VIDEO_SIZE_MB env var (no image rebuild needed)
- Fix image_service.py except clause redundancy (Exception absorbs KeyError/TypeError)

Config additions:
- video.max_file_size_mb: 200 in config.yaml
- MAX_VIDEO_SIZE_MB env override in _ENV_OVERRIDES
2026-04-10 14:34:41 +08:00

83 KiB
Raw Blame History

AI Service Implementation Plan

For agentic workers: REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (- [ ]) syntax for tracking.

Goal: 实现 label_ai_service一个 Python FastAPI 服务,为知识图谱标注平台提供文本三元组提取、图像四元组提取、视频处理、问答对生成和 GLM 微调管理能力。

Architecture: 分层架构routersHTTP 入口)→ services业务逻辑→ clients外部适配层。LLMClient 和 StorageClient 均为 ABC当前分别实现 ZhipuAIClient 和 RustFSClient通过 FastAPI Depends 注入services 层不感知具体实现。视频任务用 FastAPI BackgroundTasks 异步执行,完成后回调 Java 后端。

Tech Stack: Python 3.12conda label 环境FastAPIZhipuAI SDKboto3S3OpenCVpdfplumberpython-docxhttpxpytest


Task 1: 项目脚手架

Files:

  • Create: app/__init__.py

  • Create: app/core/__init__.py

  • Create: app/clients/__init__.py

  • Create: app/clients/llm/__init__.py

  • Create: app/clients/storage/__init__.py

  • Create: app/services/__init__.py

  • Create: app/routers/__init__.py

  • Create: app/models/__init__.py

  • Create: tests/__init__.py

  • Create: tests/conftest.py

  • Create: config.yaml

  • Create: .env

  • Create: requirements.txt

  • Step 1: 创建包目录结构

mkdir -p app/core app/clients/llm app/clients/storage app/services app/routers app/models tests
touch app/__init__.py app/core/__init__.py
touch app/clients/__init__.py app/clients/llm/__init__.py app/clients/storage/__init__.py
touch app/services/__init__.py app/routers/__init__.py app/models/__init__.py
touch tests/__init__.py
  • Step 2: 创建 config.yaml
server:
  port: 8000
  log_level: INFO

storage:
  buckets:
    source_data: "source-data"
    finetune_export: "finetune-export"

backend: {}

video:
  frame_sample_count: 8
  max_file_size_mb: 200

models:
  default_text: "glm-4-flash"
  default_vision: "glm-4v-flash"
  • Step 3: 创建 .env
ZHIPUAI_API_KEY=your-zhipuai-api-key
STORAGE_ACCESS_KEY=minioadmin
STORAGE_SECRET_KEY=minioadmin
STORAGE_ENDPOINT=http://rustfs:9000
BACKEND_CALLBACK_URL=http://backend:8080/internal/video-job/callback
# MAX_VIDEO_SIZE_MB=200   # 可选,覆盖 config.yaml 中的视频大小上限
  • Step 4: 创建 requirements.txt
fastapi>=0.111
uvicorn[standard]>=0.29
pydantic>=2.7
python-dotenv>=1.0
pyyaml>=6.0
zhipuai>=2.1
boto3>=1.34
pdfplumber>=0.11
python-docx>=1.1
opencv-python-headless>=4.9
numpy>=1.26
httpx>=0.27
pytest>=8.0
pytest-asyncio>=0.23
  • Step 5: 创建 tests/conftest.py
import pytest
from unittest.mock import AsyncMock, MagicMock
from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient


@pytest.fixture
def mock_llm():
    client = MagicMock(spec=LLMClient)
    client.chat = AsyncMock()
    client.chat_vision = AsyncMock()
    return client


@pytest.fixture
def mock_storage():
    client = MagicMock(spec=StorageClient)
    client.download_bytes = AsyncMock()
    client.upload_bytes = AsyncMock()
    client.get_presigned_url = MagicMock(return_value="https://example.com/presigned/crop.jpg")
    client.get_object_size = AsyncMock(return_value=10 * 1024 * 1024)  # 默认 10MB小于限制
    return client
  • Step 6: 安装依赖
conda run -n label pip install -r requirements.txt

Expected: 所有包安装成功,无错误

  • Step 7: Commit
git add app/ tests/ config.yaml .env requirements.txt
git commit -m "feat: project scaffold - directory structure and config files"

Task 2: Core Config 模块

Files:

  • Create: app/core/config.py

  • Create: tests/test_config.py

  • Step 1: 编写失败测试

tests/test_config.py:

import pytest
from unittest.mock import patch, mock_open
from app.core.config import get_config

MOCK_YAML = """
server:
  port: 8000
  log_level: INFO
storage:
  buckets:
    source_data: "source-data"
    finetune_export: "finetune-export"
backend: {}
video:
  frame_sample_count: 8
models:
  default_text: "glm-4-flash"
  default_vision: "glm-4v-flash"
"""


def _fresh_config(monkeypatch, extra_env: dict = None):
    """每次测试前清除 lru_cache设置环境变量。"""
    get_config.cache_clear()
    base_env = {
        "ZHIPUAI_API_KEY": "test-key",
        "STORAGE_ACCESS_KEY": "test-access",
        "STORAGE_SECRET_KEY": "test-secret",
        "STORAGE_ENDPOINT": "http://localhost:9000",
        "BACKEND_CALLBACK_URL": "http://localhost:8080/callback",
    }
    if extra_env:
        base_env.update(extra_env)
    for k, v in base_env.items():
        monkeypatch.setenv(k, v)


def test_env_overrides_yaml(monkeypatch):
    _fresh_config(monkeypatch)
    with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
        with patch("app.core.config.load_dotenv"):
            cfg = get_config()
    assert cfg["zhipuai"]["api_key"] == "test-key"
    assert cfg["storage"]["access_key"] == "test-access"
    assert cfg["storage"]["endpoint"] == "http://localhost:9000"
    assert cfg["backend"]["callback_url"] == "http://localhost:8080/callback"
    get_config.cache_clear()


def test_yaml_values_preserved(monkeypatch):
    _fresh_config(monkeypatch)
    with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
        with patch("app.core.config.load_dotenv"):
            cfg = get_config()
    assert cfg["models"]["default_text"] == "glm-4-flash"
    assert cfg["video"]["frame_sample_count"] == 8
    assert cfg["storage"]["buckets"]["source_data"] == "source-data"
    get_config.cache_clear()


def test_missing_api_key_raises(monkeypatch):
    get_config.cache_clear()
    monkeypatch.delenv("ZHIPUAI_API_KEY", raising=False)
    monkeypatch.setenv("STORAGE_ACCESS_KEY", "a")
    monkeypatch.setenv("STORAGE_SECRET_KEY", "b")
    with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
        with patch("app.core.config.load_dotenv"):
            with pytest.raises(RuntimeError, match="ZHIPUAI_API_KEY"):
                get_config()
    get_config.cache_clear()


def test_missing_storage_key_raises(monkeypatch):
    get_config.cache_clear()
    monkeypatch.setenv("ZHIPUAI_API_KEY", "key")
    monkeypatch.delenv("STORAGE_ACCESS_KEY", raising=False)
    monkeypatch.setenv("STORAGE_SECRET_KEY", "b")
    with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
        with patch("app.core.config.load_dotenv"):
            with pytest.raises(RuntimeError, match="STORAGE_ACCESS_KEY"):
                get_config()
    get_config.cache_clear()
  • Step 2: 运行,确认失败
conda run -n label pytest tests/test_config.py -v

Expected: ImportError: cannot import name 'get_config'

  • Step 3: 实现 app/core/config.py
import os
import yaml
from functools import lru_cache
from pathlib import Path
from dotenv import load_dotenv

_ROOT = Path(__file__).parent.parent.parent

_ENV_OVERRIDES = {
    "ZHIPUAI_API_KEY":       ["zhipuai", "api_key"],
    "STORAGE_ACCESS_KEY":    ["storage", "access_key"],
    "STORAGE_SECRET_KEY":    ["storage", "secret_key"],
    "STORAGE_ENDPOINT":      ["storage", "endpoint"],
    "BACKEND_CALLBACK_URL":  ["backend", "callback_url"],
    "LOG_LEVEL":             ["server", "log_level"],
    "MAX_VIDEO_SIZE_MB":     ["video", "max_file_size_mb"],
}


def _set_nested(d: dict, keys: list[str], value: str) -> None:
    for k in keys[:-1]:
        d = d.setdefault(k, {})
    d[keys[-1]] = value


@lru_cache(maxsize=1)
def get_config() -> dict:
    load_dotenv(_ROOT / ".env")
    with open(_ROOT / "config.yaml", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    for env_key, yaml_path in _ENV_OVERRIDES.items():
        val = os.environ.get(env_key)
        if val:
            _set_nested(cfg, yaml_path, val)
    _validate(cfg)
    return cfg


def _validate(cfg: dict) -> None:
    checks = [
        (["zhipuai", "api_key"],    "ZHIPUAI_API_KEY"),
        (["storage", "access_key"], "STORAGE_ACCESS_KEY"),
        (["storage", "secret_key"], "STORAGE_SECRET_KEY"),
    ]
    for path, name in checks:
        val = cfg
        for k in path:
            val = (val or {}).get(k, "")
        if not val:
            raise RuntimeError(f"缺少必要配置项:{name}")
  • Step 4: 运行,确认通过
conda run -n label pytest tests/test_config.py -v

Expected: 4 passed

  • Step 5: Commit
git add app/core/config.py tests/test_config.py
git commit -m "feat: core config module with YAML + env layered loading"

Task 3: Core Logging、Exceptions、JSON Utils

Files:

  • Create: app/core/logging.py

  • Create: app/core/exceptions.py

  • Create: app/core/json_utils.py

  • Step 1: 实现 app/core/logging.py

import json
import logging
import time
from typing import Callable

from fastapi import Request, Response


class _JSONFormatter(logging.Formatter):
    def format(self, record: logging.LogRecord) -> str:
        entry: dict = {
            "time": self.formatTime(record),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
        }
        if record.exc_info:
            entry["exception"] = self.formatException(record.exc_info)
        return json.dumps(entry, ensure_ascii=False)


def setup_logging(log_level: str = "INFO") -> None:
    handler = logging.StreamHandler()
    handler.setFormatter(_JSONFormatter())
    root = logging.getLogger()
    root.handlers.clear()
    root.addHandler(handler)
    root.setLevel(getattr(logging, log_level.upper(), logging.INFO))


async def request_logging_middleware(request: Request, call_next: Callable) -> Response:
    start = time.monotonic()
    response = await call_next(request)
    duration_ms = round((time.monotonic() - start) * 1000, 2)
    logging.getLogger("api").info(
        f"method={request.method} path={request.url.path} "
        f"status={response.status_code} duration_ms={duration_ms}"
    )
    return response
  • Step 2: 实现 app/core/exceptions.py
import logging
from fastapi import Request
from fastapi.responses import JSONResponse


class UnsupportedFileTypeError(Exception):
    def __init__(self, ext: str):
        super().__init__(f"不支持的文件类型:{ext}")


class StorageDownloadError(Exception):
    pass


class LLMResponseParseError(Exception):
    pass


class LLMCallError(Exception):
    pass


async def unsupported_file_type_handler(request: Request, exc: UnsupportedFileTypeError):
    return JSONResponse(
        status_code=400,
        content={"code": "UNSUPPORTED_FILE_TYPE", "message": str(exc)},
    )


async def storage_download_handler(request: Request, exc: StorageDownloadError):
    return JSONResponse(
        status_code=502,
        content={"code": "STORAGE_ERROR", "message": str(exc)},
    )


async def llm_parse_handler(request: Request, exc: LLMResponseParseError):
    return JSONResponse(
        status_code=502,
        content={"code": "LLM_PARSE_ERROR", "message": str(exc)},
    )


async def llm_call_handler(request: Request, exc: LLMCallError):
    return JSONResponse(
        status_code=503,
        content={"code": "LLM_CALL_ERROR", "message": str(exc)},
    )


async def generic_error_handler(request: Request, exc: Exception):
    logging.getLogger("error").exception("未捕获异常")
    return JSONResponse(
        status_code=500,
        content={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
    )
  • Step 3: 实现 app/core/json_utils.py
import json
from app.core.exceptions import LLMResponseParseError


def parse_json_response(raw: str) -> list | dict:
    """从 GLM 响应中解析 JSON兼容 markdown 代码块包裹格式。"""
    content = raw.strip()
    if "```json" in content:
        content = content.split("```json")[1].split("```")[0]
    elif "```" in content:
        content = content.split("```")[1].split("```")[0]
    content = content.strip()
    try:
        return json.loads(content)
    except json.JSONDecodeError as e:
        raise LLMResponseParseError(
            f"GLM 返回内容无法解析为 JSON: {raw[:200]}"
        ) from e
  • Step 4: Commit
git add app/core/logging.py app/core/exceptions.py app/core/json_utils.py
git commit -m "feat: core logging, exceptions, json utils"

Task 4: LLM 适配层

Files:

  • Create: app/clients/llm/base.py

  • Create: app/clients/llm/zhipuai_client.py

  • Create: tests/test_llm_client.py

  • Step 1: 编写失败测试

tests/test_llm_client.py:

import asyncio
import pytest
from unittest.mock import MagicMock, patch
from app.clients.llm.zhipuai_client import ZhipuAIClient


@pytest.fixture
def zhipuai_client():
    with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI:
        mock_sdk = MagicMock()
        MockZhipuAI.return_value = mock_sdk
        client = ZhipuAIClient(api_key="test-key")
        client._mock_sdk = mock_sdk
        yield client


def test_chat_returns_content(zhipuai_client):
    mock_resp = MagicMock()
    mock_resp.choices[0].message.content = "三元组提取结果"
    zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp

    result = asyncio.run(
        zhipuai_client.chat(
            messages=[{"role": "user", "content": "提取三元组"}],
            model="glm-4-flash",
        )
    )
    assert result == "三元组提取结果"
    zhipuai_client._mock_sdk.chat.completions.create.assert_called_once()


def test_chat_vision_calls_same_endpoint(zhipuai_client):
    mock_resp = MagicMock()
    mock_resp.choices[0].message.content = "图像分析结果"
    zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp

    result = asyncio.run(
        zhipuai_client.chat_vision(
            messages=[{"role": "user", "content": [{"type": "text", "text": "分析"}]}],
            model="glm-4v-flash",
        )
    )
    assert result == "图像分析结果"
  • Step 2: 运行,确认失败
conda run -n label pytest tests/test_llm_client.py -v

Expected: ImportError

  • Step 3: 实现 app/clients/llm/base.py
from abc import ABC, abstractmethod


class LLMClient(ABC):
    @abstractmethod
    async def chat(self, messages: list[dict], model: str, **kwargs) -> str:
        """纯文本对话,返回模型输出文本。"""

    @abstractmethod
    async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str:
        """多模态对话(图文混合输入),返回模型输出文本。"""
  • Step 4: 实现 app/clients/llm/zhipuai_client.py
import asyncio
from zhipuai import ZhipuAI
from app.clients.llm.base import LLMClient


class ZhipuAIClient(LLMClient):
    def __init__(self, api_key: str):
        self._client = ZhipuAI(api_key=api_key)

    async def chat(self, messages: list[dict], model: str, **kwargs) -> str:
        loop = asyncio.get_event_loop()
        resp = await loop.run_in_executor(
            None,
            lambda: self._client.chat.completions.create(
                model=model, messages=messages, **kwargs
            ),
        )
        return resp.choices[0].message.content

    async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str:
        # GLM-4V 与文本接口相同,通过 image_url type 区分图文消息
        return await self.chat(messages, model, **kwargs)
  • Step 5: 运行,确认通过
conda run -n label pytest tests/test_llm_client.py -v

Expected: 2 passed

  • Step 6: Commit
git add app/clients/llm/ tests/test_llm_client.py
git commit -m "feat: LLMClient ABC and ZhipuAI implementation"

Task 5: Storage 适配层

Files:

  • Create: app/clients/storage/base.py

  • Create: app/clients/storage/rustfs_client.py

  • Create: tests/test_storage_client.py

  • Step 1: 编写失败测试

tests/test_storage_client.py:

import asyncio
import pytest
from unittest.mock import MagicMock, patch
from app.clients.storage.rustfs_client import RustFSClient


@pytest.fixture
def rustfs_client():
    with patch("app.clients.storage.rustfs_client.boto3") as mock_boto3:
        mock_s3 = MagicMock()
        mock_boto3.client.return_value = mock_s3
        client = RustFSClient(
            endpoint="http://localhost:9000",
            access_key="minioadmin",
            secret_key="minioadmin",
        )
        client._mock_s3 = mock_s3
        yield client


def test_download_bytes(rustfs_client):
    mock_body = MagicMock()
    mock_body.read.return_value = b"file content"
    rustfs_client._mock_s3.get_object.return_value = {"Body": mock_body}

    result = asyncio.run(
        rustfs_client.download_bytes("source-data", "text/202404/1.txt")
    )
    assert result == b"file content"
    rustfs_client._mock_s3.get_object.assert_called_once_with(
        Bucket="source-data", Key="text/202404/1.txt"
    )


def test_upload_bytes(rustfs_client):
    asyncio.run(
        rustfs_client.upload_bytes("source-data", "crops/1/0.jpg", b"img", "image/jpeg")
    )
    rustfs_client._mock_s3.put_object.assert_called_once_with(
        Bucket="source-data", Key="crops/1/0.jpg", Body=b"img", ContentType="image/jpeg"
    )


def test_get_presigned_url(rustfs_client):
    rustfs_client._mock_s3.generate_presigned_url.return_value = "https://example.com/signed"
    url = rustfs_client.get_presigned_url("source-data", "crops/1/0.jpg", expires=3600)
    assert url == "https://example.com/signed"
    rustfs_client._mock_s3.generate_presigned_url.assert_called_once_with(
        "get_object",
        Params={"Bucket": "source-data", "Key": "crops/1/0.jpg"},
        ExpiresIn=3600,
    )


def test_get_object_size(rustfs_client):
    rustfs_client._mock_s3.head_object.return_value = {"ContentLength": 1024 * 1024 * 50}
    size = asyncio.run(rustfs_client.get_object_size("source-data", "video/1.mp4"))
    assert size == 1024 * 1024 * 50
    rustfs_client._mock_s3.head_object.assert_called_once_with(
        Bucket="source-data", Key="video/1.mp4"
    )
  • Step 2: 运行,确认失败
conda run -n label pytest tests/test_storage_client.py -v

Expected: ImportError

  • Step 3: 实现 app/clients/storage/base.py
from abc import ABC, abstractmethod


class StorageClient(ABC):
    @abstractmethod
    async def download_bytes(self, bucket: str, path: str) -> bytes:
        """从对象存储下载文件,返回字节内容。"""

    @abstractmethod
    async def upload_bytes(
        self,
        bucket: str,
        path: str,
        data: bytes,
        content_type: str = "application/octet-stream",
    ) -> None:
        """上传字节内容到对象存储。"""

    @abstractmethod
    def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
        """生成预签名访问 URL。"""

    @abstractmethod
    async def get_object_size(self, bucket: str, path: str) -> int:
        """返回对象字节大小,用于在下载前进行大小校验。"""
  • Step 4: 实现 app/clients/storage/rustfs_client.py
import asyncio
import boto3
from app.clients.storage.base import StorageClient


class RustFSClient(StorageClient):
    def __init__(self, endpoint: str, access_key: str, secret_key: str):
        self._s3 = boto3.client(
            "s3",
            endpoint_url=endpoint,
            aws_access_key_id=access_key,
            aws_secret_access_key=secret_key,
        )

    async def download_bytes(self, bucket: str, path: str) -> bytes:
        loop = asyncio.get_event_loop()
        resp = await loop.run_in_executor(
            None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
        )
        return resp["Body"].read()

    async def upload_bytes(
        self,
        bucket: str,
        path: str,
        data: bytes,
        content_type: str = "application/octet-stream",
    ) -> None:
        loop = asyncio.get_event_loop()
        await loop.run_in_executor(
            None,
            lambda: self._s3.put_object(
                Bucket=bucket, Key=path, Body=data, ContentType=content_type
            ),
        )

    def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
        return self._s3.generate_presigned_url(
            "get_object",
            Params={"Bucket": bucket, "Key": path},
            ExpiresIn=expires,
        )

    async def get_object_size(self, bucket: str, path: str) -> int:
        loop = asyncio.get_event_loop()
        resp = await loop.run_in_executor(
            None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
        )
        return resp["ContentLength"]
  • Step 5: 运行,确认通过
conda run -n label pytest tests/test_storage_client.py -v

Expected: 4 passed

  • Step 6: Commit
git add app/clients/storage/ tests/test_storage_client.py
git commit -m "feat: StorageClient ABC and RustFS S3 implementation"

Task 6: 依赖注入 + FastAPI 应用入口

Files:

  • Create: app/core/dependencies.py

  • Create: app/main.py

  • Step 1: 实现 app/core/dependencies.py

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient

_llm_client: LLMClient | None = None
_storage_client: StorageClient | None = None


def set_clients(llm: LLMClient, storage: StorageClient) -> None:
    global _llm_client, _storage_client
    _llm_client, _storage_client = llm, storage


def get_llm_client() -> LLMClient:
    return _llm_client


def get_storage_client() -> StorageClient:
    return _storage_client
  • Step 2: 实现 app/main.py

注意routers 在后续任务中创建,先注释掉 include_router待各路由实现后逐步取消注释。

import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI

from app.core.config import get_config
from app.core.dependencies import set_clients
from app.core.exceptions import (
    LLMCallError,
    LLMResponseParseError,
    StorageDownloadError,
    UnsupportedFileTypeError,
    generic_error_handler,
    llm_call_handler,
    llm_parse_handler,
    storage_download_handler,
    unsupported_file_type_handler,
)
from app.core.logging import request_logging_middleware, setup_logging
from app.clients.llm.zhipuai_client import ZhipuAIClient
from app.clients.storage.rustfs_client import RustFSClient


@asynccontextmanager
async def lifespan(app: FastAPI):
    cfg = get_config()
    setup_logging(cfg["server"]["log_level"])
    set_clients(
        llm=ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"]),
        storage=RustFSClient(
            endpoint=cfg["storage"]["endpoint"],
            access_key=cfg["storage"]["access_key"],
            secret_key=cfg["storage"]["secret_key"],
        ),
    )
    logging.getLogger("startup").info("AI 服务启动完成")
    yield
    logging.getLogger("startup").info("AI 服务关闭")


app = FastAPI(title="Label AI Service", version="1.0.0", lifespan=lifespan)

app.middleware("http")(request_logging_middleware)


@app.get("/health", tags=["Health"])
async def health():
    return {"status": "ok"}

app.add_exception_handler(UnsupportedFileTypeError, unsupported_file_type_handler)
app.add_exception_handler(StorageDownloadError, storage_download_handler)
app.add_exception_handler(LLMResponseParseError, llm_parse_handler)
app.add_exception_handler(LLMCallError, llm_call_handler)
app.add_exception_handler(Exception, generic_error_handler)

# Routers registered after each task:
# from app.routers import text, image, video, qa, finetune
# app.include_router(text.router, prefix="/api/v1")
# app.include_router(image.router, prefix="/api/v1")
# app.include_router(video.router, prefix="/api/v1")
# app.include_router(qa.router, prefix="/api/v1")
# app.include_router(finetune.router, prefix="/api/v1")
  • Step 3: 验证 /health 端点
conda run -n label python -c "
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
r = client.get('/health')
assert r.status_code == 200 and r.json() == {'status': 'ok'}, r.json()
print('health check OK')
"

Expected: health check OK

  • Step 4: Commit
git add app/core/dependencies.py app/main.py
git commit -m "feat: DI dependencies, FastAPI app entry with lifespan and /health endpoint"

Task 7: Text Pydantic Models

Files:

  • Create: app/models/text_models.py

  • Step 1: 实现 app/models/text_models.py

from pydantic import BaseModel


class SourceOffset(BaseModel):
    start: int
    end: int


class TripleItem(BaseModel):
    subject: str
    predicate: str
    object: str
    source_snippet: str
    source_offset: SourceOffset


class TextExtractRequest(BaseModel):
    file_path: str
    file_name: str
    model: str | None = None
    prompt_template: str | None = None


class TextExtractResponse(BaseModel):
    items: list[TripleItem]
  • Step 2: 快速验证 schema
conda run -n label python -c "
from app.models.text_models import TextExtractRequest, TextExtractResponse, TripleItem, SourceOffset
req = TextExtractRequest(file_path='text/1.txt', file_name='1.txt')
item = TripleItem(subject='A', predicate='B', object='C', source_snippet='ABC', source_offset=SourceOffset(start=0, end=3))
resp = TextExtractResponse(items=[item])
print(resp.model_dump())
"

Expected: 打印出完整字典,无报错

  • Step 3: Commit
git add app/models/text_models.py
git commit -m "feat: text Pydantic models"

Task 8: Text Service

Files:

  • Create: app/services/text_service.py

  • Create: tests/test_text_service.py

  • Step 1: 编写失败测试

tests/test_text_service.py:

import pytest
from app.services.text_service import extract_triples, _extract_text_from_bytes
from app.core.exceptions import UnsupportedFileTypeError, LLMResponseParseError, StorageDownloadError

TRIPLE_JSON = '[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压为110kV","source_offset":{"start":0,"end":10}}]'


@pytest.mark.asyncio
async def test_extract_triples_txt(mock_llm, mock_storage):
    mock_storage.download_bytes.return_value = b"变压器额定电压为110kV"
    mock_llm.chat.return_value = TRIPLE_JSON

    result = await extract_triples(
        file_path="text/1.txt",
        file_name="test.txt",
        model="glm-4-flash",
        prompt_template="提取三元组:",
        llm=mock_llm,
        storage=mock_storage,
    )
    assert len(result) == 1
    assert result[0].subject == "变压器"
    assert result[0].predicate == "额定电压"
    assert result[0].object == "110kV"
    assert result[0].source_offset.start == 0


@pytest.mark.asyncio
async def test_extract_triples_markdown_wrapped_json(mock_llm, mock_storage):
    mock_storage.download_bytes.return_value = b"some text"
    mock_llm.chat.return_value = f"```json\n{TRIPLE_JSON}\n```"

    result = await extract_triples(
        file_path="text/1.txt",
        file_name="test.txt",
        model="glm-4-flash",
        prompt_template="",
        llm=mock_llm,
        storage=mock_storage,
    )
    assert len(result) == 1


@pytest.mark.asyncio
async def test_extract_triples_storage_error(mock_llm, mock_storage):
    mock_storage.download_bytes.side_effect = Exception("connection refused")

    with pytest.raises(StorageDownloadError):
        await extract_triples(
            file_path="text/1.txt",
            file_name="test.txt",
            model="glm-4-flash",
            prompt_template="",
            llm=mock_llm,
            storage=mock_storage,
        )


@pytest.mark.asyncio
async def test_extract_triples_llm_parse_error(mock_llm, mock_storage):
    mock_storage.download_bytes.return_value = b"some text"
    mock_llm.chat.return_value = "这不是JSON"

    with pytest.raises(LLMResponseParseError):
        await extract_triples(
            file_path="text/1.txt",
            file_name="test.txt",
            model="glm-4-flash",
            prompt_template="",
            llm=mock_llm,
            storage=mock_storage,
        )


def test_unsupported_file_type_raises():
    with pytest.raises(UnsupportedFileTypeError):
        _extract_text_from_bytes(b"content", "doc.xlsx")


def test_parse_txt_bytes():
    result = _extract_text_from_bytes("你好世界".encode("utf-8"), "file.txt")
    assert result == "你好世界"
  • Step 2: 运行,确认失败
conda run -n label pytest tests/test_text_service.py -v

Expected: ImportError

  • Step 3: 实现 app/services/text_service.py
import logging
from pathlib import Path

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError, UnsupportedFileTypeError
from app.core.json_utils import parse_json_response
from app.models.text_models import SourceOffset, TripleItem

logger = logging.getLogger(__name__)

DEFAULT_PROMPT = """请从以下文本中提取知识三元组。
对每个三元组提供:
- subject主语实体
- predicate谓语关系
- object宾语实体
- source_snippet原文中的证据片段直接引用原文
- source_offset证据片段字符偏移 {"start": N, "end": M}

以 JSON 数组格式返回,例如:
[{"subject":"...","predicate":"...","object":"...","source_snippet":"...","source_offset":{"start":0,"end":50}}]

文本内容:
"""


def _parse_txt(data: bytes) -> str:
    return data.decode("utf-8")


def _parse_pdf(data: bytes) -> str:
    import io
    import pdfplumber
    with pdfplumber.open(io.BytesIO(data)) as pdf:
        return "\n".join(page.extract_text() or "" for page in pdf.pages)


def _parse_docx(data: bytes) -> str:
    import io
    import docx
    doc = docx.Document(io.BytesIO(data))
    return "\n".join(p.text for p in doc.paragraphs if p.text.strip())


_PARSERS = {
    ".txt": _parse_txt,
    ".pdf": _parse_pdf,
    ".docx": _parse_docx,
}


def _extract_text_from_bytes(data: bytes, filename: str) -> str:
    ext = Path(filename).suffix.lower()
    parser = _PARSERS.get(ext)
    if parser is None:
        raise UnsupportedFileTypeError(ext)
    return parser(data)


async def extract_triples(
    file_path: str,
    file_name: str,
    model: str,
    prompt_template: str,
    llm: LLMClient,
    storage: StorageClient,
    bucket: str = "source-data",
) -> list[TripleItem]:
    try:
        data = await storage.download_bytes(bucket, file_path)
    except Exception as e:
        raise StorageDownloadError(f"下载文件失败 {file_path}: {e}") from e

    text = _extract_text_from_bytes(data, file_name)
    prompt = prompt_template or DEFAULT_PROMPT

    messages = [
        {"role": "system", "content": "你是专业的知识图谱构建助手,擅长从文本中提取结构化知识三元组。"},
        {"role": "user", "content": prompt + text},
    ]

    try:
        raw = await llm.chat(messages, model)
    except Exception as e:
        raise LLMCallError(f"GLM 调用失败: {e}") from e

    logger.info(f"text_extract file={file_path} model={model}")

    items_raw = parse_json_response(raw)

    result = []
    for item in items_raw:
        try:
            offset = item.get("source_offset", {})
            result.append(TripleItem(
                subject=item["subject"],
                predicate=item["predicate"],
                object=item["object"],
                source_snippet=item.get("source_snippet", ""),
                source_offset=SourceOffset(
                    start=offset.get("start", 0),
                    end=offset.get("end", 0),
                ),
            ))
        except (KeyError, TypeError) as e:
            logger.warning(f"跳过不完整三元组: {item}, error: {e}")

    return result
  • Step 4: 运行,确认通过
conda run -n label pytest tests/test_text_service.py -v

Expected: 6 passed

  • Step 5: Commit
git add app/services/text_service.py tests/test_text_service.py
git commit -m "feat: text service with txt/pdf/docx parsing and triple extraction"

Task 9: Text Router

Files:

  • Create: app/routers/text.py

  • Create: tests/test_text_router.py

  • Step 1: 编写失败测试

tests/test_text_router.py:

import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch
from app.main import app
from app.core.dependencies import set_clients
from app.models.text_models import TripleItem, SourceOffset


@pytest.fixture
def client(mock_llm, mock_storage):
    set_clients(mock_llm, mock_storage)
    return TestClient(app)


def test_text_extract_success(client, mock_llm, mock_storage):
    mock_storage.download_bytes = AsyncMock(return_value=b"变压器额定电压110kV")
    mock_llm.chat = AsyncMock(return_value='[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压110kV","source_offset":{"start":3,"end":10}}]')

    resp = client.post("/api/v1/text/extract", json={
        "file_path": "text/202404/1.txt",
        "file_name": "规范.txt",
    })
    assert resp.status_code == 200
    data = resp.json()
    assert len(data["items"]) == 1
    assert data["items"][0]["subject"] == "变压器"


def test_text_extract_unsupported_file(client, mock_llm, mock_storage):
    mock_storage.download_bytes = AsyncMock(return_value=b"content")
    resp = client.post("/api/v1/text/extract", json={
        "file_path": "text/202404/1.xlsx",
        "file_name": "file.xlsx",
    })
    assert resp.status_code == 400
    assert resp.json()["code"] == "UNSUPPORTED_FILE_TYPE"
  • Step 2: 实现 app/routers/text.py
from fastapi import APIRouter, Depends

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.dependencies import get_llm_client, get_storage_client
from app.models.text_models import TextExtractRequest, TextExtractResponse
from app.services import text_service

router = APIRouter(tags=["Text"])


@router.post("/text/extract", response_model=TextExtractResponse)
async def extract_text(
    req: TextExtractRequest,
    llm: LLMClient = Depends(get_llm_client),
    storage: StorageClient = Depends(get_storage_client),
):
    cfg = get_config()
    model = req.model or cfg["models"]["default_text"]
    prompt = req.prompt_template or text_service.DEFAULT_PROMPT

    items = await text_service.extract_triples(
        file_path=req.file_path,
        file_name=req.file_name,
        model=model,
        prompt_template=prompt,
        llm=llm,
        storage=storage,
        bucket=cfg["storage"]["buckets"]["source_data"],
    )
    return TextExtractResponse(items=items)
  • Step 3: 在 app/main.py 注册路由

取消注释以下两行:

from app.routers import text
app.include_router(text.router, prefix="/api/v1")
  • Step 4: 运行测试
conda run -n label pytest tests/test_text_router.py -v

Expected: 2 passed

  • Step 5: Commit
git add app/routers/text.py tests/test_text_router.py app/main.py
git commit -m "feat: text router POST /api/v1/text/extract"

Task 10: Image Models + Service

Files:

  • Create: app/models/image_models.py

  • Create: app/services/image_service.py

  • Create: tests/test_image_service.py

  • Step 1: 实现 app/models/image_models.py

from pydantic import BaseModel


class BBox(BaseModel):
    x: int
    y: int
    w: int
    h: int


class QuadrupleItem(BaseModel):
    subject: str
    predicate: str
    object: str
    qualifier: str
    bbox: BBox
    cropped_image_path: str


class ImageExtractRequest(BaseModel):
    file_path: str
    task_id: int
    model: str | None = None
    prompt_template: str | None = None


class ImageExtractResponse(BaseModel):
    items: list[QuadrupleItem]
  • Step 2: 编写失败测试

tests/test_image_service.py:

import pytest
import numpy as np
import cv2
from app.services.image_service import extract_quadruples, _crop_image
from app.models.image_models import BBox
from app.core.exceptions import LLMResponseParseError, StorageDownloadError

QUAD_JSON = '[{"subject":"电缆接头","predicate":"位于","object":"配电箱左侧","qualifier":"2024年","bbox":{"x":10,"y":20,"w":50,"h":40}}]'


def _make_test_image_bytes(width=200, height=200) -> bytes:
    img = np.zeros((height, width, 3), dtype=np.uint8)
    img[:] = (100, 150, 200)
    _, buf = cv2.imencode(".jpg", img)
    return buf.tobytes()


def test_crop_image():
    img_bytes = _make_test_image_bytes(200, 200)
    bbox = BBox(x=10, y=20, w=50, h=40)
    result = _crop_image(img_bytes, bbox)
    assert isinstance(result, bytes)
    arr = np.frombuffer(result, dtype=np.uint8)
    img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
    assert img.shape[0] == 40  # height
    assert img.shape[1] == 50  # width


@pytest.mark.asyncio
async def test_extract_quadruples_success(mock_llm, mock_storage):
    mock_storage.download_bytes.return_value = _make_test_image_bytes()
    mock_llm.chat_vision.return_value = QUAD_JSON
    mock_storage.upload_bytes.return_value = None

    result = await extract_quadruples(
        file_path="image/202404/1.jpg",
        task_id=789,
        model="glm-4v-flash",
        prompt_template="提取四元组",
        llm=mock_llm,
        storage=mock_storage,
    )
    assert len(result) == 1
    assert result[0].subject == "电缆接头"
    assert result[0].cropped_image_path == "crops/789/0.jpg"
    mock_storage.upload_bytes.assert_called_once()


@pytest.mark.asyncio
async def test_extract_quadruples_storage_error(mock_llm, mock_storage):
    mock_storage.download_bytes.side_effect = Exception("timeout")
    with pytest.raises(StorageDownloadError):
        await extract_quadruples(
            file_path="image/1.jpg",
            task_id=1,
            model="glm-4v-flash",
            prompt_template="",
            llm=mock_llm,
            storage=mock_storage,
        )


@pytest.mark.asyncio
async def test_extract_quadruples_parse_error(mock_llm, mock_storage):
    mock_storage.download_bytes.return_value = _make_test_image_bytes()
    mock_llm.chat_vision.return_value = "不是JSON"
    with pytest.raises(LLMResponseParseError):
        await extract_quadruples(
            file_path="image/1.jpg",
            task_id=1,
            model="glm-4v-flash",
            prompt_template="",
            llm=mock_llm,
            storage=mock_storage,
        )
  • Step 3: 运行,确认失败
conda run -n label pytest tests/test_image_service.py -v

Expected: ImportError

  • Step 4: 实现 app/services/image_service.py
import base64
import logging
from pathlib import Path

import cv2
import numpy as np

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError
from app.core.json_utils import parse_json_response
from app.models.image_models import BBox, QuadrupleItem

logger = logging.getLogger(__name__)

DEFAULT_PROMPT = """请分析这张图片,提取知识四元组。
对每个四元组提供:
- subject主体实体
- predicate关系/属性
- object客体实体
- qualifier修饰信息时间、条件、场景无则填空字符串
- bbox边界框 {"x": N, "y": N, "w": N, "h": N}(像素坐标,相对原图)

以 JSON 数组格式返回:
[{"subject":"...","predicate":"...","object":"...","qualifier":"...","bbox":{"x":0,"y":0,"w":100,"h":100}}]
"""


def _crop_image(image_bytes: bytes, bbox: BBox) -> bytes:
    arr = np.frombuffer(image_bytes, dtype=np.uint8)
    img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
    h, w = img.shape[:2]
    x = max(0, bbox.x)
    y = max(0, bbox.y)
    x2 = min(w, bbox.x + bbox.w)
    y2 = min(h, bbox.y + bbox.h)
    cropped = img[y:y2, x:x2]
    _, buf = cv2.imencode(".jpg", cropped, [cv2.IMWRITE_JPEG_QUALITY, 90])
    return buf.tobytes()


async def extract_quadruples(
    file_path: str,
    task_id: int,
    model: str,
    prompt_template: str,
    llm: LLMClient,
    storage: StorageClient,
    source_bucket: str = "source-data",
) -> list[QuadrupleItem]:
    try:
        data = await storage.download_bytes(source_bucket, file_path)
    except Exception as e:
        raise StorageDownloadError(f"下载图片失败 {file_path}: {e}") from e

    ext = Path(file_path).suffix.lstrip(".") or "jpeg"
    b64 = base64.b64encode(data).decode()

    messages = [
        {"role": "system", "content": "你是专业的视觉分析助手,擅长从图像中提取结构化知识四元组。"},
        {"role": "user", "content": [
            {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}},
            {"type": "text", "text": prompt_template or DEFAULT_PROMPT},
        ]},
    ]

    try:
        raw = await llm.chat_vision(messages, model)
    except Exception as e:
        raise LLMCallError(f"GLM-4V 调用失败: {e}") from e

    logger.info(f"image_extract file={file_path} model={model}")
    items_raw = parse_json_response(raw)

    result = []
    for i, item in enumerate(items_raw):
        try:
            bbox = BBox(**item["bbox"])
            cropped = _crop_image(data, bbox)
            crop_path = f"crops/{task_id}/{i}.jpg"
            await storage.upload_bytes(source_bucket, crop_path, cropped, "image/jpeg")
            result.append(QuadrupleItem(
                subject=item["subject"],
                predicate=item["predicate"],
                object=item["object"],
                qualifier=item.get("qualifier", ""),
                bbox=bbox,
                cropped_image_path=crop_path,
            ))
        except Exception as e:
            logger.warning(f"跳过不完整四元组 index={i}: {e}")

    return result
  • Step 5: 运行,确认通过
conda run -n label pytest tests/test_image_service.py -v

Expected: 4 passed

  • Step 6: Commit
git add app/models/image_models.py app/services/image_service.py tests/test_image_service.py
git commit -m "feat: image models, service with bbox crop and quadruple extraction"

Task 11: Image Router

Files:

  • Create: app/routers/image.py

  • Create: tests/test_image_router.py

  • Step 1: 编写失败测试

tests/test_image_router.py:

import numpy as np
import cv2
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock
from app.main import app
from app.core.dependencies import set_clients


def _make_image_bytes() -> bytes:
    img = np.zeros((100, 100, 3), dtype=np.uint8)
    _, buf = cv2.imencode(".jpg", img)
    return buf.tobytes()


@pytest.fixture
def client(mock_llm, mock_storage):
    set_clients(mock_llm, mock_storage)
    return TestClient(app)


def test_image_extract_success(client, mock_llm, mock_storage):
    mock_storage.download_bytes = AsyncMock(return_value=_make_image_bytes())
    mock_storage.upload_bytes = AsyncMock(return_value=None)
    mock_llm.chat_vision = AsyncMock(return_value='[{"subject":"A","predicate":"B","object":"C","qualifier":"","bbox":{"x":0,"y":0,"w":10,"h":10}}]')

    resp = client.post("/api/v1/image/extract", json={
        "file_path": "image/202404/1.jpg",
        "task_id": 42,
    })
    assert resp.status_code == 200
    data = resp.json()
    assert len(data["items"]) == 1
    assert data["items"][0]["cropped_image_path"] == "crops/42/0.jpg"
  • Step 2: 实现 app/routers/image.py
from fastapi import APIRouter, Depends

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.dependencies import get_llm_client, get_storage_client
from app.models.image_models import ImageExtractRequest, ImageExtractResponse
from app.services import image_service

router = APIRouter(tags=["Image"])


@router.post("/image/extract", response_model=ImageExtractResponse)
async def extract_image(
    req: ImageExtractRequest,
    llm: LLMClient = Depends(get_llm_client),
    storage: StorageClient = Depends(get_storage_client),
):
    cfg = get_config()
    model = req.model or cfg["models"]["default_vision"]
    prompt = req.prompt_template or image_service.DEFAULT_PROMPT

    items = await image_service.extract_quadruples(
        file_path=req.file_path,
        task_id=req.task_id,
        model=model,
        prompt_template=prompt,
        llm=llm,
        storage=storage,
        source_bucket=cfg["storage"]["buckets"]["source_data"],
    )
    return ImageExtractResponse(items=items)
  • Step 3: 在 app/main.py 注册路由
from app.routers import text, image
app.include_router(image.router, prefix="/api/v1")
  • Step 4: 运行测试
conda run -n label pytest tests/test_image_router.py -v

Expected: 1 passed

  • Step 5: Commit
git add app/routers/image.py tests/test_image_router.py app/main.py
git commit -m "feat: image router POST /api/v1/image/extract"

Task 12: Video Models + Service

Files:

  • Create: app/models/video_models.py

  • Create: app/services/video_service.py

  • Create: tests/test_video_service.py

  • Step 1: 实现 app/models/video_models.py

from pydantic import BaseModel


class ExtractFramesRequest(BaseModel):
    file_path: str
    source_id: int
    job_id: int
    mode: str = "interval"      # interval | keyframe
    frame_interval: int = 30


class ExtractFramesResponse(BaseModel):
    message: str
    job_id: int


class FrameInfo(BaseModel):
    frame_index: int
    time_sec: float
    frame_path: str


class VideoToTextRequest(BaseModel):
    file_path: str
    source_id: int
    job_id: int
    start_sec: float = 0.0
    end_sec: float
    model: str | None = None
    prompt_template: str | None = None


class VideoToTextResponse(BaseModel):
    message: str
    job_id: int


class VideoJobCallback(BaseModel):
    job_id: int
    status: str                         # SUCCESS | FAILED
    frames: list[FrameInfo] | None = None
    output_path: str | None = None
    error_message: str | None = None
  • Step 2: 编写失败测试

tests/test_video_service.py:

import numpy as np
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from app.services.video_service import _is_scene_change, extract_frames_background


def test_is_scene_change_different_frames():
    prev = np.zeros((100, 100), dtype=np.uint8)
    curr = np.full((100, 100), 200, dtype=np.uint8)
    assert _is_scene_change(prev, curr, threshold=30.0) is True


def test_is_scene_change_similar_frames():
    prev = np.full((100, 100), 100, dtype=np.uint8)
    curr = np.full((100, 100), 101, dtype=np.uint8)
    assert _is_scene_change(prev, curr, threshold=30.0) is False


@pytest.mark.asyncio
async def test_extract_frames_background_calls_callback_on_success(mock_storage):
    import cv2
    import tempfile, os

    # 创建一个有效的真实测试视频5帧10x10
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
        tmp_path = f.name

    out = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*"mp4v"), 10, (10, 10))
    for _ in range(5):
        out.write(np.zeros((10, 10, 3), dtype=np.uint8))
    out.release()

    with open(tmp_path, "rb") as f:
        video_bytes = f.read()
    os.unlink(tmp_path)

    mock_storage.download_bytes.return_value = video_bytes
    mock_storage.upload_bytes = AsyncMock(return_value=None)

    with patch("app.services.video_service.httpx") as mock_httpx:
        mock_client = AsyncMock()
        mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
        mock_client.post = AsyncMock()

        await extract_frames_background(
            file_path="video/1.mp4",
            source_id=10,
            job_id=42,
            mode="interval",
            frame_interval=1,
            storage=mock_storage,
            callback_url="http://backend/callback",
        )

        mock_client.post.assert_called_once()
        call_kwargs = mock_client.post.call_args
        payload = call_kwargs.kwargs.get("json") or call_kwargs.args[1] if len(call_kwargs.args) > 1 else call_kwargs.kwargs["json"]
        assert payload["job_id"] == 42
        assert payload["status"] == "SUCCESS"


@pytest.mark.asyncio
async def test_extract_frames_background_calls_callback_on_failure(mock_storage):
    mock_storage.download_bytes.side_effect = Exception("storage error")

    with patch("app.services.video_service.httpx") as mock_httpx:
        mock_client = AsyncMock()
        mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
        mock_client.post = AsyncMock()

        await extract_frames_background(
            file_path="video/1.mp4",
            source_id=10,
            job_id=99,
            mode="interval",
            frame_interval=30,
            storage=mock_storage,
            callback_url="http://backend/callback",
        )

        mock_client.post.assert_called_once()
        call_kwargs = mock_client.post.call_args
        payload = call_kwargs.kwargs.get("json") or (call_kwargs.args[1] if len(call_kwargs.args) > 1 else {})
        assert payload["status"] == "FAILED"
        assert payload["job_id"] == 99
  • Step 3: 运行,确认失败
conda run -n label pytest tests/test_video_service.py -v

Expected: ImportError

  • Step 4: 实现 app/services/video_service.py
import base64
import logging
import tempfile
import time
from pathlib import Path

import cv2
import httpx
import numpy as np

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.exceptions import LLMCallError
from app.models.video_models import FrameInfo, VideoJobCallback

logger = logging.getLogger(__name__)

DEFAULT_VIDEO_TO_TEXT_PROMPT = """请分析这段视频的帧序列,用中文详细描述:
1. 视频中出现的主要对象、设备、人物
2. 发生的主要动作、操作步骤
3. 场景的整体情况

请输出结构化的文字描述,适合作为知识图谱构建的文本素材。"""


def _is_scene_change(prev: np.ndarray, curr: np.ndarray, threshold: float = 30.0) -> bool:
    """通过帧差分均值判断是否发生场景切换。"""
    diff = cv2.absdiff(prev, curr)
    return float(diff.mean()) > threshold


def _extract_frames(
    video_path: str, mode: str, frame_interval: int
) -> list[tuple[int, float, bytes]]:
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    results = []
    prev_gray = None
    idx = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        time_sec = idx / fps
        if mode == "interval":
            if idx % frame_interval == 0:
                _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
                results.append((idx, time_sec, buf.tobytes()))
        else:
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            if prev_gray is None or _is_scene_change(prev_gray, gray):
                _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
                results.append((idx, time_sec, buf.tobytes()))
            prev_gray = gray
        idx += 1

    cap.release()
    return results


def _sample_frames_as_base64(
    video_path: str, start_sec: float, end_sec: float, count: int
) -> list[str]:
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    start_frame = int(start_sec * fps)
    end_frame = int(end_sec * fps)
    total = max(1, end_frame - start_frame)
    step = max(1, total // count)
    results = []
    for i in range(count):
        frame_pos = start_frame + i * step
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
        ret, frame = cap.read()
        if ret:
            _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
            results.append(base64.b64encode(buf.tobytes()).decode())
    cap.release()
    return results


async def _send_callback(url: str, payload: VideoJobCallback) -> None:
    async with httpx.AsyncClient(timeout=10) as client:
        try:
            await client.post(url, json=payload.model_dump())
        except Exception as e:
            logger.warning(f"回调失败 url={url}: {e}")


async def extract_frames_background(
    file_path: str,
    source_id: int,
    job_id: int,
    mode: str,
    frame_interval: int,
    storage: StorageClient,
    callback_url: str,
    bucket: str = "source-data",
) -> None:
    try:
        data = await storage.download_bytes(bucket, file_path)
    except Exception as e:
        await _send_callback(callback_url, VideoJobCallback(
            job_id=job_id, status="FAILED", error_message=str(e)
        ))
        return

    suffix = Path(file_path).suffix or ".mp4"
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
        tmp.write(data)
        tmp_path = tmp.name

    try:
        frames = _extract_frames(tmp_path, mode, frame_interval)
        frame_infos = []
        for i, (frame_idx, time_sec, frame_data) in enumerate(frames):
            frame_path = f"frames/{source_id}/{i}.jpg"
            await storage.upload_bytes(bucket, frame_path, frame_data, "image/jpeg")
            frame_infos.append(FrameInfo(
                frame_index=frame_idx,
                time_sec=round(time_sec, 3),
                frame_path=frame_path,
            ))
        await _send_callback(callback_url, VideoJobCallback(
            job_id=job_id, status="SUCCESS", frames=frame_infos
        ))
        logger.info(f"extract_frames job_id={job_id} frames={len(frame_infos)}")
    except Exception as e:
        logger.exception(f"extract_frames failed job_id={job_id}")
        await _send_callback(callback_url, VideoJobCallback(
            job_id=job_id, status="FAILED", error_message=str(e)
        ))
    finally:
        Path(tmp_path).unlink(missing_ok=True)


async def video_to_text_background(
    file_path: str,
    source_id: int,
    job_id: int,
    start_sec: float,
    end_sec: float,
    model: str,
    prompt_template: str,
    frame_sample_count: int,
    llm: LLMClient,
    storage: StorageClient,
    callback_url: str,
    bucket: str = "source-data",
) -> None:
    try:
        data = await storage.download_bytes(bucket, file_path)
    except Exception as e:
        await _send_callback(callback_url, VideoJobCallback(
            job_id=job_id, status="FAILED", error_message=str(e)
        ))
        return

    suffix = Path(file_path).suffix or ".mp4"
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
        tmp.write(data)
        tmp_path = tmp.name

    try:
        frames_b64 = _sample_frames_as_base64(tmp_path, start_sec, end_sec, frame_sample_count)
        content: list = []
        for b64 in frames_b64:
            content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}})
        content.append({
            "type": "text",
            "text": f"以上是视频第{start_sec}秒至{end_sec}秒的均匀采样帧。\n{prompt_template}",
        })

        messages = [
            {"role": "system", "content": "你是专业的视频内容分析助手。"},
            {"role": "user", "content": content},
        ]

        try:
            description = await llm.chat_vision(messages, model)
        except Exception as e:
            raise LLMCallError(f"GLM-4V 调用失败: {e}") from e

        timestamp = int(time.time())
        output_path = f"video-text/{source_id}/{timestamp}.txt"
        await storage.upload_bytes(bucket, output_path, description.encode("utf-8"), "text/plain")

        await _send_callback(callback_url, VideoJobCallback(
            job_id=job_id, status="SUCCESS", output_path=output_path
        ))
        logger.info(f"video_to_text job_id={job_id} output={output_path}")
    except Exception as e:
        logger.exception(f"video_to_text failed job_id={job_id}")
        await _send_callback(callback_url, VideoJobCallback(
            job_id=job_id, status="FAILED", error_message=str(e)
        ))
    finally:
        Path(tmp_path).unlink(missing_ok=True)
  • Step 5: 运行,确认通过
conda run -n label pytest tests/test_video_service.py -v

Expected: 4 passed

  • Step 6: Commit
git add app/models/video_models.py app/services/video_service.py tests/test_video_service.py
git commit -m "feat: video models and service with frame extraction and video-to-text"

Task 13: Video Router

Files:

  • Create: app/routers/video.py

  • Create: tests/test_video_router.py

  • Step 1: 编写失败测试

tests/test_video_router.py:

import pytest
from fastapi.testclient import TestClient
from app.main import app
from app.core.dependencies import set_clients


@pytest.fixture
def client(mock_llm, mock_storage):
    set_clients(mock_llm, mock_storage)
    return TestClient(app)


def test_extract_frames_returns_202(client, mock_storage):
    mock_storage.get_object_size = AsyncMock(return_value=10 * 1024 * 1024)  # 10MB
    resp = client.post("/api/v1/video/extract-frames", json={
        "file_path": "video/202404/1.mp4",
        "source_id": 10,
        "job_id": 42,
        "mode": "interval",
        "frame_interval": 30,
    })
    assert resp.status_code == 202
    assert resp.json()["job_id"] == 42
    assert "后台处理中" in resp.json()["message"]


def test_video_to_text_returns_202(client, mock_storage):
    mock_storage.get_object_size = AsyncMock(return_value=10 * 1024 * 1024)  # 10MB
    resp = client.post("/api/v1/video/to-text", json={
        "file_path": "video/202404/1.mp4",
        "source_id": 10,
        "job_id": 43,
        "start_sec": 0,
        "end_sec": 60,
    })
    assert resp.status_code == 202
    assert resp.json()["job_id"] == 43


def test_extract_frames_rejects_oversized_video(client, mock_storage):
    mock_storage.get_object_size = AsyncMock(return_value=300 * 1024 * 1024)  # 300MB > 200MB limit
    resp = client.post("/api/v1/video/extract-frames", json={
        "file_path": "video/202404/big.mp4",
        "source_id": 10,
        "job_id": 99,
        "mode": "interval",
        "frame_interval": 30,
    })
    assert resp.status_code == 400
    assert "大小" in resp.json()["detail"]
  • Step 2: 实现 app/routers/video.py
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.dependencies import get_llm_client, get_storage_client
from app.models.video_models import (
    ExtractFramesRequest,
    ExtractFramesResponse,
    VideoToTextRequest,
    VideoToTextResponse,
)
from app.services import video_service

router = APIRouter(tags=["Video"])


async def _check_video_size(storage: StorageClient, bucket: str, file_path: str, max_mb: int) -> None:
    """在触发后台任务前校验视频文件大小,超限时抛出 HTTP 400。"""
    size_bytes = await storage.get_object_size(bucket, file_path)
    if size_bytes > max_mb * 1024 * 1024:
        raise HTTPException(
            status_code=400,
            detail=f"视频文件大小超出限制(最大 {max_mb}MB当前 {size_bytes // 1024 // 1024}MB",
        )


@router.post("/video/extract-frames", response_model=ExtractFramesResponse, status_code=202)
async def extract_frames(
    req: ExtractFramesRequest,
    background_tasks: BackgroundTasks,
    storage: StorageClient = Depends(get_storage_client),
):
    cfg = get_config()
    bucket = cfg["storage"]["buckets"]["source_data"]
    await _check_video_size(storage, bucket, req.file_path, cfg["video"]["max_file_size_mb"])
    background_tasks.add_task(
        video_service.extract_frames_background,
        file_path=req.file_path,
        source_id=req.source_id,
        job_id=req.job_id,
        mode=req.mode,
        frame_interval=req.frame_interval,
        storage=storage,
        callback_url=cfg["backend"]["callback_url"],
        bucket=bucket,
    )
    return ExtractFramesResponse(message="任务已接受,后台处理中", job_id=req.job_id)


@router.post("/video/to-text", response_model=VideoToTextResponse, status_code=202)
async def video_to_text(
    req: VideoToTextRequest,
    background_tasks: BackgroundTasks,
    llm: LLMClient = Depends(get_llm_client),
    storage: StorageClient = Depends(get_storage_client),
):
    cfg = get_config()
    bucket = cfg["storage"]["buckets"]["source_data"]
    await _check_video_size(storage, bucket, req.file_path, cfg["video"]["max_file_size_mb"])
    model = req.model or cfg["models"]["default_vision"]
    prompt = req.prompt_template or video_service.DEFAULT_VIDEO_TO_TEXT_PROMPT
    background_tasks.add_task(
        video_service.video_to_text_background,
        file_path=req.file_path,
        source_id=req.source_id,
        job_id=req.job_id,
        start_sec=req.start_sec,
        end_sec=req.end_sec,
        model=model,
        prompt_template=prompt,
        frame_sample_count=cfg["video"]["frame_sample_count"],
        llm=llm,
        storage=storage,
        callback_url=cfg["backend"]["callback_url"],
        bucket=bucket,
    )
    return VideoToTextResponse(message="任务已接受,后台处理中", job_id=req.job_id)
  • Step 3: 在 app/main.py 注册路由
from app.routers import text, image, video
app.include_router(video.router, prefix="/api/v1")
  • Step 4: 运行测试
conda run -n label pytest tests/test_video_router.py -v

Expected: 3 passed

  • Step 5: Commit
git add app/routers/video.py tests/test_video_router.py app/main.py
git commit -m "feat: video router POST /api/v1/video/extract-frames and /to-text"

Task 14: QA Models + Service

Files:

  • Create: app/models/qa_models.py

  • Create: app/services/qa_service.py

  • Create: tests/test_qa_service.py

  • Step 1: 实现 app/models/qa_models.py

from pydantic import BaseModel


class TextTripleForQA(BaseModel):
    subject: str
    predicate: str
    object: str
    source_snippet: str


class TextQARequest(BaseModel):
    items: list[TextTripleForQA]
    model: str | None = None
    prompt_template: str | None = None


class QAPair(BaseModel):
    question: str
    answer: str


class TextQAResponse(BaseModel):
    pairs: list[QAPair]


class ImageQuadrupleForQA(BaseModel):
    subject: str
    predicate: str
    object: str
    qualifier: str
    cropped_image_path: str


class ImageQARequest(BaseModel):
    items: list[ImageQuadrupleForQA]
    model: str | None = None
    prompt_template: str | None = None


class ImageQAPair(BaseModel):
    question: str
    answer: str
    image_path: str


class ImageQAResponse(BaseModel):
    pairs: list[ImageQAPair]
  • Step 2: 编写失败测试

tests/test_qa_service.py:

import pytest
from app.services.qa_service import gen_text_qa, gen_image_qa, _parse_qa_pairs
from app.models.qa_models import TextTripleForQA, ImageQuadrupleForQA
from app.core.exceptions import LLMResponseParseError, LLMCallError

QA_JSON = '[{"question":"变压器额定电压是多少?","answer":"110kV"}]'


def test_parse_qa_pairs_plain_json():
    result = _parse_qa_pairs(QA_JSON)
    assert len(result) == 1
    assert result[0].question == "变压器额定电压是多少?"


def test_parse_qa_pairs_markdown_wrapped():
    result = _parse_qa_pairs(f"```json\n{QA_JSON}\n```")
    assert len(result) == 1


def test_parse_qa_pairs_invalid_raises():
    with pytest.raises(LLMResponseParseError):
        _parse_qa_pairs("这不是JSON")


@pytest.mark.asyncio
async def test_gen_text_qa(mock_llm):
    mock_llm.chat.return_value = QA_JSON
    items = [TextTripleForQA(subject="变压器", predicate="额定电压", object="110kV", source_snippet="额定电压为110kV")]

    result = await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm)
    assert len(result) == 1
    assert result[0].answer == "110kV"


@pytest.mark.asyncio
async def test_gen_text_qa_llm_error(mock_llm):
    mock_llm.chat.side_effect = Exception("network error")
    items = [TextTripleForQA(subject="A", predicate="B", object="C", source_snippet="ABC")]

    with pytest.raises(LLMCallError):
        await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm)


@pytest.mark.asyncio
async def test_gen_image_qa(mock_llm, mock_storage):
    mock_llm.chat_vision.return_value = '[{"question":"图中是什么?","answer":"电缆接头"}]'
    mock_storage.download_bytes.return_value = b"fake-image-bytes"
    items = [ImageQuadrupleForQA(
        subject="电缆接头", predicate="位于", object="配电箱", qualifier="", cropped_image_path="crops/1/0.jpg"
    )]

    result = await gen_image_qa(items=items, model="glm-4v-flash", prompt_template="", llm=mock_llm, storage=mock_storage)
    assert len(result) == 1
    assert result[0].image_path == "crops/1/0.jpg"
    # 验证使用 download_bytesbase64而非 presigned URL
    mock_storage.download_bytes.assert_called_once_with("source-data", "crops/1/0.jpg")
    # 验证发送给 GLM-4V 的消息包含 base64 data URL
    call_messages = mock_llm.chat_vision.call_args[0][0]
    image_content = call_messages[1]["content"][0]
    assert image_content["image_url"]["url"].startswith("data:image/jpeg;base64,")
  • Step 3: 运行,确认失败
conda run -n label pytest tests/test_qa_service.py -v

Expected: ImportError

  • Step 4: 实现 app/services/qa_service.py
import base64
import json
import logging

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError
from app.core.json_utils import parse_json_response
from app.models.qa_models import (
    ImageQAPair,
    ImageQuadrupleForQA,
    QAPair,
    TextTripleForQA,
)

logger = logging.getLogger(__name__)

DEFAULT_TEXT_QA_PROMPT = """基于以下知识三元组和原文证据片段,生成高质量问答对。
要求:
1. 问题自然、具体,不能过于宽泛
2. 答案基于原文片段,语言流畅
3. 每个三元组生成1-2个问答对

以 JSON 数组格式返回:[{"question":"...","answer":"..."}]

三元组数据:
"""

DEFAULT_IMAGE_QA_PROMPT = """基于图片内容和以下四元组信息,生成高质量图文问答对。
要求:
1. 问题需要结合图片才能回答
2. 答案基于图片中的实际内容
3. 每个四元组生成1个问答对

以 JSON 数组格式返回:[{"question":"...","answer":"..."}]

四元组信息:
"""


def _parse_qa_pairs(raw: str) -> list[QAPair]:
    items_raw = parse_json_response(raw)
    result = []
    for item in items_raw:
        try:
            result.append(QAPair(question=item["question"], answer=item["answer"]))
        except KeyError as e:
            logger.warning(f"跳过不完整问答对: {item}, error: {e}")
    return result


async def gen_text_qa(
    items: list[TextTripleForQA],
    model: str,
    prompt_template: str,
    llm: LLMClient,
) -> list[QAPair]:
    triples_text = json.dumps([i.model_dump() for i in items], ensure_ascii=False, indent=2)
    messages = [
        {"role": "system", "content": "你是专业的知识问答对生成助手。"},
        {"role": "user", "content": (prompt_template or DEFAULT_TEXT_QA_PROMPT) + triples_text},
    ]
    try:
        raw = await llm.chat(messages, model)
    except Exception as e:
        raise LLMCallError(f"GLM 调用失败: {e}") from e
    logger.info(f"gen_text_qa model={model} items={len(items)}")
    return _parse_qa_pairs(raw)


async def gen_image_qa(
    items: list[ImageQuadrupleForQA],
    model: str,
    prompt_template: str,
    llm: LLMClient,
    storage: StorageClient,
    bucket: str = "source-data",
) -> list[ImageQAPair]:
    result = []
    prompt = prompt_template or DEFAULT_IMAGE_QA_PROMPT
    for item in items:
        # 下载裁剪图并 base64 编码RustFS 为内网部署presigned URL 无法被云端 GLM-4V 访问
        try:
            image_bytes = await storage.download_bytes(bucket, item.cropped_image_path)
        except Exception as e:
            raise StorageDownloadError(f"下载裁剪图失败 {item.cropped_image_path}: {e}") from e
        b64 = base64.b64encode(image_bytes).decode()
        quad_text = json.dumps(
            {k: v for k, v in item.model_dump().items() if k != "cropped_image_path"},
            ensure_ascii=False,
        )
        messages = [
            {"role": "system", "content": "你是专业的视觉问答对生成助手。"},
            {"role": "user", "content": [
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}},
                {"type": "text", "text": prompt + quad_text},
            ]},
        ]
        try:
            raw = await llm.chat_vision(messages, model)
        except Exception as e:
            raise LLMCallError(f"GLM-4V 调用失败: {e}") from e
        for pair in _parse_qa_pairs(raw):
            result.append(ImageQAPair(question=pair.question, answer=pair.answer, image_path=item.cropped_image_path))
    logger.info(f"gen_image_qa model={model} items={len(items)} pairs={len(result)}")
    return result
  • Step 5: 运行,确认通过
conda run -n label pytest tests/test_qa_service.py -v

Expected: 6 passed

  • Step 6: Commit
git add app/models/qa_models.py app/services/qa_service.py tests/test_qa_service.py
git commit -m "feat: QA models and service for text and image QA generation"

Task 15: QA Router

Files:

  • Create: app/routers/qa.py

  • Create: tests/test_qa_router.py

  • Step 1: 编写失败测试

tests/test_qa_router.py:

import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock
from app.main import app
from app.core.dependencies import set_clients


@pytest.fixture
def client(mock_llm, mock_storage):
    set_clients(mock_llm, mock_storage)
    return TestClient(app)


def test_gen_text_qa_success(client, mock_llm):
    mock_llm.chat = AsyncMock(return_value='[{"question":"额定电压?","answer":"110kV"}]')
    resp = client.post("/api/v1/qa/gen-text", json={
        "items": [{"subject": "变压器", "predicate": "额定电压", "object": "110kV", "source_snippet": "额定电压为110kV"}],
    })
    assert resp.status_code == 200
    assert resp.json()["pairs"][0]["question"] == "额定电压?"


def test_gen_image_qa_success(client, mock_llm, mock_storage):
    mock_llm.chat_vision = AsyncMock(return_value='[{"question":"图中是什么?","answer":"接头"}]')
    mock_storage.get_presigned_url.return_value = "https://example.com/crop.jpg"
    resp = client.post("/api/v1/qa/gen-image", json={
        "items": [{"subject": "A", "predicate": "B", "object": "C", "qualifier": "", "cropped_image_path": "crops/1/0.jpg"}],
    })
    assert resp.status_code == 200
    data = resp.json()
    assert data["pairs"][0]["image_path"] == "crops/1/0.jpg"
  • Step 2: 实现 app/routers/qa.py
from fastapi import APIRouter, Depends

from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.config import get_config
from app.core.dependencies import get_llm_client, get_storage_client
from app.models.qa_models import ImageQARequest, ImageQAResponse, TextQARequest, TextQAResponse
from app.services import qa_service

router = APIRouter(tags=["QA"])


@router.post("/qa/gen-text", response_model=TextQAResponse)
async def gen_text_qa(
    req: TextQARequest,
    llm: LLMClient = Depends(get_llm_client),
):
    cfg = get_config()
    pairs = await qa_service.gen_text_qa(
        items=req.items,
        model=req.model or cfg["models"]["default_text"],
        prompt_template=req.prompt_template or qa_service.DEFAULT_TEXT_QA_PROMPT,
        llm=llm,
    )
    return TextQAResponse(pairs=pairs)


@router.post("/qa/gen-image", response_model=ImageQAResponse)
async def gen_image_qa(
    req: ImageQARequest,
    llm: LLMClient = Depends(get_llm_client),
    storage: StorageClient = Depends(get_storage_client),
):
    cfg = get_config()
    pairs = await qa_service.gen_image_qa(
        items=req.items,
        model=req.model or cfg["models"]["default_vision"],
        prompt_template=req.prompt_template or qa_service.DEFAULT_IMAGE_QA_PROMPT,
        llm=llm,
        storage=storage,
        bucket=cfg["storage"]["buckets"]["source_data"],
    )
    return ImageQAResponse(pairs=pairs)
  • Step 3: 在 app/main.py 注册路由
from app.routers import text, image, video, qa
app.include_router(qa.router, prefix="/api/v1")
  • Step 4: 运行测试
conda run -n label pytest tests/test_qa_router.py -v

Expected: 2 passed

  • Step 5: Commit
git add app/routers/qa.py tests/test_qa_router.py app/main.py
git commit -m "feat: QA router POST /api/v1/qa/gen-text and /gen-image"

Task 16: Finetune Models + Service + Router

Files:

  • Create: app/models/finetune_models.py

  • Create: app/services/finetune_service.py

  • Create: app/routers/finetune.py

  • Create: tests/test_finetune_service.py

  • Create: tests/test_finetune_router.py

  • Step 1: 实现 app/models/finetune_models.py

from pydantic import BaseModel


class FinetuneHyperparams(BaseModel):
    learning_rate: float = 1e-4
    epochs: int = 3


class FinetuneStartRequest(BaseModel):
    jsonl_url: str
    base_model: str
    hyperparams: FinetuneHyperparams = FinetuneHyperparams()


class FinetuneStartResponse(BaseModel):
    job_id: str


class FinetuneStatusResponse(BaseModel):
    job_id: str
    status: str             # RUNNING | SUCCESS | FAILED
    progress: int | None = None
    error_message: str | None = None
  • Step 2: 编写失败测试

tests/test_finetune_service.py:

import pytest
from unittest.mock import MagicMock
from app.services.finetune_service import start_finetune, get_finetune_status
from app.models.finetune_models import FinetuneHyperparams


@pytest.mark.asyncio
async def test_start_finetune():
    mock_job = MagicMock()
    mock_job.id = "glm-ft-abc123"
    mock_zhipuai = MagicMock()
    mock_zhipuai.fine_tuning.jobs.create.return_value = mock_job

    result = await start_finetune(
        jsonl_url="https://example.com/export.jsonl",
        base_model="glm-4-flash",
        hyperparams=FinetuneHyperparams(learning_rate=1e-4, epochs=3),
        client=mock_zhipuai,
    )
    assert result == "glm-ft-abc123"
    mock_zhipuai.fine_tuning.jobs.create.assert_called_once()


@pytest.mark.asyncio
async def test_get_finetune_status_running():
    mock_job = MagicMock()
    mock_job.status = "running"
    mock_job.progress = 50
    mock_job.error = None
    mock_zhipuai = MagicMock()
    mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job

    result = await get_finetune_status("glm-ft-abc123", mock_zhipuai)
    assert result.status == "RUNNING"
    assert result.progress == 50
    assert result.job_id == "glm-ft-abc123"


@pytest.mark.asyncio
async def test_get_finetune_status_success():
    mock_job = MagicMock()
    mock_job.status = "succeeded"
    mock_job.progress = 100
    mock_job.error = None
    mock_zhipuai = MagicMock()
    mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job

    result = await get_finetune_status("glm-ft-abc123", mock_zhipuai)
    assert result.status == "SUCCESS"
  • Step 3: 运行,确认失败
conda run -n label pytest tests/test_finetune_service.py -v

Expected: ImportError

  • Step 4: 实现 app/services/finetune_service.py
import logging

from app.models.finetune_models import FinetuneHyperparams, FinetuneStatusResponse

logger = logging.getLogger(__name__)

_STATUS_MAP = {
    "running": "RUNNING",
    "succeeded": "SUCCESS",
    "failed": "FAILED",
}


async def start_finetune(
    jsonl_url: str,
    base_model: str,
    hyperparams: FinetuneHyperparams,
    client,  # ZhipuAI SDK client instance
) -> str:
    job = client.fine_tuning.jobs.create(
        training_file=jsonl_url,
        model=base_model,
        hyperparameters={
            "learning_rate_multiplier": hyperparams.learning_rate,
            "n_epochs": hyperparams.epochs,
        },
    )
    logger.info(f"finetune_start job_id={job.id} model={base_model}")
    return job.id


async def get_finetune_status(job_id: str, client) -> FinetuneStatusResponse:
    job = client.fine_tuning.jobs.retrieve(job_id)
    status = _STATUS_MAP.get(job.status, "RUNNING")
    return FinetuneStatusResponse(
        job_id=job_id,
        status=status,
        progress=getattr(job, "progress", None),
        error_message=getattr(job, "error", None),
    )
  • Step 5: 运行,确认通过
conda run -n label pytest tests/test_finetune_service.py -v

Expected: 3 passed

  • Step 6: 实现 app/routers/finetune.py
from fastapi import APIRouter, Depends

from app.clients.llm.base import LLMClient
from app.clients.llm.zhipuai_client import ZhipuAIClient
from app.core.dependencies import get_llm_client
from app.models.finetune_models import (
    FinetuneStartRequest,
    FinetuneStartResponse,
    FinetuneStatusResponse,
)
from app.services import finetune_service

router = APIRouter(tags=["Finetune"])


def _get_zhipuai(llm: LLMClient = Depends(get_llm_client)) -> ZhipuAIClient:
    if not isinstance(llm, ZhipuAIClient):
        raise RuntimeError("微调功能仅支持 ZhipuAI 后端")
    return llm


@router.post("/finetune/start", response_model=FinetuneStartResponse)
async def start_finetune(
    req: FinetuneStartRequest,
    llm: ZhipuAIClient = Depends(_get_zhipuai),
):
    job_id = await finetune_service.start_finetune(
        jsonl_url=req.jsonl_url,
        base_model=req.base_model,
        hyperparams=req.hyperparams,
        client=llm._client,
    )
    return FinetuneStartResponse(job_id=job_id)


@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse)
async def get_finetune_status(
    job_id: str,
    llm: ZhipuAIClient = Depends(_get_zhipuai),
):
    return await finetune_service.get_finetune_status(job_id, llm._client)
  • Step 7: 编写路由测试

tests/test_finetune_router.py:

import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from app.main import app
from app.core.dependencies import set_clients
from app.clients.llm.zhipuai_client import ZhipuAIClient
from app.clients.storage.base import StorageClient


@pytest.fixture
def client(mock_storage):
    with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI:
        mock_sdk = MagicMock()
        MockZhipuAI.return_value = mock_sdk
        llm = ZhipuAIClient(api_key="test-key")
        llm._mock_sdk = mock_sdk
        set_clients(llm, mock_storage)
        yield TestClient(app), mock_sdk


def test_start_finetune(client):
    test_client, mock_sdk = client
    mock_job = MagicMock()
    mock_job.id = "glm-ft-xyz"
    mock_sdk.fine_tuning.jobs.create.return_value = mock_job

    resp = test_client.post("/api/v1/finetune/start", json={
        "jsonl_url": "https://example.com/export.jsonl",
        "base_model": "glm-4-flash",
        "hyperparams": {"learning_rate": 1e-4, "epochs": 3},
    })
    assert resp.status_code == 200
    assert resp.json()["job_id"] == "glm-ft-xyz"


def test_get_finetune_status(client):
    test_client, mock_sdk = client
    mock_job = MagicMock()
    mock_job.status = "running"
    mock_job.progress = 30
    mock_job.error = None
    mock_sdk.fine_tuning.jobs.retrieve.return_value = mock_job

    resp = test_client.get("/api/v1/finetune/status/glm-ft-xyz")
    assert resp.status_code == 200
    data = resp.json()
    assert data["status"] == "RUNNING"
    assert data["progress"] == 30
  • Step 8: 在 app/main.py 注册路由(最终状态)
from app.routers import text, image, video, qa, finetune

app.include_router(text.router, prefix="/api/v1")
app.include_router(image.router, prefix="/api/v1")
app.include_router(video.router, prefix="/api/v1")
app.include_router(qa.router, prefix="/api/v1")
app.include_router(finetune.router, prefix="/api/v1")
  • Step 9: 运行全部测试
conda run -n label pytest tests/ -v

Expected: 所有测试通过,无失败

  • Step 10: Commit
git add app/models/finetune_models.py app/services/finetune_service.py app/routers/finetune.py tests/test_finetune_service.py tests/test_finetune_router.py app/main.py
git commit -m "feat: finetune models, service, and router - complete all endpoints"

Task 17: 部署文件

Files:

  • Create: Dockerfile

  • Create: docker-compose.yml

  • Step 1: 创建 Dockerfile

FROM python:3.12-slim

WORKDIR /app

# OpenCV 系统依赖
RUN apt-get update && apt-get install -y \
    libgl1 \
    libglib2.0-0 \
    && rm -rf /var/lib/apt/lists/*

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY app/ ./app/
COPY config.yaml .
COPY .env .

EXPOSE 8000

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
  • Step 2: 创建 docker-compose.yml
version: "3.9"

services:
  ai-service:
    build: .
    ports:
      - "8000:8000"
    env_file:
      - .env
    depends_on:
      - rustfs
    networks:
      - label-net
    restart: unless-stopped
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 30s
      timeout: 5s
      retries: 3
      start_period: 10s

  rustfs:
    image: minio/minio:latest
    command: server /data --console-address ":9001"
    ports:
      - "9000:9000"
      - "9001:9001"
    environment:
      MINIO_ROOT_USER: minioadmin
      MINIO_ROOT_PASSWORD: minioadmin
    volumes:
      - rustfs-data:/data
    networks:
      - label-net

volumes:
  rustfs-data:

networks:
  label-net:
    driver: bridge
  • Step 3: 验证 Docker 构建
docker build -t label-ai-service:dev .

Expected: 镜像构建成功,无错误

  • Step 4: 运行全量测试,最终确认
conda run -n label pytest tests/ -v --tb=short

Expected: 所有测试通过

  • Step 5: Commit
git add Dockerfile docker-compose.yml
git commit -m "feat: Dockerfile and docker-compose for containerized deployment"

自审检查结果

Spec coverage:

  • 文本三元组提取TXT/PDF/DOCX— Task 8-9
  • 图像四元组提取 + bbox 裁剪 — Task 10-11
  • 视频帧提取interval/keyframe— Task 12-13
  • 视频转文本BackgroundTask— Task 12-13
  • 文本问答对生成 — Task 14-15
  • 图像问答对生成 — Task 14-15
  • 微调任务提交与状态查询 — Task 16
  • LLMClient / StorageClient ABC 适配层 — Task 4-5
  • config.yaml + .env 分层配置 — Task 2
  • 结构化日志 + 请求日志 — Task 3
  • 全局异常处理 — Task 3
  • Swagger 文档FastAPI 自动生成) — Task 6
  • Dockerfile + docker-compose — Task 17
  • pytest 测试覆盖全部 service 和 router — 各 Task

类型一致性: TripleItem.source_offset 在 Task 7 定义Task 8 使用;VideoJobCallback 在 Task 12 定义Task 12 service 使用 — 一致。

占位符: 无 TBD / TODO所有步骤均含完整代码。