From f9f84937db26a7e6644b2a91d06f986ec8d5419c Mon Sep 17 00:00:00 2001 From: wh Date: Fri, 10 Apr 2026 14:15:38 +0800 Subject: [PATCH] docs: add AI service implementation plan 22 tasks covering all 8 endpoints with TDD. Each task includes test code, implementation code, run commands, and commit step. --- .../plans/2026-04-10-ai-service-impl.md | 2884 +++++++++++++++++ 1 file changed, 2884 insertions(+) create mode 100644 docs/superpowers/plans/2026-04-10-ai-service-impl.md diff --git a/docs/superpowers/plans/2026-04-10-ai-service-impl.md b/docs/superpowers/plans/2026-04-10-ai-service-impl.md new file mode 100644 index 0000000..7db598b --- /dev/null +++ b/docs/superpowers/plans/2026-04-10-ai-service-impl.md @@ -0,0 +1,2884 @@ +# AI Service Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** 实现 label_ai_service,一个 Python FastAPI 服务,为知识图谱标注平台提供文本三元组提取、图像四元组提取、视频处理、问答对生成和 GLM 微调管理能力。 + +**Architecture:** 分层架构:routers(HTTP 入口)→ services(业务逻辑)→ clients(外部适配层)。LLMClient 和 StorageClient 均为 ABC,当前分别实现 ZhipuAIClient 和 RustFSClient,通过 FastAPI Depends 注入,services 层不感知具体实现。视频任务用 FastAPI BackgroundTasks 异步执行,完成后回调 Java 后端。 + +**Tech Stack:** Python 3.12(conda `label` 环境),FastAPI,ZhipuAI SDK,boto3(S3),OpenCV,pdfplumber,python-docx,httpx,pytest + +--- + +## Task 1: 项目脚手架 + +**Files:** +- Create: `app/__init__.py` +- Create: `app/core/__init__.py` +- Create: `app/clients/__init__.py` +- Create: `app/clients/llm/__init__.py` +- Create: `app/clients/storage/__init__.py` +- Create: `app/services/__init__.py` +- Create: `app/routers/__init__.py` +- Create: `app/models/__init__.py` +- Create: `tests/__init__.py` +- Create: `tests/conftest.py` +- Create: `config.yaml` +- Create: `.env` +- Create: `requirements.txt` + +- [ ] **Step 1: 创建包目录结构** + +```bash +mkdir -p app/core app/clients/llm app/clients/storage app/services app/routers app/models tests +touch app/__init__.py app/core/__init__.py +touch app/clients/__init__.py app/clients/llm/__init__.py app/clients/storage/__init__.py +touch app/services/__init__.py app/routers/__init__.py app/models/__init__.py +touch tests/__init__.py +``` + +- [ ] **Step 2: 创建 `config.yaml`** + +```yaml +server: + port: 8000 + log_level: INFO + +storage: + buckets: + source_data: "source-data" + finetune_export: "finetune-export" + +backend: {} + +video: + frame_sample_count: 8 + +models: + default_text: "glm-4-flash" + default_vision: "glm-4v-flash" +``` + +- [ ] **Step 3: 创建 `.env`** + +```ini +ZHIPUAI_API_KEY=your-zhipuai-api-key +STORAGE_ACCESS_KEY=minioadmin +STORAGE_SECRET_KEY=minioadmin +STORAGE_ENDPOINT=http://rustfs:9000 +BACKEND_CALLBACK_URL=http://backend:8080/internal/video-job/callback +``` + +- [ ] **Step 4: 创建 `requirements.txt`** + +``` +fastapi>=0.111 +uvicorn[standard]>=0.29 +pydantic>=2.7 +python-dotenv>=1.0 +pyyaml>=6.0 +zhipuai>=2.1 +boto3>=1.34 +pdfplumber>=0.11 +python-docx>=1.1 +opencv-python-headless>=4.9 +numpy>=1.26 +httpx>=0.27 +pytest>=8.0 +pytest-asyncio>=0.23 +``` + +- [ ] **Step 5: 创建 `tests/conftest.py`** + +```python +import pytest +from unittest.mock import AsyncMock, MagicMock +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient + + +@pytest.fixture +def mock_llm(): + client = MagicMock(spec=LLMClient) + client.chat = AsyncMock() + client.chat_vision = AsyncMock() + return client + + +@pytest.fixture +def mock_storage(): + client = MagicMock(spec=StorageClient) + client.download_bytes = AsyncMock() + client.upload_bytes = AsyncMock() + client.get_presigned_url = MagicMock(return_value="https://example.com/presigned/crop.jpg") + return client +``` + +- [ ] **Step 6: 安装依赖** + +```bash +conda run -n label pip install -r requirements.txt +``` + +Expected: 所有包安装成功,无错误 + +- [ ] **Step 7: Commit** + +```bash +git add app/ tests/ config.yaml .env requirements.txt +git commit -m "feat: project scaffold - directory structure and config files" +``` + +--- + +## Task 2: Core Config 模块 + +**Files:** +- Create: `app/core/config.py` +- Create: `tests/test_config.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_config.py`: + +```python +import pytest +from unittest.mock import patch, mock_open +from app.core.config import get_config + +MOCK_YAML = """ +server: + port: 8000 + log_level: INFO +storage: + buckets: + source_data: "source-data" + finetune_export: "finetune-export" +backend: {} +video: + frame_sample_count: 8 +models: + default_text: "glm-4-flash" + default_vision: "glm-4v-flash" +""" + + +def _fresh_config(monkeypatch, extra_env: dict = None): + """每次测试前清除 lru_cache,设置环境变量。""" + get_config.cache_clear() + base_env = { + "ZHIPUAI_API_KEY": "test-key", + "STORAGE_ACCESS_KEY": "test-access", + "STORAGE_SECRET_KEY": "test-secret", + "STORAGE_ENDPOINT": "http://localhost:9000", + "BACKEND_CALLBACK_URL": "http://localhost:8080/callback", + } + if extra_env: + base_env.update(extra_env) + for k, v in base_env.items(): + monkeypatch.setenv(k, v) + + +def test_env_overrides_yaml(monkeypatch): + _fresh_config(monkeypatch) + with patch("builtins.open", mock_open(read_data=MOCK_YAML)): + with patch("app.core.config.load_dotenv"): + cfg = get_config() + assert cfg["zhipuai"]["api_key"] == "test-key" + assert cfg["storage"]["access_key"] == "test-access" + assert cfg["storage"]["endpoint"] == "http://localhost:9000" + assert cfg["backend"]["callback_url"] == "http://localhost:8080/callback" + get_config.cache_clear() + + +def test_yaml_values_preserved(monkeypatch): + _fresh_config(monkeypatch) + with patch("builtins.open", mock_open(read_data=MOCK_YAML)): + with patch("app.core.config.load_dotenv"): + cfg = get_config() + assert cfg["models"]["default_text"] == "glm-4-flash" + assert cfg["video"]["frame_sample_count"] == 8 + assert cfg["storage"]["buckets"]["source_data"] == "source-data" + get_config.cache_clear() + + +def test_missing_api_key_raises(monkeypatch): + get_config.cache_clear() + monkeypatch.delenv("ZHIPUAI_API_KEY", raising=False) + monkeypatch.setenv("STORAGE_ACCESS_KEY", "a") + monkeypatch.setenv("STORAGE_SECRET_KEY", "b") + with patch("builtins.open", mock_open(read_data=MOCK_YAML)): + with patch("app.core.config.load_dotenv"): + with pytest.raises(RuntimeError, match="ZHIPUAI_API_KEY"): + get_config() + get_config.cache_clear() + + +def test_missing_storage_key_raises(monkeypatch): + get_config.cache_clear() + monkeypatch.setenv("ZHIPUAI_API_KEY", "key") + monkeypatch.delenv("STORAGE_ACCESS_KEY", raising=False) + monkeypatch.setenv("STORAGE_SECRET_KEY", "b") + with patch("builtins.open", mock_open(read_data=MOCK_YAML)): + with patch("app.core.config.load_dotenv"): + with pytest.raises(RuntimeError, match="STORAGE_ACCESS_KEY"): + get_config() + get_config.cache_clear() +``` + +- [ ] **Step 2: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_config.py -v +``` + +Expected: `ImportError: cannot import name 'get_config'` + +- [ ] **Step 3: 实现 `app/core/config.py`** + +```python +import os +import yaml +from functools import lru_cache +from pathlib import Path +from dotenv import load_dotenv + +_ROOT = Path(__file__).parent.parent.parent + +_ENV_OVERRIDES = { + "ZHIPUAI_API_KEY": ["zhipuai", "api_key"], + "STORAGE_ACCESS_KEY": ["storage", "access_key"], + "STORAGE_SECRET_KEY": ["storage", "secret_key"], + "STORAGE_ENDPOINT": ["storage", "endpoint"], + "BACKEND_CALLBACK_URL": ["backend", "callback_url"], + "LOG_LEVEL": ["server", "log_level"], +} + + +def _set_nested(d: dict, keys: list[str], value: str) -> None: + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value + + +@lru_cache(maxsize=1) +def get_config() -> dict: + load_dotenv(_ROOT / ".env") + with open(_ROOT / "config.yaml", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + for env_key, yaml_path in _ENV_OVERRIDES.items(): + val = os.environ.get(env_key) + if val: + _set_nested(cfg, yaml_path, val) + _validate(cfg) + return cfg + + +def _validate(cfg: dict) -> None: + checks = [ + (["zhipuai", "api_key"], "ZHIPUAI_API_KEY"), + (["storage", "access_key"], "STORAGE_ACCESS_KEY"), + (["storage", "secret_key"], "STORAGE_SECRET_KEY"), + ] + for path, name in checks: + val = cfg + for k in path: + val = (val or {}).get(k, "") + if not val: + raise RuntimeError(f"缺少必要配置项:{name}") +``` + +- [ ] **Step 4: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_config.py -v +``` + +Expected: `4 passed` + +- [ ] **Step 5: Commit** + +```bash +git add app/core/config.py tests/test_config.py +git commit -m "feat: core config module with YAML + env layered loading" +``` + +--- + +## Task 3: Core Logging、Exceptions、JSON Utils + +**Files:** +- Create: `app/core/logging.py` +- Create: `app/core/exceptions.py` +- Create: `app/core/json_utils.py` + +- [ ] **Step 1: 实现 `app/core/logging.py`** + +```python +import json +import logging +import time +from typing import Callable + +from fastapi import Request, Response + + +class _JSONFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + entry: dict = { + "time": self.formatTime(record), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + if record.exc_info: + entry["exception"] = self.formatException(record.exc_info) + return json.dumps(entry, ensure_ascii=False) + + +def setup_logging(log_level: str = "INFO") -> None: + handler = logging.StreamHandler() + handler.setFormatter(_JSONFormatter()) + root = logging.getLogger() + root.handlers.clear() + root.addHandler(handler) + root.setLevel(getattr(logging, log_level.upper(), logging.INFO)) + + +async def request_logging_middleware(request: Request, call_next: Callable) -> Response: + start = time.monotonic() + response = await call_next(request) + duration_ms = round((time.monotonic() - start) * 1000, 2) + logging.getLogger("api").info( + f"method={request.method} path={request.url.path} " + f"status={response.status_code} duration_ms={duration_ms}" + ) + return response +``` + +- [ ] **Step 2: 实现 `app/core/exceptions.py`** + +```python +import logging +from fastapi import Request +from fastapi.responses import JSONResponse + + +class UnsupportedFileTypeError(Exception): + def __init__(self, ext: str): + super().__init__(f"不支持的文件类型:{ext}") + + +class StorageDownloadError(Exception): + pass + + +class LLMResponseParseError(Exception): + pass + + +class LLMCallError(Exception): + pass + + +async def unsupported_file_type_handler(request: Request, exc: UnsupportedFileTypeError): + return JSONResponse( + status_code=400, + content={"code": "UNSUPPORTED_FILE_TYPE", "message": str(exc)}, + ) + + +async def storage_download_handler(request: Request, exc: StorageDownloadError): + return JSONResponse( + status_code=502, + content={"code": "STORAGE_ERROR", "message": str(exc)}, + ) + + +async def llm_parse_handler(request: Request, exc: LLMResponseParseError): + return JSONResponse( + status_code=502, + content={"code": "LLM_PARSE_ERROR", "message": str(exc)}, + ) + + +async def llm_call_handler(request: Request, exc: LLMCallError): + return JSONResponse( + status_code=503, + content={"code": "LLM_CALL_ERROR", "message": str(exc)}, + ) + + +async def generic_error_handler(request: Request, exc: Exception): + logging.getLogger("error").exception("未捕获异常") + return JSONResponse( + status_code=500, + content={"code": "INTERNAL_ERROR", "message": "服务器内部错误"}, + ) +``` + +- [ ] **Step 3: 实现 `app/core/json_utils.py`** + +```python +import json +from app.core.exceptions import LLMResponseParseError + + +def parse_json_response(raw: str) -> list | dict: + """从 GLM 响应中解析 JSON,兼容 markdown 代码块包裹格式。""" + content = raw.strip() + if "```json" in content: + content = content.split("```json")[1].split("```")[0] + elif "```" in content: + content = content.split("```")[1].split("```")[0] + content = content.strip() + try: + return json.loads(content) + except json.JSONDecodeError as e: + raise LLMResponseParseError( + f"GLM 返回内容无法解析为 JSON: {raw[:200]}" + ) from e +``` + +- [ ] **Step 4: Commit** + +```bash +git add app/core/logging.py app/core/exceptions.py app/core/json_utils.py +git commit -m "feat: core logging, exceptions, json utils" +``` + +--- + +## Task 4: LLM 适配层 + +**Files:** +- Create: `app/clients/llm/base.py` +- Create: `app/clients/llm/zhipuai_client.py` +- Create: `tests/test_llm_client.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_llm_client.py`: + +```python +import asyncio +import pytest +from unittest.mock import MagicMock, patch +from app.clients.llm.zhipuai_client import ZhipuAIClient + + +@pytest.fixture +def zhipuai_client(): + with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI: + mock_sdk = MagicMock() + MockZhipuAI.return_value = mock_sdk + client = ZhipuAIClient(api_key="test-key") + client._mock_sdk = mock_sdk + yield client + + +def test_chat_returns_content(zhipuai_client): + mock_resp = MagicMock() + mock_resp.choices[0].message.content = "三元组提取结果" + zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp + + result = asyncio.run( + zhipuai_client.chat( + messages=[{"role": "user", "content": "提取三元组"}], + model="glm-4-flash", + ) + ) + assert result == "三元组提取结果" + zhipuai_client._mock_sdk.chat.completions.create.assert_called_once() + + +def test_chat_vision_calls_same_endpoint(zhipuai_client): + mock_resp = MagicMock() + mock_resp.choices[0].message.content = "图像分析结果" + zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp + + result = asyncio.run( + zhipuai_client.chat_vision( + messages=[{"role": "user", "content": [{"type": "text", "text": "分析"}]}], + model="glm-4v-flash", + ) + ) + assert result == "图像分析结果" +``` + +- [ ] **Step 2: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_llm_client.py -v +``` + +Expected: `ImportError` + +- [ ] **Step 3: 实现 `app/clients/llm/base.py`** + +```python +from abc import ABC, abstractmethod + + +class LLMClient(ABC): + @abstractmethod + async def chat(self, messages: list[dict], model: str, **kwargs) -> str: + """纯文本对话,返回模型输出文本。""" + + @abstractmethod + async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str: + """多模态对话(图文混合输入),返回模型输出文本。""" +``` + +- [ ] **Step 4: 实现 `app/clients/llm/zhipuai_client.py`** + +```python +import asyncio +from zhipuai import ZhipuAI +from app.clients.llm.base import LLMClient + + +class ZhipuAIClient(LLMClient): + def __init__(self, api_key: str): + self._client = ZhipuAI(api_key=api_key) + + async def chat(self, messages: list[dict], model: str, **kwargs) -> str: + loop = asyncio.get_event_loop() + resp = await loop.run_in_executor( + None, + lambda: self._client.chat.completions.create( + model=model, messages=messages, **kwargs + ), + ) + return resp.choices[0].message.content + + async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str: + # GLM-4V 与文本接口相同,通过 image_url type 区分图文消息 + return await self.chat(messages, model, **kwargs) +``` + +- [ ] **Step 5: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_llm_client.py -v +``` + +Expected: `2 passed` + +- [ ] **Step 6: Commit** + +```bash +git add app/clients/llm/ tests/test_llm_client.py +git commit -m "feat: LLMClient ABC and ZhipuAI implementation" +``` + +--- + +## Task 5: Storage 适配层 + +**Files:** +- Create: `app/clients/storage/base.py` +- Create: `app/clients/storage/rustfs_client.py` +- Create: `tests/test_storage_client.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_storage_client.py`: + +```python +import asyncio +import pytest +from unittest.mock import MagicMock, patch +from app.clients.storage.rustfs_client import RustFSClient + + +@pytest.fixture +def rustfs_client(): + with patch("app.clients.storage.rustfs_client.boto3") as mock_boto3: + mock_s3 = MagicMock() + mock_boto3.client.return_value = mock_s3 + client = RustFSClient( + endpoint="http://localhost:9000", + access_key="minioadmin", + secret_key="minioadmin", + ) + client._mock_s3 = mock_s3 + yield client + + +def test_download_bytes(rustfs_client): + mock_body = MagicMock() + mock_body.read.return_value = b"file content" + rustfs_client._mock_s3.get_object.return_value = {"Body": mock_body} + + result = asyncio.run( + rustfs_client.download_bytes("source-data", "text/202404/1.txt") + ) + assert result == b"file content" + rustfs_client._mock_s3.get_object.assert_called_once_with( + Bucket="source-data", Key="text/202404/1.txt" + ) + + +def test_upload_bytes(rustfs_client): + asyncio.run( + rustfs_client.upload_bytes("source-data", "crops/1/0.jpg", b"img", "image/jpeg") + ) + rustfs_client._mock_s3.put_object.assert_called_once_with( + Bucket="source-data", Key="crops/1/0.jpg", Body=b"img", ContentType="image/jpeg" + ) + + +def test_get_presigned_url(rustfs_client): + rustfs_client._mock_s3.generate_presigned_url.return_value = "https://example.com/signed" + url = rustfs_client.get_presigned_url("source-data", "crops/1/0.jpg", expires=3600) + assert url == "https://example.com/signed" + rustfs_client._mock_s3.generate_presigned_url.assert_called_once_with( + "get_object", + Params={"Bucket": "source-data", "Key": "crops/1/0.jpg"}, + ExpiresIn=3600, + ) +``` + +- [ ] **Step 2: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_storage_client.py -v +``` + +Expected: `ImportError` + +- [ ] **Step 3: 实现 `app/clients/storage/base.py`** + +```python +from abc import ABC, abstractmethod + + +class StorageClient(ABC): + @abstractmethod + async def download_bytes(self, bucket: str, path: str) -> bytes: + """从对象存储下载文件,返回字节内容。""" + + @abstractmethod + async def upload_bytes( + self, + bucket: str, + path: str, + data: bytes, + content_type: str = "application/octet-stream", + ) -> None: + """上传字节内容到对象存储。""" + + @abstractmethod + def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str: + """生成预签名访问 URL。""" +``` + +- [ ] **Step 4: 实现 `app/clients/storage/rustfs_client.py`** + +```python +import asyncio +import boto3 +from app.clients.storage.base import StorageClient + + +class RustFSClient(StorageClient): + def __init__(self, endpoint: str, access_key: str, secret_key: str): + self._s3 = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + + async def download_bytes(self, bucket: str, path: str) -> bytes: + loop = asyncio.get_event_loop() + resp = await loop.run_in_executor( + None, lambda: self._s3.get_object(Bucket=bucket, Key=path) + ) + return resp["Body"].read() + + async def upload_bytes( + self, + bucket: str, + path: str, + data: bytes, + content_type: str = "application/octet-stream", + ) -> None: + loop = asyncio.get_event_loop() + await loop.run_in_executor( + None, + lambda: self._s3.put_object( + Bucket=bucket, Key=path, Body=data, ContentType=content_type + ), + ) + + def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str: + return self._s3.generate_presigned_url( + "get_object", + Params={"Bucket": bucket, "Key": path}, + ExpiresIn=expires, + ) +``` + +- [ ] **Step 5: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_storage_client.py -v +``` + +Expected: `3 passed` + +- [ ] **Step 6: Commit** + +```bash +git add app/clients/storage/ tests/test_storage_client.py +git commit -m "feat: StorageClient ABC and RustFS S3 implementation" +``` + +--- + +## Task 6: 依赖注入 + FastAPI 应用入口 + +**Files:** +- Create: `app/core/dependencies.py` +- Create: `app/main.py` + +- [ ] **Step 1: 实现 `app/core/dependencies.py`** + +```python +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient + +_llm_client: LLMClient | None = None +_storage_client: StorageClient | None = None + + +def set_clients(llm: LLMClient, storage: StorageClient) -> None: + global _llm_client, _storage_client + _llm_client, _storage_client = llm, storage + + +def get_llm_client() -> LLMClient: + return _llm_client + + +def get_storage_client() -> StorageClient: + return _storage_client +``` + +- [ ] **Step 2: 实现 `app/main.py`** + +注意:routers 在后续任务中创建,先注释掉 include_router,待各路由实现后逐步取消注释。 + +```python +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from app.core.config import get_config +from app.core.dependencies import set_clients +from app.core.exceptions import ( + LLMCallError, + LLMResponseParseError, + StorageDownloadError, + UnsupportedFileTypeError, + generic_error_handler, + llm_call_handler, + llm_parse_handler, + storage_download_handler, + unsupported_file_type_handler, +) +from app.core.logging import request_logging_middleware, setup_logging +from app.clients.llm.zhipuai_client import ZhipuAIClient +from app.clients.storage.rustfs_client import RustFSClient + + +@asynccontextmanager +async def lifespan(app: FastAPI): + cfg = get_config() + setup_logging(cfg["server"]["log_level"]) + set_clients( + llm=ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"]), + storage=RustFSClient( + endpoint=cfg["storage"]["endpoint"], + access_key=cfg["storage"]["access_key"], + secret_key=cfg["storage"]["secret_key"], + ), + ) + logging.getLogger("startup").info("AI 服务启动完成") + yield + logging.getLogger("startup").info("AI 服务关闭") + + +app = FastAPI(title="Label AI Service", version="1.0.0", lifespan=lifespan) + +app.middleware("http")(request_logging_middleware) + +app.add_exception_handler(UnsupportedFileTypeError, unsupported_file_type_handler) +app.add_exception_handler(StorageDownloadError, storage_download_handler) +app.add_exception_handler(LLMResponseParseError, llm_parse_handler) +app.add_exception_handler(LLMCallError, llm_call_handler) +app.add_exception_handler(Exception, generic_error_handler) + +# Routers registered after each task: +# from app.routers import text, image, video, qa, finetune +# app.include_router(text.router, prefix="/api/v1") +# app.include_router(image.router, prefix="/api/v1") +# app.include_router(video.router, prefix="/api/v1") +# app.include_router(qa.router, prefix="/api/v1") +# app.include_router(finetune.router, prefix="/api/v1") +``` + +- [ ] **Step 3: Commit** + +```bash +git add app/core/dependencies.py app/main.py +git commit -m "feat: DI dependencies and FastAPI app entry with lifespan" +``` + +--- + +## Task 7: Text Pydantic Models + +**Files:** +- Create: `app/models/text_models.py` + +- [ ] **Step 1: 实现 `app/models/text_models.py`** + +```python +from pydantic import BaseModel + + +class SourceOffset(BaseModel): + start: int + end: int + + +class TripleItem(BaseModel): + subject: str + predicate: str + object: str + source_snippet: str + source_offset: SourceOffset + + +class TextExtractRequest(BaseModel): + file_path: str + file_name: str + model: str | None = None + prompt_template: str | None = None + + +class TextExtractResponse(BaseModel): + items: list[TripleItem] +``` + +- [ ] **Step 2: 快速验证 schema** + +```bash +conda run -n label python -c " +from app.models.text_models import TextExtractRequest, TextExtractResponse, TripleItem, SourceOffset +req = TextExtractRequest(file_path='text/1.txt', file_name='1.txt') +item = TripleItem(subject='A', predicate='B', object='C', source_snippet='ABC', source_offset=SourceOffset(start=0, end=3)) +resp = TextExtractResponse(items=[item]) +print(resp.model_dump()) +" +``` + +Expected: 打印出完整字典,无报错 + +- [ ] **Step 3: Commit** + +```bash +git add app/models/text_models.py +git commit -m "feat: text Pydantic models" +``` + +--- + +## Task 8: Text Service + +**Files:** +- Create: `app/services/text_service.py` +- Create: `tests/test_text_service.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_text_service.py`: + +```python +import pytest +from app.services.text_service import extract_triples, _extract_text_from_bytes +from app.core.exceptions import UnsupportedFileTypeError, LLMResponseParseError, StorageDownloadError + +TRIPLE_JSON = '[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压为110kV","source_offset":{"start":0,"end":10}}]' + + +@pytest.mark.asyncio +async def test_extract_triples_txt(mock_llm, mock_storage): + mock_storage.download_bytes.return_value = b"变压器额定电压为110kV" + mock_llm.chat.return_value = TRIPLE_JSON + + result = await extract_triples( + file_path="text/1.txt", + file_name="test.txt", + model="glm-4-flash", + prompt_template="提取三元组:", + llm=mock_llm, + storage=mock_storage, + ) + assert len(result) == 1 + assert result[0].subject == "变压器" + assert result[0].predicate == "额定电压" + assert result[0].object == "110kV" + assert result[0].source_offset.start == 0 + + +@pytest.mark.asyncio +async def test_extract_triples_markdown_wrapped_json(mock_llm, mock_storage): + mock_storage.download_bytes.return_value = b"some text" + mock_llm.chat.return_value = f"```json\n{TRIPLE_JSON}\n```" + + result = await extract_triples( + file_path="text/1.txt", + file_name="test.txt", + model="glm-4-flash", + prompt_template="", + llm=mock_llm, + storage=mock_storage, + ) + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_extract_triples_storage_error(mock_llm, mock_storage): + mock_storage.download_bytes.side_effect = Exception("connection refused") + + with pytest.raises(StorageDownloadError): + await extract_triples( + file_path="text/1.txt", + file_name="test.txt", + model="glm-4-flash", + prompt_template="", + llm=mock_llm, + storage=mock_storage, + ) + + +@pytest.mark.asyncio +async def test_extract_triples_llm_parse_error(mock_llm, mock_storage): + mock_storage.download_bytes.return_value = b"some text" + mock_llm.chat.return_value = "这不是JSON" + + with pytest.raises(LLMResponseParseError): + await extract_triples( + file_path="text/1.txt", + file_name="test.txt", + model="glm-4-flash", + prompt_template="", + llm=mock_llm, + storage=mock_storage, + ) + + +def test_unsupported_file_type_raises(): + with pytest.raises(UnsupportedFileTypeError): + _extract_text_from_bytes(b"content", "doc.xlsx") + + +def test_parse_txt_bytes(): + result = _extract_text_from_bytes("你好世界".encode("utf-8"), "file.txt") + assert result == "你好世界" +``` + +- [ ] **Step 2: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_text_service.py -v +``` + +Expected: `ImportError` + +- [ ] **Step 3: 实现 `app/services/text_service.py`** + +```python +import logging +from pathlib import Path + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError, UnsupportedFileTypeError +from app.core.json_utils import parse_json_response +from app.models.text_models import SourceOffset, TripleItem + +logger = logging.getLogger(__name__) + +DEFAULT_PROMPT = """请从以下文本中提取知识三元组。 +对每个三元组提供: +- subject:主语实体 +- predicate:谓语关系 +- object:宾语实体 +- source_snippet:原文中的证据片段(直接引用原文) +- source_offset:证据片段字符偏移 {"start": N, "end": M} + +以 JSON 数组格式返回,例如: +[{"subject":"...","predicate":"...","object":"...","source_snippet":"...","source_offset":{"start":0,"end":50}}] + +文本内容: +""" + + +def _parse_txt(data: bytes) -> str: + return data.decode("utf-8") + + +def _parse_pdf(data: bytes) -> str: + import io + import pdfplumber + with pdfplumber.open(io.BytesIO(data)) as pdf: + return "\n".join(page.extract_text() or "" for page in pdf.pages) + + +def _parse_docx(data: bytes) -> str: + import io + import docx + doc = docx.Document(io.BytesIO(data)) + return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) + + +_PARSERS = { + ".txt": _parse_txt, + ".pdf": _parse_pdf, + ".docx": _parse_docx, +} + + +def _extract_text_from_bytes(data: bytes, filename: str) -> str: + ext = Path(filename).suffix.lower() + parser = _PARSERS.get(ext) + if parser is None: + raise UnsupportedFileTypeError(ext) + return parser(data) + + +async def extract_triples( + file_path: str, + file_name: str, + model: str, + prompt_template: str, + llm: LLMClient, + storage: StorageClient, + bucket: str = "source-data", +) -> list[TripleItem]: + try: + data = await storage.download_bytes(bucket, file_path) + except Exception as e: + raise StorageDownloadError(f"下载文件失败 {file_path}: {e}") from e + + text = _extract_text_from_bytes(data, file_name) + prompt = prompt_template or DEFAULT_PROMPT + + messages = [ + {"role": "system", "content": "你是专业的知识图谱构建助手,擅长从文本中提取结构化知识三元组。"}, + {"role": "user", "content": prompt + text}, + ] + + try: + raw = await llm.chat(messages, model) + except Exception as e: + raise LLMCallError(f"GLM 调用失败: {e}") from e + + logger.info(f"text_extract file={file_path} model={model}") + + items_raw = parse_json_response(raw) + + result = [] + for item in items_raw: + try: + offset = item.get("source_offset", {}) + result.append(TripleItem( + subject=item["subject"], + predicate=item["predicate"], + object=item["object"], + source_snippet=item.get("source_snippet", ""), + source_offset=SourceOffset( + start=offset.get("start", 0), + end=offset.get("end", 0), + ), + )) + except (KeyError, TypeError) as e: + logger.warning(f"跳过不完整三元组: {item}, error: {e}") + + return result +``` + +- [ ] **Step 4: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_text_service.py -v +``` + +Expected: `6 passed` + +- [ ] **Step 5: Commit** + +```bash +git add app/services/text_service.py tests/test_text_service.py +git commit -m "feat: text service with txt/pdf/docx parsing and triple extraction" +``` + +--- + +## Task 9: Text Router + +**Files:** +- Create: `app/routers/text.py` +- Create: `tests/test_text_router.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_text_router.py`: + +```python +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, patch +from app.main import app +from app.core.dependencies import set_clients +from app.models.text_models import TripleItem, SourceOffset + + +@pytest.fixture +def client(mock_llm, mock_storage): + set_clients(mock_llm, mock_storage) + return TestClient(app) + + +def test_text_extract_success(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=b"变压器额定电压110kV") + mock_llm.chat = AsyncMock(return_value='[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压110kV","source_offset":{"start":3,"end":10}}]') + + resp = client.post("/api/v1/text/extract", json={ + "file_path": "text/202404/1.txt", + "file_name": "规范.txt", + }) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + assert data["items"][0]["subject"] == "变压器" + + +def test_text_extract_unsupported_file(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=b"content") + resp = client.post("/api/v1/text/extract", json={ + "file_path": "text/202404/1.xlsx", + "file_name": "file.xlsx", + }) + assert resp.status_code == 400 + assert resp.json()["code"] == "UNSUPPORTED_FILE_TYPE" +``` + +- [ ] **Step 2: 实现 `app/routers/text.py`** + +```python +from fastapi import APIRouter, Depends + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.config import get_config +from app.core.dependencies import get_llm_client, get_storage_client +from app.models.text_models import TextExtractRequest, TextExtractResponse +from app.services import text_service + +router = APIRouter(tags=["Text"]) + + +@router.post("/text/extract", response_model=TextExtractResponse) +async def extract_text( + req: TextExtractRequest, + llm: LLMClient = Depends(get_llm_client), + storage: StorageClient = Depends(get_storage_client), +): + cfg = get_config() + model = req.model or cfg["models"]["default_text"] + prompt = req.prompt_template or text_service.DEFAULT_PROMPT + + items = await text_service.extract_triples( + file_path=req.file_path, + file_name=req.file_name, + model=model, + prompt_template=prompt, + llm=llm, + storage=storage, + bucket=cfg["storage"]["buckets"]["source_data"], + ) + return TextExtractResponse(items=items) +``` + +- [ ] **Step 3: 在 `app/main.py` 注册路由** + +取消注释以下两行: + +```python +from app.routers import text +app.include_router(text.router, prefix="/api/v1") +``` + +- [ ] **Step 4: 运行测试** + +```bash +conda run -n label pytest tests/test_text_router.py -v +``` + +Expected: `2 passed` + +- [ ] **Step 5: Commit** + +```bash +git add app/routers/text.py tests/test_text_router.py app/main.py +git commit -m "feat: text router POST /api/v1/text/extract" +``` + +--- + +## Task 10: Image Models + Service + +**Files:** +- Create: `app/models/image_models.py` +- Create: `app/services/image_service.py` +- Create: `tests/test_image_service.py` + +- [ ] **Step 1: 实现 `app/models/image_models.py`** + +```python +from pydantic import BaseModel + + +class BBox(BaseModel): + x: int + y: int + w: int + h: int + + +class QuadrupleItem(BaseModel): + subject: str + predicate: str + object: str + qualifier: str + bbox: BBox + cropped_image_path: str + + +class ImageExtractRequest(BaseModel): + file_path: str + task_id: int + model: str | None = None + prompt_template: str | None = None + + +class ImageExtractResponse(BaseModel): + items: list[QuadrupleItem] +``` + +- [ ] **Step 2: 编写失败测试** + +`tests/test_image_service.py`: + +```python +import pytest +import numpy as np +import cv2 +from app.services.image_service import extract_quadruples, _crop_image +from app.models.image_models import BBox +from app.core.exceptions import LLMResponseParseError, StorageDownloadError + +QUAD_JSON = '[{"subject":"电缆接头","predicate":"位于","object":"配电箱左侧","qualifier":"2024年","bbox":{"x":10,"y":20,"w":50,"h":40}}]' + + +def _make_test_image_bytes(width=200, height=200) -> bytes: + img = np.zeros((height, width, 3), dtype=np.uint8) + img[:] = (100, 150, 200) + _, buf = cv2.imencode(".jpg", img) + return buf.tobytes() + + +def test_crop_image(): + img_bytes = _make_test_image_bytes(200, 200) + bbox = BBox(x=10, y=20, w=50, h=40) + result = _crop_image(img_bytes, bbox) + assert isinstance(result, bytes) + arr = np.frombuffer(result, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + assert img.shape[0] == 40 # height + assert img.shape[1] == 50 # width + + +@pytest.mark.asyncio +async def test_extract_quadruples_success(mock_llm, mock_storage): + mock_storage.download_bytes.return_value = _make_test_image_bytes() + mock_llm.chat_vision.return_value = QUAD_JSON + mock_storage.upload_bytes.return_value = None + + result = await extract_quadruples( + file_path="image/202404/1.jpg", + task_id=789, + model="glm-4v-flash", + prompt_template="提取四元组", + llm=mock_llm, + storage=mock_storage, + ) + assert len(result) == 1 + assert result[0].subject == "电缆接头" + assert result[0].cropped_image_path == "crops/789/0.jpg" + mock_storage.upload_bytes.assert_called_once() + + +@pytest.mark.asyncio +async def test_extract_quadruples_storage_error(mock_llm, mock_storage): + mock_storage.download_bytes.side_effect = Exception("timeout") + with pytest.raises(StorageDownloadError): + await extract_quadruples( + file_path="image/1.jpg", + task_id=1, + model="glm-4v-flash", + prompt_template="", + llm=mock_llm, + storage=mock_storage, + ) + + +@pytest.mark.asyncio +async def test_extract_quadruples_parse_error(mock_llm, mock_storage): + mock_storage.download_bytes.return_value = _make_test_image_bytes() + mock_llm.chat_vision.return_value = "不是JSON" + with pytest.raises(LLMResponseParseError): + await extract_quadruples( + file_path="image/1.jpg", + task_id=1, + model="glm-4v-flash", + prompt_template="", + llm=mock_llm, + storage=mock_storage, + ) +``` + +- [ ] **Step 3: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_image_service.py -v +``` + +Expected: `ImportError` + +- [ ] **Step 4: 实现 `app/services/image_service.py`** + +```python +import base64 +import logging +from pathlib import Path + +import cv2 +import numpy as np + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError +from app.core.json_utils import parse_json_response +from app.models.image_models import BBox, QuadrupleItem + +logger = logging.getLogger(__name__) + +DEFAULT_PROMPT = """请分析这张图片,提取知识四元组。 +对每个四元组提供: +- subject:主体实体 +- predicate:关系/属性 +- object:客体实体 +- qualifier:修饰信息(时间、条件、场景,无则填空字符串) +- bbox:边界框 {"x": N, "y": N, "w": N, "h": N}(像素坐标,相对原图) + +以 JSON 数组格式返回: +[{"subject":"...","predicate":"...","object":"...","qualifier":"...","bbox":{"x":0,"y":0,"w":100,"h":100}}] +""" + + +def _crop_image(image_bytes: bytes, bbox: BBox) -> bytes: + arr = np.frombuffer(image_bytes, dtype=np.uint8) + img = cv2.imdecode(arr, cv2.IMREAD_COLOR) + h, w = img.shape[:2] + x = max(0, bbox.x) + y = max(0, bbox.y) + x2 = min(w, bbox.x + bbox.w) + y2 = min(h, bbox.y + bbox.h) + cropped = img[y:y2, x:x2] + _, buf = cv2.imencode(".jpg", cropped, [cv2.IMWRITE_JPEG_QUALITY, 90]) + return buf.tobytes() + + +async def extract_quadruples( + file_path: str, + task_id: int, + model: str, + prompt_template: str, + llm: LLMClient, + storage: StorageClient, + source_bucket: str = "source-data", +) -> list[QuadrupleItem]: + try: + data = await storage.download_bytes(source_bucket, file_path) + except Exception as e: + raise StorageDownloadError(f"下载图片失败 {file_path}: {e}") from e + + ext = Path(file_path).suffix.lstrip(".") or "jpeg" + b64 = base64.b64encode(data).decode() + + messages = [ + {"role": "system", "content": "你是专业的视觉分析助手,擅长从图像中提取结构化知识四元组。"}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}}, + {"type": "text", "text": prompt_template or DEFAULT_PROMPT}, + ]}, + ] + + try: + raw = await llm.chat_vision(messages, model) + except Exception as e: + raise LLMCallError(f"GLM-4V 调用失败: {e}") from e + + logger.info(f"image_extract file={file_path} model={model}") + items_raw = parse_json_response(raw) + + result = [] + for i, item in enumerate(items_raw): + try: + bbox = BBox(**item["bbox"]) + cropped = _crop_image(data, bbox) + crop_path = f"crops/{task_id}/{i}.jpg" + await storage.upload_bytes(source_bucket, crop_path, cropped, "image/jpeg") + result.append(QuadrupleItem( + subject=item["subject"], + predicate=item["predicate"], + object=item["object"], + qualifier=item.get("qualifier", ""), + bbox=bbox, + cropped_image_path=crop_path, + )) + except (KeyError, TypeError, Exception) as e: + logger.warning(f"跳过不完整四元组 index={i}: {e}") + + return result +``` + +- [ ] **Step 5: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_image_service.py -v +``` + +Expected: `4 passed` + +- [ ] **Step 6: Commit** + +```bash +git add app/models/image_models.py app/services/image_service.py tests/test_image_service.py +git commit -m "feat: image models, service with bbox crop and quadruple extraction" +``` + +--- + +## Task 11: Image Router + +**Files:** +- Create: `app/routers/image.py` +- Create: `tests/test_image_router.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_image_router.py`: + +```python +import numpy as np +import cv2 +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock +from app.main import app +from app.core.dependencies import set_clients + + +def _make_image_bytes() -> bytes: + img = np.zeros((100, 100, 3), dtype=np.uint8) + _, buf = cv2.imencode(".jpg", img) + return buf.tobytes() + + +@pytest.fixture +def client(mock_llm, mock_storage): + set_clients(mock_llm, mock_storage) + return TestClient(app) + + +def test_image_extract_success(client, mock_llm, mock_storage): + mock_storage.download_bytes = AsyncMock(return_value=_make_image_bytes()) + mock_storage.upload_bytes = AsyncMock(return_value=None) + mock_llm.chat_vision = AsyncMock(return_value='[{"subject":"A","predicate":"B","object":"C","qualifier":"","bbox":{"x":0,"y":0,"w":10,"h":10}}]') + + resp = client.post("/api/v1/image/extract", json={ + "file_path": "image/202404/1.jpg", + "task_id": 42, + }) + assert resp.status_code == 200 + data = resp.json() + assert len(data["items"]) == 1 + assert data["items"][0]["cropped_image_path"] == "crops/42/0.jpg" +``` + +- [ ] **Step 2: 实现 `app/routers/image.py`** + +```python +from fastapi import APIRouter, Depends + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.config import get_config +from app.core.dependencies import get_llm_client, get_storage_client +from app.models.image_models import ImageExtractRequest, ImageExtractResponse +from app.services import image_service + +router = APIRouter(tags=["Image"]) + + +@router.post("/image/extract", response_model=ImageExtractResponse) +async def extract_image( + req: ImageExtractRequest, + llm: LLMClient = Depends(get_llm_client), + storage: StorageClient = Depends(get_storage_client), +): + cfg = get_config() + model = req.model or cfg["models"]["default_vision"] + prompt = req.prompt_template or image_service.DEFAULT_PROMPT + + items = await image_service.extract_quadruples( + file_path=req.file_path, + task_id=req.task_id, + model=model, + prompt_template=prompt, + llm=llm, + storage=storage, + source_bucket=cfg["storage"]["buckets"]["source_data"], + ) + return ImageExtractResponse(items=items) +``` + +- [ ] **Step 3: 在 `app/main.py` 注册路由** + +```python +from app.routers import text, image +app.include_router(image.router, prefix="/api/v1") +``` + +- [ ] **Step 4: 运行测试** + +```bash +conda run -n label pytest tests/test_image_router.py -v +``` + +Expected: `1 passed` + +- [ ] **Step 5: Commit** + +```bash +git add app/routers/image.py tests/test_image_router.py app/main.py +git commit -m "feat: image router POST /api/v1/image/extract" +``` + +--- + +## Task 12: Video Models + Service + +**Files:** +- Create: `app/models/video_models.py` +- Create: `app/services/video_service.py` +- Create: `tests/test_video_service.py` + +- [ ] **Step 1: 实现 `app/models/video_models.py`** + +```python +from pydantic import BaseModel + + +class ExtractFramesRequest(BaseModel): + file_path: str + source_id: int + job_id: int + mode: str = "interval" # interval | keyframe + frame_interval: int = 30 + + +class ExtractFramesResponse(BaseModel): + message: str + job_id: int + + +class FrameInfo(BaseModel): + frame_index: int + time_sec: float + frame_path: str + + +class VideoToTextRequest(BaseModel): + file_path: str + source_id: int + job_id: int + start_sec: float = 0.0 + end_sec: float + model: str | None = None + prompt_template: str | None = None + + +class VideoToTextResponse(BaseModel): + message: str + job_id: int + + +class VideoJobCallback(BaseModel): + job_id: int + status: str # SUCCESS | FAILED + frames: list[FrameInfo] | None = None + output_path: str | None = None + error_message: str | None = None +``` + +- [ ] **Step 2: 编写失败测试** + +`tests/test_video_service.py`: + +```python +import numpy as np +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.services.video_service import _is_scene_change, extract_frames_background + + +def test_is_scene_change_different_frames(): + prev = np.zeros((100, 100), dtype=np.uint8) + curr = np.full((100, 100), 200, dtype=np.uint8) + assert _is_scene_change(prev, curr, threshold=30.0) is True + + +def test_is_scene_change_similar_frames(): + prev = np.full((100, 100), 100, dtype=np.uint8) + curr = np.full((100, 100), 101, dtype=np.uint8) + assert _is_scene_change(prev, curr, threshold=30.0) is False + + +@pytest.mark.asyncio +async def test_extract_frames_background_calls_callback_on_success(mock_storage): + import cv2 + import tempfile, os + + # 创建一个有效的真实测试视频(5帧,10x10) + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + tmp_path = f.name + + out = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*"mp4v"), 10, (10, 10)) + for _ in range(5): + out.write(np.zeros((10, 10, 3), dtype=np.uint8)) + out.release() + + with open(tmp_path, "rb") as f: + video_bytes = f.read() + os.unlink(tmp_path) + + mock_storage.download_bytes.return_value = video_bytes + mock_storage.upload_bytes = AsyncMock(return_value=None) + + with patch("app.services.video_service.httpx") as mock_httpx: + mock_client = AsyncMock() + mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock() + + await extract_frames_background( + file_path="video/1.mp4", + source_id=10, + job_id=42, + mode="interval", + frame_interval=1, + storage=mock_storage, + callback_url="http://backend/callback", + ) + + mock_client.post.assert_called_once() + call_kwargs = mock_client.post.call_args + payload = call_kwargs.kwargs.get("json") or call_kwargs.args[1] if len(call_kwargs.args) > 1 else call_kwargs.kwargs["json"] + assert payload["job_id"] == 42 + assert payload["status"] == "SUCCESS" + + +@pytest.mark.asyncio +async def test_extract_frames_background_calls_callback_on_failure(mock_storage): + mock_storage.download_bytes.side_effect = Exception("storage error") + + with patch("app.services.video_service.httpx") as mock_httpx: + mock_client = AsyncMock() + mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False) + mock_client.post = AsyncMock() + + await extract_frames_background( + file_path="video/1.mp4", + source_id=10, + job_id=99, + mode="interval", + frame_interval=30, + storage=mock_storage, + callback_url="http://backend/callback", + ) + + mock_client.post.assert_called_once() + call_kwargs = mock_client.post.call_args + payload = call_kwargs.kwargs.get("json") or (call_kwargs.args[1] if len(call_kwargs.args) > 1 else {}) + assert payload["status"] == "FAILED" + assert payload["job_id"] == 99 +``` + +- [ ] **Step 3: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_video_service.py -v +``` + +Expected: `ImportError` + +- [ ] **Step 4: 实现 `app/services/video_service.py`** + +```python +import base64 +import logging +import tempfile +import time +from pathlib import Path + +import cv2 +import httpx +import numpy as np + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.exceptions import LLMCallError +from app.models.video_models import FrameInfo, VideoJobCallback + +logger = logging.getLogger(__name__) + +DEFAULT_VIDEO_TO_TEXT_PROMPT = """请分析这段视频的帧序列,用中文详细描述: +1. 视频中出现的主要对象、设备、人物 +2. 发生的主要动作、操作步骤 +3. 场景的整体情况 + +请输出结构化的文字描述,适合作为知识图谱构建的文本素材。""" + + +def _is_scene_change(prev: np.ndarray, curr: np.ndarray, threshold: float = 30.0) -> bool: + """通过帧差分均值判断是否发生场景切换。""" + diff = cv2.absdiff(prev, curr) + return float(diff.mean()) > threshold + + +def _extract_frames( + video_path: str, mode: str, frame_interval: int +) -> list[tuple[int, float, bytes]]: + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + results = [] + prev_gray = None + idx = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + time_sec = idx / fps + if mode == "interval": + if idx % frame_interval == 0: + _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90]) + results.append((idx, time_sec, buf.tobytes())) + else: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + if prev_gray is None or _is_scene_change(prev_gray, gray): + _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90]) + results.append((idx, time_sec, buf.tobytes())) + prev_gray = gray + idx += 1 + + cap.release() + return results + + +def _sample_frames_as_base64( + video_path: str, start_sec: float, end_sec: float, count: int +) -> list[str]: + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 + start_frame = int(start_sec * fps) + end_frame = int(end_sec * fps) + total = max(1, end_frame - start_frame) + step = max(1, total // count) + results = [] + for i in range(count): + frame_pos = start_frame + i * step + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos) + ret, frame = cap.read() + if ret: + _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) + results.append(base64.b64encode(buf.tobytes()).decode()) + cap.release() + return results + + +async def _send_callback(url: str, payload: VideoJobCallback) -> None: + async with httpx.AsyncClient(timeout=10) as client: + try: + await client.post(url, json=payload.model_dump()) + except Exception as e: + logger.warning(f"回调失败 url={url}: {e}") + + +async def extract_frames_background( + file_path: str, + source_id: int, + job_id: int, + mode: str, + frame_interval: int, + storage: StorageClient, + callback_url: str, + bucket: str = "source-data", +) -> None: + try: + data = await storage.download_bytes(bucket, file_path) + except Exception as e: + await _send_callback(callback_url, VideoJobCallback( + job_id=job_id, status="FAILED", error_message=str(e) + )) + return + + suffix = Path(file_path).suffix or ".mp4" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp.write(data) + tmp_path = tmp.name + + try: + frames = _extract_frames(tmp_path, mode, frame_interval) + frame_infos = [] + for i, (frame_idx, time_sec, frame_data) in enumerate(frames): + frame_path = f"frames/{source_id}/{i}.jpg" + await storage.upload_bytes(bucket, frame_path, frame_data, "image/jpeg") + frame_infos.append(FrameInfo( + frame_index=frame_idx, + time_sec=round(time_sec, 3), + frame_path=frame_path, + )) + await _send_callback(callback_url, VideoJobCallback( + job_id=job_id, status="SUCCESS", frames=frame_infos + )) + logger.info(f"extract_frames job_id={job_id} frames={len(frame_infos)}") + except Exception as e: + logger.exception(f"extract_frames failed job_id={job_id}") + await _send_callback(callback_url, VideoJobCallback( + job_id=job_id, status="FAILED", error_message=str(e) + )) + finally: + Path(tmp_path).unlink(missing_ok=True) + + +async def video_to_text_background( + file_path: str, + source_id: int, + job_id: int, + start_sec: float, + end_sec: float, + model: str, + prompt_template: str, + frame_sample_count: int, + llm: LLMClient, + storage: StorageClient, + callback_url: str, + bucket: str = "source-data", +) -> None: + try: + data = await storage.download_bytes(bucket, file_path) + except Exception as e: + await _send_callback(callback_url, VideoJobCallback( + job_id=job_id, status="FAILED", error_message=str(e) + )) + return + + suffix = Path(file_path).suffix or ".mp4" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp.write(data) + tmp_path = tmp.name + + try: + frames_b64 = _sample_frames_as_base64(tmp_path, start_sec, end_sec, frame_sample_count) + content: list = [] + for b64 in frames_b64: + content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}}) + content.append({ + "type": "text", + "text": f"以上是视频第{start_sec}秒至{end_sec}秒的均匀采样帧。\n{prompt_template}", + }) + + messages = [ + {"role": "system", "content": "你是专业的视频内容分析助手。"}, + {"role": "user", "content": content}, + ] + + try: + description = await llm.chat_vision(messages, model) + except Exception as e: + raise LLMCallError(f"GLM-4V 调用失败: {e}") from e + + timestamp = int(time.time()) + output_path = f"video-text/{source_id}/{timestamp}.txt" + await storage.upload_bytes(bucket, output_path, description.encode("utf-8"), "text/plain") + + await _send_callback(callback_url, VideoJobCallback( + job_id=job_id, status="SUCCESS", output_path=output_path + )) + logger.info(f"video_to_text job_id={job_id} output={output_path}") + except Exception as e: + logger.exception(f"video_to_text failed job_id={job_id}") + await _send_callback(callback_url, VideoJobCallback( + job_id=job_id, status="FAILED", error_message=str(e) + )) + finally: + Path(tmp_path).unlink(missing_ok=True) +``` + +- [ ] **Step 5: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_video_service.py -v +``` + +Expected: `4 passed` + +- [ ] **Step 6: Commit** + +```bash +git add app/models/video_models.py app/services/video_service.py tests/test_video_service.py +git commit -m "feat: video models and service with frame extraction and video-to-text" +``` + +--- + +## Task 13: Video Router + +**Files:** +- Create: `app/routers/video.py` +- Create: `tests/test_video_router.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_video_router.py`: + +```python +import pytest +from fastapi.testclient import TestClient +from app.main import app +from app.core.dependencies import set_clients + + +@pytest.fixture +def client(mock_llm, mock_storage): + set_clients(mock_llm, mock_storage) + return TestClient(app) + + +def test_extract_frames_returns_202(client): + resp = client.post("/api/v1/video/extract-frames", json={ + "file_path": "video/202404/1.mp4", + "source_id": 10, + "job_id": 42, + "mode": "interval", + "frame_interval": 30, + }) + assert resp.status_code == 202 + assert resp.json()["job_id"] == 42 + assert "后台处理中" in resp.json()["message"] + + +def test_video_to_text_returns_202(client): + resp = client.post("/api/v1/video/to-text", json={ + "file_path": "video/202404/1.mp4", + "source_id": 10, + "job_id": 43, + "start_sec": 0, + "end_sec": 60, + }) + assert resp.status_code == 202 + assert resp.json()["job_id"] == 43 +``` + +- [ ] **Step 2: 实现 `app/routers/video.py`** + +```python +from fastapi import APIRouter, BackgroundTasks, Depends + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.config import get_config +from app.core.dependencies import get_llm_client, get_storage_client +from app.models.video_models import ( + ExtractFramesRequest, + ExtractFramesResponse, + VideoToTextRequest, + VideoToTextResponse, +) +from app.services import video_service + +router = APIRouter(tags=["Video"]) + + +@router.post("/video/extract-frames", response_model=ExtractFramesResponse, status_code=202) +async def extract_frames( + req: ExtractFramesRequest, + background_tasks: BackgroundTasks, + storage: StorageClient = Depends(get_storage_client), +): + cfg = get_config() + background_tasks.add_task( + video_service.extract_frames_background, + file_path=req.file_path, + source_id=req.source_id, + job_id=req.job_id, + mode=req.mode, + frame_interval=req.frame_interval, + storage=storage, + callback_url=cfg["backend"]["callback_url"], + bucket=cfg["storage"]["buckets"]["source_data"], + ) + return ExtractFramesResponse(message="任务已接受,后台处理中", job_id=req.job_id) + + +@router.post("/video/to-text", response_model=VideoToTextResponse, status_code=202) +async def video_to_text( + req: VideoToTextRequest, + background_tasks: BackgroundTasks, + llm: LLMClient = Depends(get_llm_client), + storage: StorageClient = Depends(get_storage_client), +): + cfg = get_config() + model = req.model or cfg["models"]["default_vision"] + prompt = req.prompt_template or video_service.DEFAULT_VIDEO_TO_TEXT_PROMPT + background_tasks.add_task( + video_service.video_to_text_background, + file_path=req.file_path, + source_id=req.source_id, + job_id=req.job_id, + start_sec=req.start_sec, + end_sec=req.end_sec, + model=model, + prompt_template=prompt, + frame_sample_count=cfg["video"]["frame_sample_count"], + llm=llm, + storage=storage, + callback_url=cfg["backend"]["callback_url"], + bucket=cfg["storage"]["buckets"]["source_data"], + ) + return VideoToTextResponse(message="任务已接受,后台处理中", job_id=req.job_id) +``` + +- [ ] **Step 3: 在 `app/main.py` 注册路由** + +```python +from app.routers import text, image, video +app.include_router(video.router, prefix="/api/v1") +``` + +- [ ] **Step 4: 运行测试** + +```bash +conda run -n label pytest tests/test_video_router.py -v +``` + +Expected: `2 passed` + +- [ ] **Step 5: Commit** + +```bash +git add app/routers/video.py tests/test_video_router.py app/main.py +git commit -m "feat: video router POST /api/v1/video/extract-frames and /to-text" +``` + +--- + +## Task 14: QA Models + Service + +**Files:** +- Create: `app/models/qa_models.py` +- Create: `app/services/qa_service.py` +- Create: `tests/test_qa_service.py` + +- [ ] **Step 1: 实现 `app/models/qa_models.py`** + +```python +from pydantic import BaseModel + + +class TextTripleForQA(BaseModel): + subject: str + predicate: str + object: str + source_snippet: str + + +class TextQARequest(BaseModel): + items: list[TextTripleForQA] + model: str | None = None + prompt_template: str | None = None + + +class QAPair(BaseModel): + question: str + answer: str + + +class TextQAResponse(BaseModel): + pairs: list[QAPair] + + +class ImageQuadrupleForQA(BaseModel): + subject: str + predicate: str + object: str + qualifier: str + cropped_image_path: str + + +class ImageQARequest(BaseModel): + items: list[ImageQuadrupleForQA] + model: str | None = None + prompt_template: str | None = None + + +class ImageQAPair(BaseModel): + question: str + answer: str + image_path: str + + +class ImageQAResponse(BaseModel): + pairs: list[ImageQAPair] +``` + +- [ ] **Step 2: 编写失败测试** + +`tests/test_qa_service.py`: + +```python +import pytest +from app.services.qa_service import gen_text_qa, gen_image_qa, _parse_qa_pairs +from app.models.qa_models import TextTripleForQA, ImageQuadrupleForQA +from app.core.exceptions import LLMResponseParseError, LLMCallError + +QA_JSON = '[{"question":"变压器额定电压是多少?","answer":"110kV"}]' + + +def test_parse_qa_pairs_plain_json(): + result = _parse_qa_pairs(QA_JSON) + assert len(result) == 1 + assert result[0].question == "变压器额定电压是多少?" + + +def test_parse_qa_pairs_markdown_wrapped(): + result = _parse_qa_pairs(f"```json\n{QA_JSON}\n```") + assert len(result) == 1 + + +def test_parse_qa_pairs_invalid_raises(): + with pytest.raises(LLMResponseParseError): + _parse_qa_pairs("这不是JSON") + + +@pytest.mark.asyncio +async def test_gen_text_qa(mock_llm): + mock_llm.chat.return_value = QA_JSON + items = [TextTripleForQA(subject="变压器", predicate="额定电压", object="110kV", source_snippet="额定电压为110kV")] + + result = await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm) + assert len(result) == 1 + assert result[0].answer == "110kV" + + +@pytest.mark.asyncio +async def test_gen_text_qa_llm_error(mock_llm): + mock_llm.chat.side_effect = Exception("network error") + items = [TextTripleForQA(subject="A", predicate="B", object="C", source_snippet="ABC")] + + with pytest.raises(LLMCallError): + await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm) + + +@pytest.mark.asyncio +async def test_gen_image_qa(mock_llm, mock_storage): + mock_llm.chat_vision.return_value = '[{"question":"图中是什么?","answer":"电缆接头"}]' + items = [ImageQuadrupleForQA( + subject="电缆接头", predicate="位于", object="配电箱", qualifier="", cropped_image_path="crops/1/0.jpg" + )] + + result = await gen_image_qa(items=items, model="glm-4v-flash", prompt_template="", llm=mock_llm, storage=mock_storage) + assert len(result) == 1 + assert result[0].image_path == "crops/1/0.jpg" + mock_storage.get_presigned_url.assert_called_once_with("source-data", "crops/1/0.jpg") +``` + +- [ ] **Step 3: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_qa_service.py -v +``` + +Expected: `ImportError` + +- [ ] **Step 4: 实现 `app/services/qa_service.py`** + +```python +import json +import logging + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.exceptions import LLMCallError, LLMResponseParseError +from app.core.json_utils import parse_json_response +from app.models.qa_models import ( + ImageQAPair, + ImageQuadrupleForQA, + QAPair, + TextTripleForQA, +) + +logger = logging.getLogger(__name__) + +DEFAULT_TEXT_QA_PROMPT = """基于以下知识三元组和原文证据片段,生成高质量问答对。 +要求: +1. 问题自然、具体,不能过于宽泛 +2. 答案基于原文片段,语言流畅 +3. 每个三元组生成1-2个问答对 + +以 JSON 数组格式返回:[{"question":"...","answer":"..."}] + +三元组数据: +""" + +DEFAULT_IMAGE_QA_PROMPT = """基于图片内容和以下四元组信息,生成高质量图文问答对。 +要求: +1. 问题需要结合图片才能回答 +2. 答案基于图片中的实际内容 +3. 每个四元组生成1个问答对 + +以 JSON 数组格式返回:[{"question":"...","answer":"..."}] + +四元组信息: +""" + + +def _parse_qa_pairs(raw: str) -> list[QAPair]: + items_raw = parse_json_response(raw) + result = [] + for item in items_raw: + try: + result.append(QAPair(question=item["question"], answer=item["answer"])) + except KeyError as e: + logger.warning(f"跳过不完整问答对: {item}, error: {e}") + return result + + +async def gen_text_qa( + items: list[TextTripleForQA], + model: str, + prompt_template: str, + llm: LLMClient, +) -> list[QAPair]: + triples_text = json.dumps([i.model_dump() for i in items], ensure_ascii=False, indent=2) + messages = [ + {"role": "system", "content": "你是专业的知识问答对生成助手。"}, + {"role": "user", "content": (prompt_template or DEFAULT_TEXT_QA_PROMPT) + triples_text}, + ] + try: + raw = await llm.chat(messages, model) + except Exception as e: + raise LLMCallError(f"GLM 调用失败: {e}") from e + logger.info(f"gen_text_qa model={model} items={len(items)}") + return _parse_qa_pairs(raw) + + +async def gen_image_qa( + items: list[ImageQuadrupleForQA], + model: str, + prompt_template: str, + llm: LLMClient, + storage: StorageClient, + bucket: str = "source-data", +) -> list[ImageQAPair]: + result = [] + prompt = prompt_template or DEFAULT_IMAGE_QA_PROMPT + for item in items: + presigned_url = storage.get_presigned_url(bucket, item.cropped_image_path) + quad_text = json.dumps( + {k: v for k, v in item.model_dump().items() if k != "cropped_image_path"}, + ensure_ascii=False, + ) + messages = [ + {"role": "system", "content": "你是专业的视觉问答对生成助手。"}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": presigned_url}}, + {"type": "text", "text": prompt + quad_text}, + ]}, + ] + try: + raw = await llm.chat_vision(messages, model) + except Exception as e: + raise LLMCallError(f"GLM-4V 调用失败: {e}") from e + for pair in _parse_qa_pairs(raw): + result.append(ImageQAPair(question=pair.question, answer=pair.answer, image_path=item.cropped_image_path)) + logger.info(f"gen_image_qa model={model} items={len(items)} pairs={len(result)}") + return result +``` + +- [ ] **Step 5: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_qa_service.py -v +``` + +Expected: `6 passed` + +- [ ] **Step 6: Commit** + +```bash +git add app/models/qa_models.py app/services/qa_service.py tests/test_qa_service.py +git commit -m "feat: QA models and service for text and image QA generation" +``` + +--- + +## Task 15: QA Router + +**Files:** +- Create: `app/routers/qa.py` +- Create: `tests/test_qa_router.py` + +- [ ] **Step 1: 编写失败测试** + +`tests/test_qa_router.py`: + +```python +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock +from app.main import app +from app.core.dependencies import set_clients + + +@pytest.fixture +def client(mock_llm, mock_storage): + set_clients(mock_llm, mock_storage) + return TestClient(app) + + +def test_gen_text_qa_success(client, mock_llm): + mock_llm.chat = AsyncMock(return_value='[{"question":"额定电压?","answer":"110kV"}]') + resp = client.post("/api/v1/qa/gen-text", json={ + "items": [{"subject": "变压器", "predicate": "额定电压", "object": "110kV", "source_snippet": "额定电压为110kV"}], + }) + assert resp.status_code == 200 + assert resp.json()["pairs"][0]["question"] == "额定电压?" + + +def test_gen_image_qa_success(client, mock_llm, mock_storage): + mock_llm.chat_vision = AsyncMock(return_value='[{"question":"图中是什么?","answer":"接头"}]') + mock_storage.get_presigned_url.return_value = "https://example.com/crop.jpg" + resp = client.post("/api/v1/qa/gen-image", json={ + "items": [{"subject": "A", "predicate": "B", "object": "C", "qualifier": "", "cropped_image_path": "crops/1/0.jpg"}], + }) + assert resp.status_code == 200 + data = resp.json() + assert data["pairs"][0]["image_path"] == "crops/1/0.jpg" +``` + +- [ ] **Step 2: 实现 `app/routers/qa.py`** + +```python +from fastapi import APIRouter, Depends + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.config import get_config +from app.core.dependencies import get_llm_client, get_storage_client +from app.models.qa_models import ImageQARequest, ImageQAResponse, TextQARequest, TextQAResponse +from app.services import qa_service + +router = APIRouter(tags=["QA"]) + + +@router.post("/qa/gen-text", response_model=TextQAResponse) +async def gen_text_qa( + req: TextQARequest, + llm: LLMClient = Depends(get_llm_client), +): + cfg = get_config() + pairs = await qa_service.gen_text_qa( + items=req.items, + model=req.model or cfg["models"]["default_text"], + prompt_template=req.prompt_template or qa_service.DEFAULT_TEXT_QA_PROMPT, + llm=llm, + ) + return TextQAResponse(pairs=pairs) + + +@router.post("/qa/gen-image", response_model=ImageQAResponse) +async def gen_image_qa( + req: ImageQARequest, + llm: LLMClient = Depends(get_llm_client), + storage: StorageClient = Depends(get_storage_client), +): + cfg = get_config() + pairs = await qa_service.gen_image_qa( + items=req.items, + model=req.model or cfg["models"]["default_vision"], + prompt_template=req.prompt_template or qa_service.DEFAULT_IMAGE_QA_PROMPT, + llm=llm, + storage=storage, + bucket=cfg["storage"]["buckets"]["source_data"], + ) + return ImageQAResponse(pairs=pairs) +``` + +- [ ] **Step 3: 在 `app/main.py` 注册路由** + +```python +from app.routers import text, image, video, qa +app.include_router(qa.router, prefix="/api/v1") +``` + +- [ ] **Step 4: 运行测试** + +```bash +conda run -n label pytest tests/test_qa_router.py -v +``` + +Expected: `2 passed` + +- [ ] **Step 5: Commit** + +```bash +git add app/routers/qa.py tests/test_qa_router.py app/main.py +git commit -m "feat: QA router POST /api/v1/qa/gen-text and /gen-image" +``` + +--- + +## Task 16: Finetune Models + Service + Router + +**Files:** +- Create: `app/models/finetune_models.py` +- Create: `app/services/finetune_service.py` +- Create: `app/routers/finetune.py` +- Create: `tests/test_finetune_service.py` +- Create: `tests/test_finetune_router.py` + +- [ ] **Step 1: 实现 `app/models/finetune_models.py`** + +```python +from pydantic import BaseModel + + +class FinetuneHyperparams(BaseModel): + learning_rate: float = 1e-4 + epochs: int = 3 + + +class FinetuneStartRequest(BaseModel): + jsonl_url: str + base_model: str + hyperparams: FinetuneHyperparams = FinetuneHyperparams() + + +class FinetuneStartResponse(BaseModel): + job_id: str + + +class FinetuneStatusResponse(BaseModel): + job_id: str + status: str # RUNNING | SUCCESS | FAILED + progress: int | None = None + error_message: str | None = None +``` + +- [ ] **Step 2: 编写失败测试** + +`tests/test_finetune_service.py`: + +```python +import pytest +from unittest.mock import MagicMock +from app.services.finetune_service import start_finetune, get_finetune_status +from app.models.finetune_models import FinetuneHyperparams + + +@pytest.mark.asyncio +async def test_start_finetune(): + mock_job = MagicMock() + mock_job.id = "glm-ft-abc123" + mock_zhipuai = MagicMock() + mock_zhipuai.fine_tuning.jobs.create.return_value = mock_job + + result = await start_finetune( + jsonl_url="https://example.com/export.jsonl", + base_model="glm-4-flash", + hyperparams=FinetuneHyperparams(learning_rate=1e-4, epochs=3), + client=mock_zhipuai, + ) + assert result == "glm-ft-abc123" + mock_zhipuai.fine_tuning.jobs.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_finetune_status_running(): + mock_job = MagicMock() + mock_job.status = "running" + mock_job.progress = 50 + mock_job.error = None + mock_zhipuai = MagicMock() + mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job + + result = await get_finetune_status("glm-ft-abc123", mock_zhipuai) + assert result.status == "RUNNING" + assert result.progress == 50 + assert result.job_id == "glm-ft-abc123" + + +@pytest.mark.asyncio +async def test_get_finetune_status_success(): + mock_job = MagicMock() + mock_job.status = "succeeded" + mock_job.progress = 100 + mock_job.error = None + mock_zhipuai = MagicMock() + mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job + + result = await get_finetune_status("glm-ft-abc123", mock_zhipuai) + assert result.status == "SUCCESS" +``` + +- [ ] **Step 3: 运行,确认失败** + +```bash +conda run -n label pytest tests/test_finetune_service.py -v +``` + +Expected: `ImportError` + +- [ ] **Step 4: 实现 `app/services/finetune_service.py`** + +```python +import logging + +from app.models.finetune_models import FinetuneHyperparams, FinetuneStatusResponse + +logger = logging.getLogger(__name__) + +_STATUS_MAP = { + "running": "RUNNING", + "succeeded": "SUCCESS", + "failed": "FAILED", +} + + +async def start_finetune( + jsonl_url: str, + base_model: str, + hyperparams: FinetuneHyperparams, + client, # ZhipuAI SDK client instance +) -> str: + job = client.fine_tuning.jobs.create( + training_file=jsonl_url, + model=base_model, + hyperparameters={ + "learning_rate_multiplier": hyperparams.learning_rate, + "n_epochs": hyperparams.epochs, + }, + ) + logger.info(f"finetune_start job_id={job.id} model={base_model}") + return job.id + + +async def get_finetune_status(job_id: str, client) -> FinetuneStatusResponse: + job = client.fine_tuning.jobs.retrieve(job_id) + status = _STATUS_MAP.get(job.status, "RUNNING") + return FinetuneStatusResponse( + job_id=job_id, + status=status, + progress=getattr(job, "progress", None), + error_message=getattr(job, "error", None), + ) +``` + +- [ ] **Step 5: 运行,确认通过** + +```bash +conda run -n label pytest tests/test_finetune_service.py -v +``` + +Expected: `3 passed` + +- [ ] **Step 6: 实现 `app/routers/finetune.py`** + +```python +from fastapi import APIRouter, Depends + +from app.clients.llm.base import LLMClient +from app.clients.llm.zhipuai_client import ZhipuAIClient +from app.core.dependencies import get_llm_client +from app.models.finetune_models import ( + FinetuneStartRequest, + FinetuneStartResponse, + FinetuneStatusResponse, +) +from app.services import finetune_service + +router = APIRouter(tags=["Finetune"]) + + +def _get_zhipuai(llm: LLMClient = Depends(get_llm_client)) -> ZhipuAIClient: + if not isinstance(llm, ZhipuAIClient): + raise RuntimeError("微调功能仅支持 ZhipuAI 后端") + return llm + + +@router.post("/finetune/start", response_model=FinetuneStartResponse) +async def start_finetune( + req: FinetuneStartRequest, + llm: ZhipuAIClient = Depends(_get_zhipuai), +): + job_id = await finetune_service.start_finetune( + jsonl_url=req.jsonl_url, + base_model=req.base_model, + hyperparams=req.hyperparams, + client=llm._client, + ) + return FinetuneStartResponse(job_id=job_id) + + +@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse) +async def get_finetune_status( + job_id: str, + llm: ZhipuAIClient = Depends(_get_zhipuai), +): + return await finetune_service.get_finetune_status(job_id, llm._client) +``` + +- [ ] **Step 7: 编写路由测试** + +`tests/test_finetune_router.py`: + +```python +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch +from app.main import app +from app.core.dependencies import set_clients +from app.clients.llm.zhipuai_client import ZhipuAIClient +from app.clients.storage.base import StorageClient + + +@pytest.fixture +def client(mock_storage): + with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI: + mock_sdk = MagicMock() + MockZhipuAI.return_value = mock_sdk + llm = ZhipuAIClient(api_key="test-key") + llm._mock_sdk = mock_sdk + set_clients(llm, mock_storage) + yield TestClient(app), mock_sdk + + +def test_start_finetune(client): + test_client, mock_sdk = client + mock_job = MagicMock() + mock_job.id = "glm-ft-xyz" + mock_sdk.fine_tuning.jobs.create.return_value = mock_job + + resp = test_client.post("/api/v1/finetune/start", json={ + "jsonl_url": "https://example.com/export.jsonl", + "base_model": "glm-4-flash", + "hyperparams": {"learning_rate": 1e-4, "epochs": 3}, + }) + assert resp.status_code == 200 + assert resp.json()["job_id"] == "glm-ft-xyz" + + +def test_get_finetune_status(client): + test_client, mock_sdk = client + mock_job = MagicMock() + mock_job.status = "running" + mock_job.progress = 30 + mock_job.error = None + mock_sdk.fine_tuning.jobs.retrieve.return_value = mock_job + + resp = test_client.get("/api/v1/finetune/status/glm-ft-xyz") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "RUNNING" + assert data["progress"] == 30 +``` + +- [ ] **Step 8: 在 `app/main.py` 注册路由(最终状态)** + +```python +from app.routers import text, image, video, qa, finetune + +app.include_router(text.router, prefix="/api/v1") +app.include_router(image.router, prefix="/api/v1") +app.include_router(video.router, prefix="/api/v1") +app.include_router(qa.router, prefix="/api/v1") +app.include_router(finetune.router, prefix="/api/v1") +``` + +- [ ] **Step 9: 运行全部测试** + +```bash +conda run -n label pytest tests/ -v +``` + +Expected: 所有测试通过,无失败 + +- [ ] **Step 10: Commit** + +```bash +git add app/models/finetune_models.py app/services/finetune_service.py app/routers/finetune.py tests/test_finetune_service.py tests/test_finetune_router.py app/main.py +git commit -m "feat: finetune models, service, and router - complete all endpoints" +``` + +--- + +## Task 17: 部署文件 + +**Files:** +- Create: `Dockerfile` +- Create: `docker-compose.yml` + +- [ ] **Step 1: 创建 `Dockerfile`** + +```dockerfile +FROM python:3.12-slim + +WORKDIR /app + +# OpenCV 系统依赖 +RUN apt-get update && apt-get install -y \ + libgl1 \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY app/ ./app/ +COPY config.yaml . +COPY .env . + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +- [ ] **Step 2: 创建 `docker-compose.yml`** + +```yaml +version: "3.9" + +services: + ai-service: + build: . + ports: + - "8000:8000" + env_file: + - .env + depends_on: + - rustfs + networks: + - label-net + restart: unless-stopped + + rustfs: + image: minio/minio:latest + command: server /data --console-address ":9001" + ports: + - "9000:9000" + - "9001:9001" + environment: + MINIO_ROOT_USER: minioadmin + MINIO_ROOT_PASSWORD: minioadmin + volumes: + - rustfs-data:/data + networks: + - label-net + +volumes: + rustfs-data: + +networks: + label-net: + driver: bridge +``` + +- [ ] **Step 3: 验证 Docker 构建** + +```bash +docker build -t label-ai-service:dev . +``` + +Expected: 镜像构建成功,无错误 + +- [ ] **Step 4: 运行全量测试,最终确认** + +```bash +conda run -n label pytest tests/ -v --tb=short +``` + +Expected: 所有测试通过 + +- [ ] **Step 5: Commit** + +```bash +git add Dockerfile docker-compose.yml +git commit -m "feat: Dockerfile and docker-compose for containerized deployment" +``` + +--- + +## 自审检查结果 + +**Spec coverage:** +- ✅ 文本三元组提取(TXT/PDF/DOCX)— Task 8-9 +- ✅ 图像四元组提取 + bbox 裁剪 — Task 10-11 +- ✅ 视频帧提取(interval/keyframe)— Task 12-13 +- ✅ 视频转文本(BackgroundTask)— Task 12-13 +- ✅ 文本问答对生成 — Task 14-15 +- ✅ 图像问答对生成 — Task 14-15 +- ✅ 微调任务提交与状态查询 — Task 16 +- ✅ LLMClient / StorageClient ABC 适配层 — Task 4-5 +- ✅ config.yaml + .env 分层配置 — Task 2 +- ✅ 结构化日志 + 请求日志 — Task 3 +- ✅ 全局异常处理 — Task 3 +- ✅ Swagger 文档(FastAPI 自动生成) — Task 6 +- ✅ Dockerfile + docker-compose — Task 17 +- ✅ pytest 测试覆盖全部 service 和 router — 各 Task + +**类型一致性:** `TripleItem.source_offset` 在 Task 7 定义,Task 8 使用;`VideoJobCallback` 在 Task 12 定义,Task 12 service 使用 — 一致。 + +**占位符:** 无 TBD / TODO,所有步骤均含完整代码。