Architecture fixes: - Image QA: presigned URL → base64 (RustFS is internal, GLM-4V is cloud) - Add GET /health endpoint + Docker healthcheck - Video size limit: add get_object_size() to StorageClient ABC, check before background task - Video size configurable via MAX_VIDEO_SIZE_MB env var (no image rebuild needed) - Fix image_service.py except clause redundancy (Exception absorbs KeyError/TypeError) Config additions: - video.max_file_size_mb: 200 in config.yaml - MAX_VIDEO_SIZE_MB env override in _ENV_OVERRIDES
2976 lines
83 KiB
Markdown
2976 lines
83 KiB
Markdown
# AI Service Implementation Plan
|
||
|
||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||
|
||
**Goal:** 实现 label_ai_service,一个 Python FastAPI 服务,为知识图谱标注平台提供文本三元组提取、图像四元组提取、视频处理、问答对生成和 GLM 微调管理能力。
|
||
|
||
**Architecture:** 分层架构:routers(HTTP 入口)→ services(业务逻辑)→ clients(外部适配层)。LLMClient 和 StorageClient 均为 ABC,当前分别实现 ZhipuAIClient 和 RustFSClient,通过 FastAPI Depends 注入,services 层不感知具体实现。视频任务用 FastAPI BackgroundTasks 异步执行,完成后回调 Java 后端。
|
||
|
||
**Tech Stack:** Python 3.12(conda `label` 环境),FastAPI,ZhipuAI SDK,boto3(S3),OpenCV,pdfplumber,python-docx,httpx,pytest
|
||
|
||
---
|
||
|
||
## Task 1: 项目脚手架
|
||
|
||
**Files:**
|
||
- Create: `app/__init__.py`
|
||
- Create: `app/core/__init__.py`
|
||
- Create: `app/clients/__init__.py`
|
||
- Create: `app/clients/llm/__init__.py`
|
||
- Create: `app/clients/storage/__init__.py`
|
||
- Create: `app/services/__init__.py`
|
||
- Create: `app/routers/__init__.py`
|
||
- Create: `app/models/__init__.py`
|
||
- Create: `tests/__init__.py`
|
||
- Create: `tests/conftest.py`
|
||
- Create: `config.yaml`
|
||
- Create: `.env`
|
||
- Create: `requirements.txt`
|
||
|
||
- [ ] **Step 1: 创建包目录结构**
|
||
|
||
```bash
|
||
mkdir -p app/core app/clients/llm app/clients/storage app/services app/routers app/models tests
|
||
touch app/__init__.py app/core/__init__.py
|
||
touch app/clients/__init__.py app/clients/llm/__init__.py app/clients/storage/__init__.py
|
||
touch app/services/__init__.py app/routers/__init__.py app/models/__init__.py
|
||
touch tests/__init__.py
|
||
```
|
||
|
||
- [ ] **Step 2: 创建 `config.yaml`**
|
||
|
||
```yaml
|
||
server:
|
||
port: 8000
|
||
log_level: INFO
|
||
|
||
storage:
|
||
buckets:
|
||
source_data: "source-data"
|
||
finetune_export: "finetune-export"
|
||
|
||
backend: {}
|
||
|
||
video:
|
||
frame_sample_count: 8
|
||
max_file_size_mb: 200
|
||
|
||
models:
|
||
default_text: "glm-4-flash"
|
||
default_vision: "glm-4v-flash"
|
||
```
|
||
|
||
- [ ] **Step 3: 创建 `.env`**
|
||
|
||
```ini
|
||
ZHIPUAI_API_KEY=your-zhipuai-api-key
|
||
STORAGE_ACCESS_KEY=minioadmin
|
||
STORAGE_SECRET_KEY=minioadmin
|
||
STORAGE_ENDPOINT=http://rustfs:9000
|
||
BACKEND_CALLBACK_URL=http://backend:8080/internal/video-job/callback
|
||
# MAX_VIDEO_SIZE_MB=200 # 可选,覆盖 config.yaml 中的视频大小上限
|
||
```
|
||
|
||
- [ ] **Step 4: 创建 `requirements.txt`**
|
||
|
||
```
|
||
fastapi>=0.111
|
||
uvicorn[standard]>=0.29
|
||
pydantic>=2.7
|
||
python-dotenv>=1.0
|
||
pyyaml>=6.0
|
||
zhipuai>=2.1
|
||
boto3>=1.34
|
||
pdfplumber>=0.11
|
||
python-docx>=1.1
|
||
opencv-python-headless>=4.9
|
||
numpy>=1.26
|
||
httpx>=0.27
|
||
pytest>=8.0
|
||
pytest-asyncio>=0.23
|
||
```
|
||
|
||
- [ ] **Step 5: 创建 `tests/conftest.py`**
|
||
|
||
```python
|
||
import pytest
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_llm():
|
||
client = MagicMock(spec=LLMClient)
|
||
client.chat = AsyncMock()
|
||
client.chat_vision = AsyncMock()
|
||
return client
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_storage():
|
||
client = MagicMock(spec=StorageClient)
|
||
client.download_bytes = AsyncMock()
|
||
client.upload_bytes = AsyncMock()
|
||
client.get_presigned_url = MagicMock(return_value="https://example.com/presigned/crop.jpg")
|
||
client.get_object_size = AsyncMock(return_value=10 * 1024 * 1024) # 默认 10MB,小于限制
|
||
return client
|
||
```
|
||
|
||
- [ ] **Step 6: 安装依赖**
|
||
|
||
```bash
|
||
conda run -n label pip install -r requirements.txt
|
||
```
|
||
|
||
Expected: 所有包安装成功,无错误
|
||
|
||
- [ ] **Step 7: Commit**
|
||
|
||
```bash
|
||
git add app/ tests/ config.yaml .env requirements.txt
|
||
git commit -m "feat: project scaffold - directory structure and config files"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 2: Core Config 模块
|
||
|
||
**Files:**
|
||
- Create: `app/core/config.py`
|
||
- Create: `tests/test_config.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_config.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from unittest.mock import patch, mock_open
|
||
from app.core.config import get_config
|
||
|
||
MOCK_YAML = """
|
||
server:
|
||
port: 8000
|
||
log_level: INFO
|
||
storage:
|
||
buckets:
|
||
source_data: "source-data"
|
||
finetune_export: "finetune-export"
|
||
backend: {}
|
||
video:
|
||
frame_sample_count: 8
|
||
models:
|
||
default_text: "glm-4-flash"
|
||
default_vision: "glm-4v-flash"
|
||
"""
|
||
|
||
|
||
def _fresh_config(monkeypatch, extra_env: dict = None):
|
||
"""每次测试前清除 lru_cache,设置环境变量。"""
|
||
get_config.cache_clear()
|
||
base_env = {
|
||
"ZHIPUAI_API_KEY": "test-key",
|
||
"STORAGE_ACCESS_KEY": "test-access",
|
||
"STORAGE_SECRET_KEY": "test-secret",
|
||
"STORAGE_ENDPOINT": "http://localhost:9000",
|
||
"BACKEND_CALLBACK_URL": "http://localhost:8080/callback",
|
||
}
|
||
if extra_env:
|
||
base_env.update(extra_env)
|
||
for k, v in base_env.items():
|
||
monkeypatch.setenv(k, v)
|
||
|
||
|
||
def test_env_overrides_yaml(monkeypatch):
|
||
_fresh_config(monkeypatch)
|
||
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
||
with patch("app.core.config.load_dotenv"):
|
||
cfg = get_config()
|
||
assert cfg["zhipuai"]["api_key"] == "test-key"
|
||
assert cfg["storage"]["access_key"] == "test-access"
|
||
assert cfg["storage"]["endpoint"] == "http://localhost:9000"
|
||
assert cfg["backend"]["callback_url"] == "http://localhost:8080/callback"
|
||
get_config.cache_clear()
|
||
|
||
|
||
def test_yaml_values_preserved(monkeypatch):
|
||
_fresh_config(monkeypatch)
|
||
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
||
with patch("app.core.config.load_dotenv"):
|
||
cfg = get_config()
|
||
assert cfg["models"]["default_text"] == "glm-4-flash"
|
||
assert cfg["video"]["frame_sample_count"] == 8
|
||
assert cfg["storage"]["buckets"]["source_data"] == "source-data"
|
||
get_config.cache_clear()
|
||
|
||
|
||
def test_missing_api_key_raises(monkeypatch):
|
||
get_config.cache_clear()
|
||
monkeypatch.delenv("ZHIPUAI_API_KEY", raising=False)
|
||
monkeypatch.setenv("STORAGE_ACCESS_KEY", "a")
|
||
monkeypatch.setenv("STORAGE_SECRET_KEY", "b")
|
||
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
||
with patch("app.core.config.load_dotenv"):
|
||
with pytest.raises(RuntimeError, match="ZHIPUAI_API_KEY"):
|
||
get_config()
|
||
get_config.cache_clear()
|
||
|
||
|
||
def test_missing_storage_key_raises(monkeypatch):
|
||
get_config.cache_clear()
|
||
monkeypatch.setenv("ZHIPUAI_API_KEY", "key")
|
||
monkeypatch.delenv("STORAGE_ACCESS_KEY", raising=False)
|
||
monkeypatch.setenv("STORAGE_SECRET_KEY", "b")
|
||
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
||
with patch("app.core.config.load_dotenv"):
|
||
with pytest.raises(RuntimeError, match="STORAGE_ACCESS_KEY"):
|
||
get_config()
|
||
get_config.cache_clear()
|
||
```
|
||
|
||
- [ ] **Step 2: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_config.py -v
|
||
```
|
||
|
||
Expected: `ImportError: cannot import name 'get_config'`
|
||
|
||
- [ ] **Step 3: 实现 `app/core/config.py`**
|
||
|
||
```python
|
||
import os
|
||
import yaml
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from dotenv import load_dotenv
|
||
|
||
_ROOT = Path(__file__).parent.parent.parent
|
||
|
||
_ENV_OVERRIDES = {
|
||
"ZHIPUAI_API_KEY": ["zhipuai", "api_key"],
|
||
"STORAGE_ACCESS_KEY": ["storage", "access_key"],
|
||
"STORAGE_SECRET_KEY": ["storage", "secret_key"],
|
||
"STORAGE_ENDPOINT": ["storage", "endpoint"],
|
||
"BACKEND_CALLBACK_URL": ["backend", "callback_url"],
|
||
"LOG_LEVEL": ["server", "log_level"],
|
||
"MAX_VIDEO_SIZE_MB": ["video", "max_file_size_mb"],
|
||
}
|
||
|
||
|
||
def _set_nested(d: dict, keys: list[str], value: str) -> None:
|
||
for k in keys[:-1]:
|
||
d = d.setdefault(k, {})
|
||
d[keys[-1]] = value
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def get_config() -> dict:
|
||
load_dotenv(_ROOT / ".env")
|
||
with open(_ROOT / "config.yaml", encoding="utf-8") as f:
|
||
cfg = yaml.safe_load(f)
|
||
for env_key, yaml_path in _ENV_OVERRIDES.items():
|
||
val = os.environ.get(env_key)
|
||
if val:
|
||
_set_nested(cfg, yaml_path, val)
|
||
_validate(cfg)
|
||
return cfg
|
||
|
||
|
||
def _validate(cfg: dict) -> None:
|
||
checks = [
|
||
(["zhipuai", "api_key"], "ZHIPUAI_API_KEY"),
|
||
(["storage", "access_key"], "STORAGE_ACCESS_KEY"),
|
||
(["storage", "secret_key"], "STORAGE_SECRET_KEY"),
|
||
]
|
||
for path, name in checks:
|
||
val = cfg
|
||
for k in path:
|
||
val = (val or {}).get(k, "")
|
||
if not val:
|
||
raise RuntimeError(f"缺少必要配置项:{name}")
|
||
```
|
||
|
||
- [ ] **Step 4: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_config.py -v
|
||
```
|
||
|
||
Expected: `4 passed`
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add app/core/config.py tests/test_config.py
|
||
git commit -m "feat: core config module with YAML + env layered loading"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 3: Core Logging、Exceptions、JSON Utils
|
||
|
||
**Files:**
|
||
- Create: `app/core/logging.py`
|
||
- Create: `app/core/exceptions.py`
|
||
- Create: `app/core/json_utils.py`
|
||
|
||
- [ ] **Step 1: 实现 `app/core/logging.py`**
|
||
|
||
```python
|
||
import json
|
||
import logging
|
||
import time
|
||
from typing import Callable
|
||
|
||
from fastapi import Request, Response
|
||
|
||
|
||
class _JSONFormatter(logging.Formatter):
|
||
def format(self, record: logging.LogRecord) -> str:
|
||
entry: dict = {
|
||
"time": self.formatTime(record),
|
||
"level": record.levelname,
|
||
"logger": record.name,
|
||
"message": record.getMessage(),
|
||
}
|
||
if record.exc_info:
|
||
entry["exception"] = self.formatException(record.exc_info)
|
||
return json.dumps(entry, ensure_ascii=False)
|
||
|
||
|
||
def setup_logging(log_level: str = "INFO") -> None:
|
||
handler = logging.StreamHandler()
|
||
handler.setFormatter(_JSONFormatter())
|
||
root = logging.getLogger()
|
||
root.handlers.clear()
|
||
root.addHandler(handler)
|
||
root.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
||
|
||
|
||
async def request_logging_middleware(request: Request, call_next: Callable) -> Response:
|
||
start = time.monotonic()
|
||
response = await call_next(request)
|
||
duration_ms = round((time.monotonic() - start) * 1000, 2)
|
||
logging.getLogger("api").info(
|
||
f"method={request.method} path={request.url.path} "
|
||
f"status={response.status_code} duration_ms={duration_ms}"
|
||
)
|
||
return response
|
||
```
|
||
|
||
- [ ] **Step 2: 实现 `app/core/exceptions.py`**
|
||
|
||
```python
|
||
import logging
|
||
from fastapi import Request
|
||
from fastapi.responses import JSONResponse
|
||
|
||
|
||
class UnsupportedFileTypeError(Exception):
|
||
def __init__(self, ext: str):
|
||
super().__init__(f"不支持的文件类型:{ext}")
|
||
|
||
|
||
class StorageDownloadError(Exception):
|
||
pass
|
||
|
||
|
||
class LLMResponseParseError(Exception):
|
||
pass
|
||
|
||
|
||
class LLMCallError(Exception):
|
||
pass
|
||
|
||
|
||
async def unsupported_file_type_handler(request: Request, exc: UnsupportedFileTypeError):
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={"code": "UNSUPPORTED_FILE_TYPE", "message": str(exc)},
|
||
)
|
||
|
||
|
||
async def storage_download_handler(request: Request, exc: StorageDownloadError):
|
||
return JSONResponse(
|
||
status_code=502,
|
||
content={"code": "STORAGE_ERROR", "message": str(exc)},
|
||
)
|
||
|
||
|
||
async def llm_parse_handler(request: Request, exc: LLMResponseParseError):
|
||
return JSONResponse(
|
||
status_code=502,
|
||
content={"code": "LLM_PARSE_ERROR", "message": str(exc)},
|
||
)
|
||
|
||
|
||
async def llm_call_handler(request: Request, exc: LLMCallError):
|
||
return JSONResponse(
|
||
status_code=503,
|
||
content={"code": "LLM_CALL_ERROR", "message": str(exc)},
|
||
)
|
||
|
||
|
||
async def generic_error_handler(request: Request, exc: Exception):
|
||
logging.getLogger("error").exception("未捕获异常")
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
||
)
|
||
```
|
||
|
||
- [ ] **Step 3: 实现 `app/core/json_utils.py`**
|
||
|
||
```python
|
||
import json
|
||
from app.core.exceptions import LLMResponseParseError
|
||
|
||
|
||
def parse_json_response(raw: str) -> list | dict:
|
||
"""从 GLM 响应中解析 JSON,兼容 markdown 代码块包裹格式。"""
|
||
content = raw.strip()
|
||
if "```json" in content:
|
||
content = content.split("```json")[1].split("```")[0]
|
||
elif "```" in content:
|
||
content = content.split("```")[1].split("```")[0]
|
||
content = content.strip()
|
||
try:
|
||
return json.loads(content)
|
||
except json.JSONDecodeError as e:
|
||
raise LLMResponseParseError(
|
||
f"GLM 返回内容无法解析为 JSON: {raw[:200]}"
|
||
) from e
|
||
```
|
||
|
||
- [ ] **Step 4: Commit**
|
||
|
||
```bash
|
||
git add app/core/logging.py app/core/exceptions.py app/core/json_utils.py
|
||
git commit -m "feat: core logging, exceptions, json utils"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 4: LLM 适配层
|
||
|
||
**Files:**
|
||
- Create: `app/clients/llm/base.py`
|
||
- Create: `app/clients/llm/zhipuai_client.py`
|
||
- Create: `tests/test_llm_client.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_llm_client.py`:
|
||
|
||
```python
|
||
import asyncio
|
||
import pytest
|
||
from unittest.mock import MagicMock, patch
|
||
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
||
|
||
|
||
@pytest.fixture
|
||
def zhipuai_client():
|
||
with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI:
|
||
mock_sdk = MagicMock()
|
||
MockZhipuAI.return_value = mock_sdk
|
||
client = ZhipuAIClient(api_key="test-key")
|
||
client._mock_sdk = mock_sdk
|
||
yield client
|
||
|
||
|
||
def test_chat_returns_content(zhipuai_client):
|
||
mock_resp = MagicMock()
|
||
mock_resp.choices[0].message.content = "三元组提取结果"
|
||
zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp
|
||
|
||
result = asyncio.run(
|
||
zhipuai_client.chat(
|
||
messages=[{"role": "user", "content": "提取三元组"}],
|
||
model="glm-4-flash",
|
||
)
|
||
)
|
||
assert result == "三元组提取结果"
|
||
zhipuai_client._mock_sdk.chat.completions.create.assert_called_once()
|
||
|
||
|
||
def test_chat_vision_calls_same_endpoint(zhipuai_client):
|
||
mock_resp = MagicMock()
|
||
mock_resp.choices[0].message.content = "图像分析结果"
|
||
zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp
|
||
|
||
result = asyncio.run(
|
||
zhipuai_client.chat_vision(
|
||
messages=[{"role": "user", "content": [{"type": "text", "text": "分析"}]}],
|
||
model="glm-4v-flash",
|
||
)
|
||
)
|
||
assert result == "图像分析结果"
|
||
```
|
||
|
||
- [ ] **Step 2: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_llm_client.py -v
|
||
```
|
||
|
||
Expected: `ImportError`
|
||
|
||
- [ ] **Step 3: 实现 `app/clients/llm/base.py`**
|
||
|
||
```python
|
||
from abc import ABC, abstractmethod
|
||
|
||
|
||
class LLMClient(ABC):
|
||
@abstractmethod
|
||
async def chat(self, messages: list[dict], model: str, **kwargs) -> str:
|
||
"""纯文本对话,返回模型输出文本。"""
|
||
|
||
@abstractmethod
|
||
async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str:
|
||
"""多模态对话(图文混合输入),返回模型输出文本。"""
|
||
```
|
||
|
||
- [ ] **Step 4: 实现 `app/clients/llm/zhipuai_client.py`**
|
||
|
||
```python
|
||
import asyncio
|
||
from zhipuai import ZhipuAI
|
||
from app.clients.llm.base import LLMClient
|
||
|
||
|
||
class ZhipuAIClient(LLMClient):
|
||
def __init__(self, api_key: str):
|
||
self._client = ZhipuAI(api_key=api_key)
|
||
|
||
async def chat(self, messages: list[dict], model: str, **kwargs) -> str:
|
||
loop = asyncio.get_event_loop()
|
||
resp = await loop.run_in_executor(
|
||
None,
|
||
lambda: self._client.chat.completions.create(
|
||
model=model, messages=messages, **kwargs
|
||
),
|
||
)
|
||
return resp.choices[0].message.content
|
||
|
||
async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str:
|
||
# GLM-4V 与文本接口相同,通过 image_url type 区分图文消息
|
||
return await self.chat(messages, model, **kwargs)
|
||
```
|
||
|
||
- [ ] **Step 5: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_llm_client.py -v
|
||
```
|
||
|
||
Expected: `2 passed`
|
||
|
||
- [ ] **Step 6: Commit**
|
||
|
||
```bash
|
||
git add app/clients/llm/ tests/test_llm_client.py
|
||
git commit -m "feat: LLMClient ABC and ZhipuAI implementation"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 5: Storage 适配层
|
||
|
||
**Files:**
|
||
- Create: `app/clients/storage/base.py`
|
||
- Create: `app/clients/storage/rustfs_client.py`
|
||
- Create: `tests/test_storage_client.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_storage_client.py`:
|
||
|
||
```python
|
||
import asyncio
|
||
import pytest
|
||
from unittest.mock import MagicMock, patch
|
||
from app.clients.storage.rustfs_client import RustFSClient
|
||
|
||
|
||
@pytest.fixture
|
||
def rustfs_client():
|
||
with patch("app.clients.storage.rustfs_client.boto3") as mock_boto3:
|
||
mock_s3 = MagicMock()
|
||
mock_boto3.client.return_value = mock_s3
|
||
client = RustFSClient(
|
||
endpoint="http://localhost:9000",
|
||
access_key="minioadmin",
|
||
secret_key="minioadmin",
|
||
)
|
||
client._mock_s3 = mock_s3
|
||
yield client
|
||
|
||
|
||
def test_download_bytes(rustfs_client):
|
||
mock_body = MagicMock()
|
||
mock_body.read.return_value = b"file content"
|
||
rustfs_client._mock_s3.get_object.return_value = {"Body": mock_body}
|
||
|
||
result = asyncio.run(
|
||
rustfs_client.download_bytes("source-data", "text/202404/1.txt")
|
||
)
|
||
assert result == b"file content"
|
||
rustfs_client._mock_s3.get_object.assert_called_once_with(
|
||
Bucket="source-data", Key="text/202404/1.txt"
|
||
)
|
||
|
||
|
||
def test_upload_bytes(rustfs_client):
|
||
asyncio.run(
|
||
rustfs_client.upload_bytes("source-data", "crops/1/0.jpg", b"img", "image/jpeg")
|
||
)
|
||
rustfs_client._mock_s3.put_object.assert_called_once_with(
|
||
Bucket="source-data", Key="crops/1/0.jpg", Body=b"img", ContentType="image/jpeg"
|
||
)
|
||
|
||
|
||
def test_get_presigned_url(rustfs_client):
|
||
rustfs_client._mock_s3.generate_presigned_url.return_value = "https://example.com/signed"
|
||
url = rustfs_client.get_presigned_url("source-data", "crops/1/0.jpg", expires=3600)
|
||
assert url == "https://example.com/signed"
|
||
rustfs_client._mock_s3.generate_presigned_url.assert_called_once_with(
|
||
"get_object",
|
||
Params={"Bucket": "source-data", "Key": "crops/1/0.jpg"},
|
||
ExpiresIn=3600,
|
||
)
|
||
|
||
|
||
def test_get_object_size(rustfs_client):
|
||
rustfs_client._mock_s3.head_object.return_value = {"ContentLength": 1024 * 1024 * 50}
|
||
size = asyncio.run(rustfs_client.get_object_size("source-data", "video/1.mp4"))
|
||
assert size == 1024 * 1024 * 50
|
||
rustfs_client._mock_s3.head_object.assert_called_once_with(
|
||
Bucket="source-data", Key="video/1.mp4"
|
||
)
|
||
```
|
||
|
||
- [ ] **Step 2: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_storage_client.py -v
|
||
```
|
||
|
||
Expected: `ImportError`
|
||
|
||
- [ ] **Step 3: 实现 `app/clients/storage/base.py`**
|
||
|
||
```python
|
||
from abc import ABC, abstractmethod
|
||
|
||
|
||
class StorageClient(ABC):
|
||
@abstractmethod
|
||
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||
"""从对象存储下载文件,返回字节内容。"""
|
||
|
||
@abstractmethod
|
||
async def upload_bytes(
|
||
self,
|
||
bucket: str,
|
||
path: str,
|
||
data: bytes,
|
||
content_type: str = "application/octet-stream",
|
||
) -> None:
|
||
"""上传字节内容到对象存储。"""
|
||
|
||
@abstractmethod
|
||
def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||
"""生成预签名访问 URL。"""
|
||
|
||
@abstractmethod
|
||
async def get_object_size(self, bucket: str, path: str) -> int:
|
||
"""返回对象字节大小,用于在下载前进行大小校验。"""
|
||
```
|
||
|
||
- [ ] **Step 4: 实现 `app/clients/storage/rustfs_client.py`**
|
||
|
||
```python
|
||
import asyncio
|
||
import boto3
|
||
from app.clients.storage.base import StorageClient
|
||
|
||
|
||
class RustFSClient(StorageClient):
|
||
def __init__(self, endpoint: str, access_key: str, secret_key: str):
|
||
self._s3 = boto3.client(
|
||
"s3",
|
||
endpoint_url=endpoint,
|
||
aws_access_key_id=access_key,
|
||
aws_secret_access_key=secret_key,
|
||
)
|
||
|
||
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||
loop = asyncio.get_event_loop()
|
||
resp = await loop.run_in_executor(
|
||
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
|
||
)
|
||
return resp["Body"].read()
|
||
|
||
async def upload_bytes(
|
||
self,
|
||
bucket: str,
|
||
path: str,
|
||
data: bytes,
|
||
content_type: str = "application/octet-stream",
|
||
) -> None:
|
||
loop = asyncio.get_event_loop()
|
||
await loop.run_in_executor(
|
||
None,
|
||
lambda: self._s3.put_object(
|
||
Bucket=bucket, Key=path, Body=data, ContentType=content_type
|
||
),
|
||
)
|
||
|
||
def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||
return self._s3.generate_presigned_url(
|
||
"get_object",
|
||
Params={"Bucket": bucket, "Key": path},
|
||
ExpiresIn=expires,
|
||
)
|
||
|
||
async def get_object_size(self, bucket: str, path: str) -> int:
|
||
loop = asyncio.get_event_loop()
|
||
resp = await loop.run_in_executor(
|
||
None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
|
||
)
|
||
return resp["ContentLength"]
|
||
```
|
||
|
||
- [ ] **Step 5: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_storage_client.py -v
|
||
```
|
||
|
||
Expected: `4 passed`
|
||
|
||
- [ ] **Step 6: Commit**
|
||
|
||
```bash
|
||
git add app/clients/storage/ tests/test_storage_client.py
|
||
git commit -m "feat: StorageClient ABC and RustFS S3 implementation"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 6: 依赖注入 + FastAPI 应用入口
|
||
|
||
**Files:**
|
||
- Create: `app/core/dependencies.py`
|
||
- Create: `app/main.py`
|
||
|
||
- [ ] **Step 1: 实现 `app/core/dependencies.py`**
|
||
|
||
```python
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
|
||
_llm_client: LLMClient | None = None
|
||
_storage_client: StorageClient | None = None
|
||
|
||
|
||
def set_clients(llm: LLMClient, storage: StorageClient) -> None:
|
||
global _llm_client, _storage_client
|
||
_llm_client, _storage_client = llm, storage
|
||
|
||
|
||
def get_llm_client() -> LLMClient:
|
||
return _llm_client
|
||
|
||
|
||
def get_storage_client() -> StorageClient:
|
||
return _storage_client
|
||
```
|
||
|
||
- [ ] **Step 2: 实现 `app/main.py`**
|
||
|
||
注意:routers 在后续任务中创建,先注释掉 include_router,待各路由实现后逐步取消注释。
|
||
|
||
```python
|
||
import logging
|
||
from contextlib import asynccontextmanager
|
||
|
||
from fastapi import FastAPI
|
||
|
||
from app.core.config import get_config
|
||
from app.core.dependencies import set_clients
|
||
from app.core.exceptions import (
|
||
LLMCallError,
|
||
LLMResponseParseError,
|
||
StorageDownloadError,
|
||
UnsupportedFileTypeError,
|
||
generic_error_handler,
|
||
llm_call_handler,
|
||
llm_parse_handler,
|
||
storage_download_handler,
|
||
unsupported_file_type_handler,
|
||
)
|
||
from app.core.logging import request_logging_middleware, setup_logging
|
||
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
||
from app.clients.storage.rustfs_client import RustFSClient
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
cfg = get_config()
|
||
setup_logging(cfg["server"]["log_level"])
|
||
set_clients(
|
||
llm=ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"]),
|
||
storage=RustFSClient(
|
||
endpoint=cfg["storage"]["endpoint"],
|
||
access_key=cfg["storage"]["access_key"],
|
||
secret_key=cfg["storage"]["secret_key"],
|
||
),
|
||
)
|
||
logging.getLogger("startup").info("AI 服务启动完成")
|
||
yield
|
||
logging.getLogger("startup").info("AI 服务关闭")
|
||
|
||
|
||
app = FastAPI(title="Label AI Service", version="1.0.0", lifespan=lifespan)
|
||
|
||
app.middleware("http")(request_logging_middleware)
|
||
|
||
|
||
@app.get("/health", tags=["Health"])
|
||
async def health():
|
||
return {"status": "ok"}
|
||
|
||
app.add_exception_handler(UnsupportedFileTypeError, unsupported_file_type_handler)
|
||
app.add_exception_handler(StorageDownloadError, storage_download_handler)
|
||
app.add_exception_handler(LLMResponseParseError, llm_parse_handler)
|
||
app.add_exception_handler(LLMCallError, llm_call_handler)
|
||
app.add_exception_handler(Exception, generic_error_handler)
|
||
|
||
# Routers registered after each task:
|
||
# from app.routers import text, image, video, qa, finetune
|
||
# app.include_router(text.router, prefix="/api/v1")
|
||
# app.include_router(image.router, prefix="/api/v1")
|
||
# app.include_router(video.router, prefix="/api/v1")
|
||
# app.include_router(qa.router, prefix="/api/v1")
|
||
# app.include_router(finetune.router, prefix="/api/v1")
|
||
```
|
||
|
||
- [ ] **Step 3: 验证 /health 端点**
|
||
|
||
```bash
|
||
conda run -n label python -c "
|
||
from fastapi.testclient import TestClient
|
||
from app.main import app
|
||
client = TestClient(app)
|
||
r = client.get('/health')
|
||
assert r.status_code == 200 and r.json() == {'status': 'ok'}, r.json()
|
||
print('health check OK')
|
||
"
|
||
```
|
||
|
||
Expected: `health check OK`
|
||
|
||
- [ ] **Step 4: Commit**
|
||
|
||
```bash
|
||
git add app/core/dependencies.py app/main.py
|
||
git commit -m "feat: DI dependencies, FastAPI app entry with lifespan and /health endpoint"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 7: Text Pydantic Models
|
||
|
||
**Files:**
|
||
- Create: `app/models/text_models.py`
|
||
|
||
- [ ] **Step 1: 实现 `app/models/text_models.py`**
|
||
|
||
```python
|
||
from pydantic import BaseModel
|
||
|
||
|
||
class SourceOffset(BaseModel):
|
||
start: int
|
||
end: int
|
||
|
||
|
||
class TripleItem(BaseModel):
|
||
subject: str
|
||
predicate: str
|
||
object: str
|
||
source_snippet: str
|
||
source_offset: SourceOffset
|
||
|
||
|
||
class TextExtractRequest(BaseModel):
|
||
file_path: str
|
||
file_name: str
|
||
model: str | None = None
|
||
prompt_template: str | None = None
|
||
|
||
|
||
class TextExtractResponse(BaseModel):
|
||
items: list[TripleItem]
|
||
```
|
||
|
||
- [ ] **Step 2: 快速验证 schema**
|
||
|
||
```bash
|
||
conda run -n label python -c "
|
||
from app.models.text_models import TextExtractRequest, TextExtractResponse, TripleItem, SourceOffset
|
||
req = TextExtractRequest(file_path='text/1.txt', file_name='1.txt')
|
||
item = TripleItem(subject='A', predicate='B', object='C', source_snippet='ABC', source_offset=SourceOffset(start=0, end=3))
|
||
resp = TextExtractResponse(items=[item])
|
||
print(resp.model_dump())
|
||
"
|
||
```
|
||
|
||
Expected: 打印出完整字典,无报错
|
||
|
||
- [ ] **Step 3: Commit**
|
||
|
||
```bash
|
||
git add app/models/text_models.py
|
||
git commit -m "feat: text Pydantic models"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 8: Text Service
|
||
|
||
**Files:**
|
||
- Create: `app/services/text_service.py`
|
||
- Create: `tests/test_text_service.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_text_service.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from app.services.text_service import extract_triples, _extract_text_from_bytes
|
||
from app.core.exceptions import UnsupportedFileTypeError, LLMResponseParseError, StorageDownloadError
|
||
|
||
TRIPLE_JSON = '[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压为110kV","source_offset":{"start":0,"end":10}}]'
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_triples_txt(mock_llm, mock_storage):
|
||
mock_storage.download_bytes.return_value = b"变压器额定电压为110kV"
|
||
mock_llm.chat.return_value = TRIPLE_JSON
|
||
|
||
result = await extract_triples(
|
||
file_path="text/1.txt",
|
||
file_name="test.txt",
|
||
model="glm-4-flash",
|
||
prompt_template="提取三元组:",
|
||
llm=mock_llm,
|
||
storage=mock_storage,
|
||
)
|
||
assert len(result) == 1
|
||
assert result[0].subject == "变压器"
|
||
assert result[0].predicate == "额定电压"
|
||
assert result[0].object == "110kV"
|
||
assert result[0].source_offset.start == 0
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_triples_markdown_wrapped_json(mock_llm, mock_storage):
|
||
mock_storage.download_bytes.return_value = b"some text"
|
||
mock_llm.chat.return_value = f"```json\n{TRIPLE_JSON}\n```"
|
||
|
||
result = await extract_triples(
|
||
file_path="text/1.txt",
|
||
file_name="test.txt",
|
||
model="glm-4-flash",
|
||
prompt_template="",
|
||
llm=mock_llm,
|
||
storage=mock_storage,
|
||
)
|
||
assert len(result) == 1
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_triples_storage_error(mock_llm, mock_storage):
|
||
mock_storage.download_bytes.side_effect = Exception("connection refused")
|
||
|
||
with pytest.raises(StorageDownloadError):
|
||
await extract_triples(
|
||
file_path="text/1.txt",
|
||
file_name="test.txt",
|
||
model="glm-4-flash",
|
||
prompt_template="",
|
||
llm=mock_llm,
|
||
storage=mock_storage,
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_triples_llm_parse_error(mock_llm, mock_storage):
|
||
mock_storage.download_bytes.return_value = b"some text"
|
||
mock_llm.chat.return_value = "这不是JSON"
|
||
|
||
with pytest.raises(LLMResponseParseError):
|
||
await extract_triples(
|
||
file_path="text/1.txt",
|
||
file_name="test.txt",
|
||
model="glm-4-flash",
|
||
prompt_template="",
|
||
llm=mock_llm,
|
||
storage=mock_storage,
|
||
)
|
||
|
||
|
||
def test_unsupported_file_type_raises():
|
||
with pytest.raises(UnsupportedFileTypeError):
|
||
_extract_text_from_bytes(b"content", "doc.xlsx")
|
||
|
||
|
||
def test_parse_txt_bytes():
|
||
result = _extract_text_from_bytes("你好世界".encode("utf-8"), "file.txt")
|
||
assert result == "你好世界"
|
||
```
|
||
|
||
- [ ] **Step 2: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_text_service.py -v
|
||
```
|
||
|
||
Expected: `ImportError`
|
||
|
||
- [ ] **Step 3: 实现 `app/services/text_service.py`**
|
||
|
||
```python
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError, UnsupportedFileTypeError
|
||
from app.core.json_utils import parse_json_response
|
||
from app.models.text_models import SourceOffset, TripleItem
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_PROMPT = """请从以下文本中提取知识三元组。
|
||
对每个三元组提供:
|
||
- subject:主语实体
|
||
- predicate:谓语关系
|
||
- object:宾语实体
|
||
- source_snippet:原文中的证据片段(直接引用原文)
|
||
- source_offset:证据片段字符偏移 {"start": N, "end": M}
|
||
|
||
以 JSON 数组格式返回,例如:
|
||
[{"subject":"...","predicate":"...","object":"...","source_snippet":"...","source_offset":{"start":0,"end":50}}]
|
||
|
||
文本内容:
|
||
"""
|
||
|
||
|
||
def _parse_txt(data: bytes) -> str:
|
||
return data.decode("utf-8")
|
||
|
||
|
||
def _parse_pdf(data: bytes) -> str:
|
||
import io
|
||
import pdfplumber
|
||
with pdfplumber.open(io.BytesIO(data)) as pdf:
|
||
return "\n".join(page.extract_text() or "" for page in pdf.pages)
|
||
|
||
|
||
def _parse_docx(data: bytes) -> str:
|
||
import io
|
||
import docx
|
||
doc = docx.Document(io.BytesIO(data))
|
||
return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
|
||
|
||
|
||
_PARSERS = {
|
||
".txt": _parse_txt,
|
||
".pdf": _parse_pdf,
|
||
".docx": _parse_docx,
|
||
}
|
||
|
||
|
||
def _extract_text_from_bytes(data: bytes, filename: str) -> str:
|
||
ext = Path(filename).suffix.lower()
|
||
parser = _PARSERS.get(ext)
|
||
if parser is None:
|
||
raise UnsupportedFileTypeError(ext)
|
||
return parser(data)
|
||
|
||
|
||
async def extract_triples(
|
||
file_path: str,
|
||
file_name: str,
|
||
model: str,
|
||
prompt_template: str,
|
||
llm: LLMClient,
|
||
storage: StorageClient,
|
||
bucket: str = "source-data",
|
||
) -> list[TripleItem]:
|
||
try:
|
||
data = await storage.download_bytes(bucket, file_path)
|
||
except Exception as e:
|
||
raise StorageDownloadError(f"下载文件失败 {file_path}: {e}") from e
|
||
|
||
text = _extract_text_from_bytes(data, file_name)
|
||
prompt = prompt_template or DEFAULT_PROMPT
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是专业的知识图谱构建助手,擅长从文本中提取结构化知识三元组。"},
|
||
{"role": "user", "content": prompt + text},
|
||
]
|
||
|
||
try:
|
||
raw = await llm.chat(messages, model)
|
||
except Exception as e:
|
||
raise LLMCallError(f"GLM 调用失败: {e}") from e
|
||
|
||
logger.info(f"text_extract file={file_path} model={model}")
|
||
|
||
items_raw = parse_json_response(raw)
|
||
|
||
result = []
|
||
for item in items_raw:
|
||
try:
|
||
offset = item.get("source_offset", {})
|
||
result.append(TripleItem(
|
||
subject=item["subject"],
|
||
predicate=item["predicate"],
|
||
object=item["object"],
|
||
source_snippet=item.get("source_snippet", ""),
|
||
source_offset=SourceOffset(
|
||
start=offset.get("start", 0),
|
||
end=offset.get("end", 0),
|
||
),
|
||
))
|
||
except (KeyError, TypeError) as e:
|
||
logger.warning(f"跳过不完整三元组: {item}, error: {e}")
|
||
|
||
return result
|
||
```
|
||
|
||
- [ ] **Step 4: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_text_service.py -v
|
||
```
|
||
|
||
Expected: `6 passed`
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add app/services/text_service.py tests/test_text_service.py
|
||
git commit -m "feat: text service with txt/pdf/docx parsing and triple extraction"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 9: Text Router
|
||
|
||
**Files:**
|
||
- Create: `app/routers/text.py`
|
||
- Create: `tests/test_text_router.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_text_router.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from unittest.mock import AsyncMock, patch
|
||
from app.main import app
|
||
from app.core.dependencies import set_clients
|
||
from app.models.text_models import TripleItem, SourceOffset
|
||
|
||
|
||
@pytest.fixture
|
||
def client(mock_llm, mock_storage):
|
||
set_clients(mock_llm, mock_storage)
|
||
return TestClient(app)
|
||
|
||
|
||
def test_text_extract_success(client, mock_llm, mock_storage):
|
||
mock_storage.download_bytes = AsyncMock(return_value=b"变压器额定电压110kV")
|
||
mock_llm.chat = AsyncMock(return_value='[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压110kV","source_offset":{"start":3,"end":10}}]')
|
||
|
||
resp = client.post("/api/v1/text/extract", json={
|
||
"file_path": "text/202404/1.txt",
|
||
"file_name": "规范.txt",
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data["items"]) == 1
|
||
assert data["items"][0]["subject"] == "变压器"
|
||
|
||
|
||
def test_text_extract_unsupported_file(client, mock_llm, mock_storage):
|
||
mock_storage.download_bytes = AsyncMock(return_value=b"content")
|
||
resp = client.post("/api/v1/text/extract", json={
|
||
"file_path": "text/202404/1.xlsx",
|
||
"file_name": "file.xlsx",
|
||
})
|
||
assert resp.status_code == 400
|
||
assert resp.json()["code"] == "UNSUPPORTED_FILE_TYPE"
|
||
```
|
||
|
||
- [ ] **Step 2: 实现 `app/routers/text.py`**
|
||
|
||
```python
|
||
from fastapi import APIRouter, Depends
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.config import get_config
|
||
from app.core.dependencies import get_llm_client, get_storage_client
|
||
from app.models.text_models import TextExtractRequest, TextExtractResponse
|
||
from app.services import text_service
|
||
|
||
router = APIRouter(tags=["Text"])
|
||
|
||
|
||
@router.post("/text/extract", response_model=TextExtractResponse)
|
||
async def extract_text(
|
||
req: TextExtractRequest,
|
||
llm: LLMClient = Depends(get_llm_client),
|
||
storage: StorageClient = Depends(get_storage_client),
|
||
):
|
||
cfg = get_config()
|
||
model = req.model or cfg["models"]["default_text"]
|
||
prompt = req.prompt_template or text_service.DEFAULT_PROMPT
|
||
|
||
items = await text_service.extract_triples(
|
||
file_path=req.file_path,
|
||
file_name=req.file_name,
|
||
model=model,
|
||
prompt_template=prompt,
|
||
llm=llm,
|
||
storage=storage,
|
||
bucket=cfg["storage"]["buckets"]["source_data"],
|
||
)
|
||
return TextExtractResponse(items=items)
|
||
```
|
||
|
||
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
||
|
||
取消注释以下两行:
|
||
|
||
```python
|
||
from app.routers import text
|
||
app.include_router(text.router, prefix="/api/v1")
|
||
```
|
||
|
||
- [ ] **Step 4: 运行测试**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_text_router.py -v
|
||
```
|
||
|
||
Expected: `2 passed`
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add app/routers/text.py tests/test_text_router.py app/main.py
|
||
git commit -m "feat: text router POST /api/v1/text/extract"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 10: Image Models + Service
|
||
|
||
**Files:**
|
||
- Create: `app/models/image_models.py`
|
||
- Create: `app/services/image_service.py`
|
||
- Create: `tests/test_image_service.py`
|
||
|
||
- [ ] **Step 1: 实现 `app/models/image_models.py`**
|
||
|
||
```python
|
||
from pydantic import BaseModel
|
||
|
||
|
||
class BBox(BaseModel):
|
||
x: int
|
||
y: int
|
||
w: int
|
||
h: int
|
||
|
||
|
||
class QuadrupleItem(BaseModel):
|
||
subject: str
|
||
predicate: str
|
||
object: str
|
||
qualifier: str
|
||
bbox: BBox
|
||
cropped_image_path: str
|
||
|
||
|
||
class ImageExtractRequest(BaseModel):
|
||
file_path: str
|
||
task_id: int
|
||
model: str | None = None
|
||
prompt_template: str | None = None
|
||
|
||
|
||
class ImageExtractResponse(BaseModel):
|
||
items: list[QuadrupleItem]
|
||
```
|
||
|
||
- [ ] **Step 2: 编写失败测试**
|
||
|
||
`tests/test_image_service.py`:
|
||
|
||
```python
|
||
import pytest
|
||
import numpy as np
|
||
import cv2
|
||
from app.services.image_service import extract_quadruples, _crop_image
|
||
from app.models.image_models import BBox
|
||
from app.core.exceptions import LLMResponseParseError, StorageDownloadError
|
||
|
||
QUAD_JSON = '[{"subject":"电缆接头","predicate":"位于","object":"配电箱左侧","qualifier":"2024年","bbox":{"x":10,"y":20,"w":50,"h":40}}]'
|
||
|
||
|
||
def _make_test_image_bytes(width=200, height=200) -> bytes:
|
||
img = np.zeros((height, width, 3), dtype=np.uint8)
|
||
img[:] = (100, 150, 200)
|
||
_, buf = cv2.imencode(".jpg", img)
|
||
return buf.tobytes()
|
||
|
||
|
||
def test_crop_image():
|
||
img_bytes = _make_test_image_bytes(200, 200)
|
||
bbox = BBox(x=10, y=20, w=50, h=40)
|
||
result = _crop_image(img_bytes, bbox)
|
||
assert isinstance(result, bytes)
|
||
arr = np.frombuffer(result, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
assert img.shape[0] == 40 # height
|
||
assert img.shape[1] == 50 # width
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_quadruples_success(mock_llm, mock_storage):
|
||
mock_storage.download_bytes.return_value = _make_test_image_bytes()
|
||
mock_llm.chat_vision.return_value = QUAD_JSON
|
||
mock_storage.upload_bytes.return_value = None
|
||
|
||
result = await extract_quadruples(
|
||
file_path="image/202404/1.jpg",
|
||
task_id=789,
|
||
model="glm-4v-flash",
|
||
prompt_template="提取四元组",
|
||
llm=mock_llm,
|
||
storage=mock_storage,
|
||
)
|
||
assert len(result) == 1
|
||
assert result[0].subject == "电缆接头"
|
||
assert result[0].cropped_image_path == "crops/789/0.jpg"
|
||
mock_storage.upload_bytes.assert_called_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_quadruples_storage_error(mock_llm, mock_storage):
|
||
mock_storage.download_bytes.side_effect = Exception("timeout")
|
||
with pytest.raises(StorageDownloadError):
|
||
await extract_quadruples(
|
||
file_path="image/1.jpg",
|
||
task_id=1,
|
||
model="glm-4v-flash",
|
||
prompt_template="",
|
||
llm=mock_llm,
|
||
storage=mock_storage,
|
||
)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_quadruples_parse_error(mock_llm, mock_storage):
|
||
mock_storage.download_bytes.return_value = _make_test_image_bytes()
|
||
mock_llm.chat_vision.return_value = "不是JSON"
|
||
with pytest.raises(LLMResponseParseError):
|
||
await extract_quadruples(
|
||
file_path="image/1.jpg",
|
||
task_id=1,
|
||
model="glm-4v-flash",
|
||
prompt_template="",
|
||
llm=mock_llm,
|
||
storage=mock_storage,
|
||
)
|
||
```
|
||
|
||
- [ ] **Step 3: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_image_service.py -v
|
||
```
|
||
|
||
Expected: `ImportError`
|
||
|
||
- [ ] **Step 4: 实现 `app/services/image_service.py`**
|
||
|
||
```python
|
||
import base64
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
import numpy as np
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError
|
||
from app.core.json_utils import parse_json_response
|
||
from app.models.image_models import BBox, QuadrupleItem
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_PROMPT = """请分析这张图片,提取知识四元组。
|
||
对每个四元组提供:
|
||
- subject:主体实体
|
||
- predicate:关系/属性
|
||
- object:客体实体
|
||
- qualifier:修饰信息(时间、条件、场景,无则填空字符串)
|
||
- bbox:边界框 {"x": N, "y": N, "w": N, "h": N}(像素坐标,相对原图)
|
||
|
||
以 JSON 数组格式返回:
|
||
[{"subject":"...","predicate":"...","object":"...","qualifier":"...","bbox":{"x":0,"y":0,"w":100,"h":100}}]
|
||
"""
|
||
|
||
|
||
def _crop_image(image_bytes: bytes, bbox: BBox) -> bytes:
|
||
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
||
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||
h, w = img.shape[:2]
|
||
x = max(0, bbox.x)
|
||
y = max(0, bbox.y)
|
||
x2 = min(w, bbox.x + bbox.w)
|
||
y2 = min(h, bbox.y + bbox.h)
|
||
cropped = img[y:y2, x:x2]
|
||
_, buf = cv2.imencode(".jpg", cropped, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||
return buf.tobytes()
|
||
|
||
|
||
async def extract_quadruples(
|
||
file_path: str,
|
||
task_id: int,
|
||
model: str,
|
||
prompt_template: str,
|
||
llm: LLMClient,
|
||
storage: StorageClient,
|
||
source_bucket: str = "source-data",
|
||
) -> list[QuadrupleItem]:
|
||
try:
|
||
data = await storage.download_bytes(source_bucket, file_path)
|
||
except Exception as e:
|
||
raise StorageDownloadError(f"下载图片失败 {file_path}: {e}") from e
|
||
|
||
ext = Path(file_path).suffix.lstrip(".") or "jpeg"
|
||
b64 = base64.b64encode(data).decode()
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是专业的视觉分析助手,擅长从图像中提取结构化知识四元组。"},
|
||
{"role": "user", "content": [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}},
|
||
{"type": "text", "text": prompt_template or DEFAULT_PROMPT},
|
||
]},
|
||
]
|
||
|
||
try:
|
||
raw = await llm.chat_vision(messages, model)
|
||
except Exception as e:
|
||
raise LLMCallError(f"GLM-4V 调用失败: {e}") from e
|
||
|
||
logger.info(f"image_extract file={file_path} model={model}")
|
||
items_raw = parse_json_response(raw)
|
||
|
||
result = []
|
||
for i, item in enumerate(items_raw):
|
||
try:
|
||
bbox = BBox(**item["bbox"])
|
||
cropped = _crop_image(data, bbox)
|
||
crop_path = f"crops/{task_id}/{i}.jpg"
|
||
await storage.upload_bytes(source_bucket, crop_path, cropped, "image/jpeg")
|
||
result.append(QuadrupleItem(
|
||
subject=item["subject"],
|
||
predicate=item["predicate"],
|
||
object=item["object"],
|
||
qualifier=item.get("qualifier", ""),
|
||
bbox=bbox,
|
||
cropped_image_path=crop_path,
|
||
))
|
||
except Exception as e:
|
||
logger.warning(f"跳过不完整四元组 index={i}: {e}")
|
||
|
||
return result
|
||
```
|
||
|
||
- [ ] **Step 5: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_image_service.py -v
|
||
```
|
||
|
||
Expected: `4 passed`
|
||
|
||
- [ ] **Step 6: Commit**
|
||
|
||
```bash
|
||
git add app/models/image_models.py app/services/image_service.py tests/test_image_service.py
|
||
git commit -m "feat: image models, service with bbox crop and quadruple extraction"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 11: Image Router
|
||
|
||
**Files:**
|
||
- Create: `app/routers/image.py`
|
||
- Create: `tests/test_image_router.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_image_router.py`:
|
||
|
||
```python
|
||
import numpy as np
|
||
import cv2
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from unittest.mock import AsyncMock
|
||
from app.main import app
|
||
from app.core.dependencies import set_clients
|
||
|
||
|
||
def _make_image_bytes() -> bytes:
|
||
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||
_, buf = cv2.imencode(".jpg", img)
|
||
return buf.tobytes()
|
||
|
||
|
||
@pytest.fixture
|
||
def client(mock_llm, mock_storage):
|
||
set_clients(mock_llm, mock_storage)
|
||
return TestClient(app)
|
||
|
||
|
||
def test_image_extract_success(client, mock_llm, mock_storage):
|
||
mock_storage.download_bytes = AsyncMock(return_value=_make_image_bytes())
|
||
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
||
mock_llm.chat_vision = AsyncMock(return_value='[{"subject":"A","predicate":"B","object":"C","qualifier":"","bbox":{"x":0,"y":0,"w":10,"h":10}}]')
|
||
|
||
resp = client.post("/api/v1/image/extract", json={
|
||
"file_path": "image/202404/1.jpg",
|
||
"task_id": 42,
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert len(data["items"]) == 1
|
||
assert data["items"][0]["cropped_image_path"] == "crops/42/0.jpg"
|
||
```
|
||
|
||
- [ ] **Step 2: 实现 `app/routers/image.py`**
|
||
|
||
```python
|
||
from fastapi import APIRouter, Depends
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.config import get_config
|
||
from app.core.dependencies import get_llm_client, get_storage_client
|
||
from app.models.image_models import ImageExtractRequest, ImageExtractResponse
|
||
from app.services import image_service
|
||
|
||
router = APIRouter(tags=["Image"])
|
||
|
||
|
||
@router.post("/image/extract", response_model=ImageExtractResponse)
|
||
async def extract_image(
|
||
req: ImageExtractRequest,
|
||
llm: LLMClient = Depends(get_llm_client),
|
||
storage: StorageClient = Depends(get_storage_client),
|
||
):
|
||
cfg = get_config()
|
||
model = req.model or cfg["models"]["default_vision"]
|
||
prompt = req.prompt_template or image_service.DEFAULT_PROMPT
|
||
|
||
items = await image_service.extract_quadruples(
|
||
file_path=req.file_path,
|
||
task_id=req.task_id,
|
||
model=model,
|
||
prompt_template=prompt,
|
||
llm=llm,
|
||
storage=storage,
|
||
source_bucket=cfg["storage"]["buckets"]["source_data"],
|
||
)
|
||
return ImageExtractResponse(items=items)
|
||
```
|
||
|
||
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
||
|
||
```python
|
||
from app.routers import text, image
|
||
app.include_router(image.router, prefix="/api/v1")
|
||
```
|
||
|
||
- [ ] **Step 4: 运行测试**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_image_router.py -v
|
||
```
|
||
|
||
Expected: `1 passed`
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add app/routers/image.py tests/test_image_router.py app/main.py
|
||
git commit -m "feat: image router POST /api/v1/image/extract"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 12: Video Models + Service
|
||
|
||
**Files:**
|
||
- Create: `app/models/video_models.py`
|
||
- Create: `app/services/video_service.py`
|
||
- Create: `tests/test_video_service.py`
|
||
|
||
- [ ] **Step 1: 实现 `app/models/video_models.py`**
|
||
|
||
```python
|
||
from pydantic import BaseModel
|
||
|
||
|
||
class ExtractFramesRequest(BaseModel):
|
||
file_path: str
|
||
source_id: int
|
||
job_id: int
|
||
mode: str = "interval" # interval | keyframe
|
||
frame_interval: int = 30
|
||
|
||
|
||
class ExtractFramesResponse(BaseModel):
|
||
message: str
|
||
job_id: int
|
||
|
||
|
||
class FrameInfo(BaseModel):
|
||
frame_index: int
|
||
time_sec: float
|
||
frame_path: str
|
||
|
||
|
||
class VideoToTextRequest(BaseModel):
|
||
file_path: str
|
||
source_id: int
|
||
job_id: int
|
||
start_sec: float = 0.0
|
||
end_sec: float
|
||
model: str | None = None
|
||
prompt_template: str | None = None
|
||
|
||
|
||
class VideoToTextResponse(BaseModel):
|
||
message: str
|
||
job_id: int
|
||
|
||
|
||
class VideoJobCallback(BaseModel):
|
||
job_id: int
|
||
status: str # SUCCESS | FAILED
|
||
frames: list[FrameInfo] | None = None
|
||
output_path: str | None = None
|
||
error_message: str | None = None
|
||
```
|
||
|
||
- [ ] **Step 2: 编写失败测试**
|
||
|
||
`tests/test_video_service.py`:
|
||
|
||
```python
|
||
import numpy as np
|
||
import pytest
|
||
from unittest.mock import AsyncMock, patch, MagicMock
|
||
from app.services.video_service import _is_scene_change, extract_frames_background
|
||
|
||
|
||
def test_is_scene_change_different_frames():
|
||
prev = np.zeros((100, 100), dtype=np.uint8)
|
||
curr = np.full((100, 100), 200, dtype=np.uint8)
|
||
assert _is_scene_change(prev, curr, threshold=30.0) is True
|
||
|
||
|
||
def test_is_scene_change_similar_frames():
|
||
prev = np.full((100, 100), 100, dtype=np.uint8)
|
||
curr = np.full((100, 100), 101, dtype=np.uint8)
|
||
assert _is_scene_change(prev, curr, threshold=30.0) is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_frames_background_calls_callback_on_success(mock_storage):
|
||
import cv2
|
||
import tempfile, os
|
||
|
||
# 创建一个有效的真实测试视频(5帧,10x10)
|
||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
||
tmp_path = f.name
|
||
|
||
out = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*"mp4v"), 10, (10, 10))
|
||
for _ in range(5):
|
||
out.write(np.zeros((10, 10, 3), dtype=np.uint8))
|
||
out.release()
|
||
|
||
with open(tmp_path, "rb") as f:
|
||
video_bytes = f.read()
|
||
os.unlink(tmp_path)
|
||
|
||
mock_storage.download_bytes.return_value = video_bytes
|
||
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
||
|
||
with patch("app.services.video_service.httpx") as mock_httpx:
|
||
mock_client = AsyncMock()
|
||
mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||
mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
mock_client.post = AsyncMock()
|
||
|
||
await extract_frames_background(
|
||
file_path="video/1.mp4",
|
||
source_id=10,
|
||
job_id=42,
|
||
mode="interval",
|
||
frame_interval=1,
|
||
storage=mock_storage,
|
||
callback_url="http://backend/callback",
|
||
)
|
||
|
||
mock_client.post.assert_called_once()
|
||
call_kwargs = mock_client.post.call_args
|
||
payload = call_kwargs.kwargs.get("json") or call_kwargs.args[1] if len(call_kwargs.args) > 1 else call_kwargs.kwargs["json"]
|
||
assert payload["job_id"] == 42
|
||
assert payload["status"] == "SUCCESS"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_frames_background_calls_callback_on_failure(mock_storage):
|
||
mock_storage.download_bytes.side_effect = Exception("storage error")
|
||
|
||
with patch("app.services.video_service.httpx") as mock_httpx:
|
||
mock_client = AsyncMock()
|
||
mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
||
mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
mock_client.post = AsyncMock()
|
||
|
||
await extract_frames_background(
|
||
file_path="video/1.mp4",
|
||
source_id=10,
|
||
job_id=99,
|
||
mode="interval",
|
||
frame_interval=30,
|
||
storage=mock_storage,
|
||
callback_url="http://backend/callback",
|
||
)
|
||
|
||
mock_client.post.assert_called_once()
|
||
call_kwargs = mock_client.post.call_args
|
||
payload = call_kwargs.kwargs.get("json") or (call_kwargs.args[1] if len(call_kwargs.args) > 1 else {})
|
||
assert payload["status"] == "FAILED"
|
||
assert payload["job_id"] == 99
|
||
```
|
||
|
||
- [ ] **Step 3: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_video_service.py -v
|
||
```
|
||
|
||
Expected: `ImportError`
|
||
|
||
- [ ] **Step 4: 实现 `app/services/video_service.py`**
|
||
|
||
```python
|
||
import base64
|
||
import logging
|
||
import tempfile
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
import httpx
|
||
import numpy as np
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.exceptions import LLMCallError
|
||
from app.models.video_models import FrameInfo, VideoJobCallback
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_VIDEO_TO_TEXT_PROMPT = """请分析这段视频的帧序列,用中文详细描述:
|
||
1. 视频中出现的主要对象、设备、人物
|
||
2. 发生的主要动作、操作步骤
|
||
3. 场景的整体情况
|
||
|
||
请输出结构化的文字描述,适合作为知识图谱构建的文本素材。"""
|
||
|
||
|
||
def _is_scene_change(prev: np.ndarray, curr: np.ndarray, threshold: float = 30.0) -> bool:
|
||
"""通过帧差分均值判断是否发生场景切换。"""
|
||
diff = cv2.absdiff(prev, curr)
|
||
return float(diff.mean()) > threshold
|
||
|
||
|
||
def _extract_frames(
|
||
video_path: str, mode: str, frame_interval: int
|
||
) -> list[tuple[int, float, bytes]]:
|
||
cap = cv2.VideoCapture(video_path)
|
||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||
results = []
|
||
prev_gray = None
|
||
idx = 0
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
time_sec = idx / fps
|
||
if mode == "interval":
|
||
if idx % frame_interval == 0:
|
||
_, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||
results.append((idx, time_sec, buf.tobytes()))
|
||
else:
|
||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||
if prev_gray is None or _is_scene_change(prev_gray, gray):
|
||
_, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||
results.append((idx, time_sec, buf.tobytes()))
|
||
prev_gray = gray
|
||
idx += 1
|
||
|
||
cap.release()
|
||
return results
|
||
|
||
|
||
def _sample_frames_as_base64(
|
||
video_path: str, start_sec: float, end_sec: float, count: int
|
||
) -> list[str]:
|
||
cap = cv2.VideoCapture(video_path)
|
||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||
start_frame = int(start_sec * fps)
|
||
end_frame = int(end_sec * fps)
|
||
total = max(1, end_frame - start_frame)
|
||
step = max(1, total // count)
|
||
results = []
|
||
for i in range(count):
|
||
frame_pos = start_frame + i * step
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
|
||
ret, frame = cap.read()
|
||
if ret:
|
||
_, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||
results.append(base64.b64encode(buf.tobytes()).decode())
|
||
cap.release()
|
||
return results
|
||
|
||
|
||
async def _send_callback(url: str, payload: VideoJobCallback) -> None:
|
||
async with httpx.AsyncClient(timeout=10) as client:
|
||
try:
|
||
await client.post(url, json=payload.model_dump())
|
||
except Exception as e:
|
||
logger.warning(f"回调失败 url={url}: {e}")
|
||
|
||
|
||
async def extract_frames_background(
|
||
file_path: str,
|
||
source_id: int,
|
||
job_id: int,
|
||
mode: str,
|
||
frame_interval: int,
|
||
storage: StorageClient,
|
||
callback_url: str,
|
||
bucket: str = "source-data",
|
||
) -> None:
|
||
try:
|
||
data = await storage.download_bytes(bucket, file_path)
|
||
except Exception as e:
|
||
await _send_callback(callback_url, VideoJobCallback(
|
||
job_id=job_id, status="FAILED", error_message=str(e)
|
||
))
|
||
return
|
||
|
||
suffix = Path(file_path).suffix or ".mp4"
|
||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||
tmp.write(data)
|
||
tmp_path = tmp.name
|
||
|
||
try:
|
||
frames = _extract_frames(tmp_path, mode, frame_interval)
|
||
frame_infos = []
|
||
for i, (frame_idx, time_sec, frame_data) in enumerate(frames):
|
||
frame_path = f"frames/{source_id}/{i}.jpg"
|
||
await storage.upload_bytes(bucket, frame_path, frame_data, "image/jpeg")
|
||
frame_infos.append(FrameInfo(
|
||
frame_index=frame_idx,
|
||
time_sec=round(time_sec, 3),
|
||
frame_path=frame_path,
|
||
))
|
||
await _send_callback(callback_url, VideoJobCallback(
|
||
job_id=job_id, status="SUCCESS", frames=frame_infos
|
||
))
|
||
logger.info(f"extract_frames job_id={job_id} frames={len(frame_infos)}")
|
||
except Exception as e:
|
||
logger.exception(f"extract_frames failed job_id={job_id}")
|
||
await _send_callback(callback_url, VideoJobCallback(
|
||
job_id=job_id, status="FAILED", error_message=str(e)
|
||
))
|
||
finally:
|
||
Path(tmp_path).unlink(missing_ok=True)
|
||
|
||
|
||
async def video_to_text_background(
|
||
file_path: str,
|
||
source_id: int,
|
||
job_id: int,
|
||
start_sec: float,
|
||
end_sec: float,
|
||
model: str,
|
||
prompt_template: str,
|
||
frame_sample_count: int,
|
||
llm: LLMClient,
|
||
storage: StorageClient,
|
||
callback_url: str,
|
||
bucket: str = "source-data",
|
||
) -> None:
|
||
try:
|
||
data = await storage.download_bytes(bucket, file_path)
|
||
except Exception as e:
|
||
await _send_callback(callback_url, VideoJobCallback(
|
||
job_id=job_id, status="FAILED", error_message=str(e)
|
||
))
|
||
return
|
||
|
||
suffix = Path(file_path).suffix or ".mp4"
|
||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||
tmp.write(data)
|
||
tmp_path = tmp.name
|
||
|
||
try:
|
||
frames_b64 = _sample_frames_as_base64(tmp_path, start_sec, end_sec, frame_sample_count)
|
||
content: list = []
|
||
for b64 in frames_b64:
|
||
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}})
|
||
content.append({
|
||
"type": "text",
|
||
"text": f"以上是视频第{start_sec}秒至{end_sec}秒的均匀采样帧。\n{prompt_template}",
|
||
})
|
||
|
||
messages = [
|
||
{"role": "system", "content": "你是专业的视频内容分析助手。"},
|
||
{"role": "user", "content": content},
|
||
]
|
||
|
||
try:
|
||
description = await llm.chat_vision(messages, model)
|
||
except Exception as e:
|
||
raise LLMCallError(f"GLM-4V 调用失败: {e}") from e
|
||
|
||
timestamp = int(time.time())
|
||
output_path = f"video-text/{source_id}/{timestamp}.txt"
|
||
await storage.upload_bytes(bucket, output_path, description.encode("utf-8"), "text/plain")
|
||
|
||
await _send_callback(callback_url, VideoJobCallback(
|
||
job_id=job_id, status="SUCCESS", output_path=output_path
|
||
))
|
||
logger.info(f"video_to_text job_id={job_id} output={output_path}")
|
||
except Exception as e:
|
||
logger.exception(f"video_to_text failed job_id={job_id}")
|
||
await _send_callback(callback_url, VideoJobCallback(
|
||
job_id=job_id, status="FAILED", error_message=str(e)
|
||
))
|
||
finally:
|
||
Path(tmp_path).unlink(missing_ok=True)
|
||
```
|
||
|
||
- [ ] **Step 5: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_video_service.py -v
|
||
```
|
||
|
||
Expected: `4 passed`
|
||
|
||
- [ ] **Step 6: Commit**
|
||
|
||
```bash
|
||
git add app/models/video_models.py app/services/video_service.py tests/test_video_service.py
|
||
git commit -m "feat: video models and service with frame extraction and video-to-text"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 13: Video Router
|
||
|
||
**Files:**
|
||
- Create: `app/routers/video.py`
|
||
- Create: `tests/test_video_router.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_video_router.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from app.main import app
|
||
from app.core.dependencies import set_clients
|
||
|
||
|
||
@pytest.fixture
|
||
def client(mock_llm, mock_storage):
|
||
set_clients(mock_llm, mock_storage)
|
||
return TestClient(app)
|
||
|
||
|
||
def test_extract_frames_returns_202(client, mock_storage):
|
||
mock_storage.get_object_size = AsyncMock(return_value=10 * 1024 * 1024) # 10MB
|
||
resp = client.post("/api/v1/video/extract-frames", json={
|
||
"file_path": "video/202404/1.mp4",
|
||
"source_id": 10,
|
||
"job_id": 42,
|
||
"mode": "interval",
|
||
"frame_interval": 30,
|
||
})
|
||
assert resp.status_code == 202
|
||
assert resp.json()["job_id"] == 42
|
||
assert "后台处理中" in resp.json()["message"]
|
||
|
||
|
||
def test_video_to_text_returns_202(client, mock_storage):
|
||
mock_storage.get_object_size = AsyncMock(return_value=10 * 1024 * 1024) # 10MB
|
||
resp = client.post("/api/v1/video/to-text", json={
|
||
"file_path": "video/202404/1.mp4",
|
||
"source_id": 10,
|
||
"job_id": 43,
|
||
"start_sec": 0,
|
||
"end_sec": 60,
|
||
})
|
||
assert resp.status_code == 202
|
||
assert resp.json()["job_id"] == 43
|
||
|
||
|
||
def test_extract_frames_rejects_oversized_video(client, mock_storage):
|
||
mock_storage.get_object_size = AsyncMock(return_value=300 * 1024 * 1024) # 300MB > 200MB limit
|
||
resp = client.post("/api/v1/video/extract-frames", json={
|
||
"file_path": "video/202404/big.mp4",
|
||
"source_id": 10,
|
||
"job_id": 99,
|
||
"mode": "interval",
|
||
"frame_interval": 30,
|
||
})
|
||
assert resp.status_code == 400
|
||
assert "大小" in resp.json()["detail"]
|
||
```
|
||
|
||
- [ ] **Step 2: 实现 `app/routers/video.py`**
|
||
|
||
```python
|
||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.config import get_config
|
||
from app.core.dependencies import get_llm_client, get_storage_client
|
||
from app.models.video_models import (
|
||
ExtractFramesRequest,
|
||
ExtractFramesResponse,
|
||
VideoToTextRequest,
|
||
VideoToTextResponse,
|
||
)
|
||
from app.services import video_service
|
||
|
||
router = APIRouter(tags=["Video"])
|
||
|
||
|
||
async def _check_video_size(storage: StorageClient, bucket: str, file_path: str, max_mb: int) -> None:
|
||
"""在触发后台任务前校验视频文件大小,超限时抛出 HTTP 400。"""
|
||
size_bytes = await storage.get_object_size(bucket, file_path)
|
||
if size_bytes > max_mb * 1024 * 1024:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"视频文件大小超出限制(最大 {max_mb}MB,当前 {size_bytes // 1024 // 1024}MB)",
|
||
)
|
||
|
||
|
||
@router.post("/video/extract-frames", response_model=ExtractFramesResponse, status_code=202)
|
||
async def extract_frames(
|
||
req: ExtractFramesRequest,
|
||
background_tasks: BackgroundTasks,
|
||
storage: StorageClient = Depends(get_storage_client),
|
||
):
|
||
cfg = get_config()
|
||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||
await _check_video_size(storage, bucket, req.file_path, cfg["video"]["max_file_size_mb"])
|
||
background_tasks.add_task(
|
||
video_service.extract_frames_background,
|
||
file_path=req.file_path,
|
||
source_id=req.source_id,
|
||
job_id=req.job_id,
|
||
mode=req.mode,
|
||
frame_interval=req.frame_interval,
|
||
storage=storage,
|
||
callback_url=cfg["backend"]["callback_url"],
|
||
bucket=bucket,
|
||
)
|
||
return ExtractFramesResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
||
|
||
|
||
@router.post("/video/to-text", response_model=VideoToTextResponse, status_code=202)
|
||
async def video_to_text(
|
||
req: VideoToTextRequest,
|
||
background_tasks: BackgroundTasks,
|
||
llm: LLMClient = Depends(get_llm_client),
|
||
storage: StorageClient = Depends(get_storage_client),
|
||
):
|
||
cfg = get_config()
|
||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||
await _check_video_size(storage, bucket, req.file_path, cfg["video"]["max_file_size_mb"])
|
||
model = req.model or cfg["models"]["default_vision"]
|
||
prompt = req.prompt_template or video_service.DEFAULT_VIDEO_TO_TEXT_PROMPT
|
||
background_tasks.add_task(
|
||
video_service.video_to_text_background,
|
||
file_path=req.file_path,
|
||
source_id=req.source_id,
|
||
job_id=req.job_id,
|
||
start_sec=req.start_sec,
|
||
end_sec=req.end_sec,
|
||
model=model,
|
||
prompt_template=prompt,
|
||
frame_sample_count=cfg["video"]["frame_sample_count"],
|
||
llm=llm,
|
||
storage=storage,
|
||
callback_url=cfg["backend"]["callback_url"],
|
||
bucket=bucket,
|
||
)
|
||
return VideoToTextResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
||
```
|
||
|
||
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
||
|
||
```python
|
||
from app.routers import text, image, video
|
||
app.include_router(video.router, prefix="/api/v1")
|
||
```
|
||
|
||
- [ ] **Step 4: 运行测试**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_video_router.py -v
|
||
```
|
||
|
||
Expected: `3 passed`
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add app/routers/video.py tests/test_video_router.py app/main.py
|
||
git commit -m "feat: video router POST /api/v1/video/extract-frames and /to-text"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 14: QA Models + Service
|
||
|
||
**Files:**
|
||
- Create: `app/models/qa_models.py`
|
||
- Create: `app/services/qa_service.py`
|
||
- Create: `tests/test_qa_service.py`
|
||
|
||
- [ ] **Step 1: 实现 `app/models/qa_models.py`**
|
||
|
||
```python
|
||
from pydantic import BaseModel
|
||
|
||
|
||
class TextTripleForQA(BaseModel):
|
||
subject: str
|
||
predicate: str
|
||
object: str
|
||
source_snippet: str
|
||
|
||
|
||
class TextQARequest(BaseModel):
|
||
items: list[TextTripleForQA]
|
||
model: str | None = None
|
||
prompt_template: str | None = None
|
||
|
||
|
||
class QAPair(BaseModel):
|
||
question: str
|
||
answer: str
|
||
|
||
|
||
class TextQAResponse(BaseModel):
|
||
pairs: list[QAPair]
|
||
|
||
|
||
class ImageQuadrupleForQA(BaseModel):
|
||
subject: str
|
||
predicate: str
|
||
object: str
|
||
qualifier: str
|
||
cropped_image_path: str
|
||
|
||
|
||
class ImageQARequest(BaseModel):
|
||
items: list[ImageQuadrupleForQA]
|
||
model: str | None = None
|
||
prompt_template: str | None = None
|
||
|
||
|
||
class ImageQAPair(BaseModel):
|
||
question: str
|
||
answer: str
|
||
image_path: str
|
||
|
||
|
||
class ImageQAResponse(BaseModel):
|
||
pairs: list[ImageQAPair]
|
||
```
|
||
|
||
- [ ] **Step 2: 编写失败测试**
|
||
|
||
`tests/test_qa_service.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from app.services.qa_service import gen_text_qa, gen_image_qa, _parse_qa_pairs
|
||
from app.models.qa_models import TextTripleForQA, ImageQuadrupleForQA
|
||
from app.core.exceptions import LLMResponseParseError, LLMCallError
|
||
|
||
QA_JSON = '[{"question":"变压器额定电压是多少?","answer":"110kV"}]'
|
||
|
||
|
||
def test_parse_qa_pairs_plain_json():
|
||
result = _parse_qa_pairs(QA_JSON)
|
||
assert len(result) == 1
|
||
assert result[0].question == "变压器额定电压是多少?"
|
||
|
||
|
||
def test_parse_qa_pairs_markdown_wrapped():
|
||
result = _parse_qa_pairs(f"```json\n{QA_JSON}\n```")
|
||
assert len(result) == 1
|
||
|
||
|
||
def test_parse_qa_pairs_invalid_raises():
|
||
with pytest.raises(LLMResponseParseError):
|
||
_parse_qa_pairs("这不是JSON")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_gen_text_qa(mock_llm):
|
||
mock_llm.chat.return_value = QA_JSON
|
||
items = [TextTripleForQA(subject="变压器", predicate="额定电压", object="110kV", source_snippet="额定电压为110kV")]
|
||
|
||
result = await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm)
|
||
assert len(result) == 1
|
||
assert result[0].answer == "110kV"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_gen_text_qa_llm_error(mock_llm):
|
||
mock_llm.chat.side_effect = Exception("network error")
|
||
items = [TextTripleForQA(subject="A", predicate="B", object="C", source_snippet="ABC")]
|
||
|
||
with pytest.raises(LLMCallError):
|
||
await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_gen_image_qa(mock_llm, mock_storage):
|
||
mock_llm.chat_vision.return_value = '[{"question":"图中是什么?","answer":"电缆接头"}]'
|
||
mock_storage.download_bytes.return_value = b"fake-image-bytes"
|
||
items = [ImageQuadrupleForQA(
|
||
subject="电缆接头", predicate="位于", object="配电箱", qualifier="", cropped_image_path="crops/1/0.jpg"
|
||
)]
|
||
|
||
result = await gen_image_qa(items=items, model="glm-4v-flash", prompt_template="", llm=mock_llm, storage=mock_storage)
|
||
assert len(result) == 1
|
||
assert result[0].image_path == "crops/1/0.jpg"
|
||
# 验证使用 download_bytes(base64),而非 presigned URL
|
||
mock_storage.download_bytes.assert_called_once_with("source-data", "crops/1/0.jpg")
|
||
# 验证发送给 GLM-4V 的消息包含 base64 data URL
|
||
call_messages = mock_llm.chat_vision.call_args[0][0]
|
||
image_content = call_messages[1]["content"][0]
|
||
assert image_content["image_url"]["url"].startswith("data:image/jpeg;base64,")
|
||
```
|
||
|
||
- [ ] **Step 3: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_qa_service.py -v
|
||
```
|
||
|
||
Expected: `ImportError`
|
||
|
||
- [ ] **Step 4: 实现 `app/services/qa_service.py`**
|
||
|
||
```python
|
||
import base64
|
||
import json
|
||
import logging
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError
|
||
from app.core.json_utils import parse_json_response
|
||
from app.models.qa_models import (
|
||
ImageQAPair,
|
||
ImageQuadrupleForQA,
|
||
QAPair,
|
||
TextTripleForQA,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_TEXT_QA_PROMPT = """基于以下知识三元组和原文证据片段,生成高质量问答对。
|
||
要求:
|
||
1. 问题自然、具体,不能过于宽泛
|
||
2. 答案基于原文片段,语言流畅
|
||
3. 每个三元组生成1-2个问答对
|
||
|
||
以 JSON 数组格式返回:[{"question":"...","answer":"..."}]
|
||
|
||
三元组数据:
|
||
"""
|
||
|
||
DEFAULT_IMAGE_QA_PROMPT = """基于图片内容和以下四元组信息,生成高质量图文问答对。
|
||
要求:
|
||
1. 问题需要结合图片才能回答
|
||
2. 答案基于图片中的实际内容
|
||
3. 每个四元组生成1个问答对
|
||
|
||
以 JSON 数组格式返回:[{"question":"...","answer":"..."}]
|
||
|
||
四元组信息:
|
||
"""
|
||
|
||
|
||
def _parse_qa_pairs(raw: str) -> list[QAPair]:
|
||
items_raw = parse_json_response(raw)
|
||
result = []
|
||
for item in items_raw:
|
||
try:
|
||
result.append(QAPair(question=item["question"], answer=item["answer"]))
|
||
except KeyError as e:
|
||
logger.warning(f"跳过不完整问答对: {item}, error: {e}")
|
||
return result
|
||
|
||
|
||
async def gen_text_qa(
|
||
items: list[TextTripleForQA],
|
||
model: str,
|
||
prompt_template: str,
|
||
llm: LLMClient,
|
||
) -> list[QAPair]:
|
||
triples_text = json.dumps([i.model_dump() for i in items], ensure_ascii=False, indent=2)
|
||
messages = [
|
||
{"role": "system", "content": "你是专业的知识问答对生成助手。"},
|
||
{"role": "user", "content": (prompt_template or DEFAULT_TEXT_QA_PROMPT) + triples_text},
|
||
]
|
||
try:
|
||
raw = await llm.chat(messages, model)
|
||
except Exception as e:
|
||
raise LLMCallError(f"GLM 调用失败: {e}") from e
|
||
logger.info(f"gen_text_qa model={model} items={len(items)}")
|
||
return _parse_qa_pairs(raw)
|
||
|
||
|
||
async def gen_image_qa(
|
||
items: list[ImageQuadrupleForQA],
|
||
model: str,
|
||
prompt_template: str,
|
||
llm: LLMClient,
|
||
storage: StorageClient,
|
||
bucket: str = "source-data",
|
||
) -> list[ImageQAPair]:
|
||
result = []
|
||
prompt = prompt_template or DEFAULT_IMAGE_QA_PROMPT
|
||
for item in items:
|
||
# 下载裁剪图并 base64 编码:RustFS 为内网部署,presigned URL 无法被云端 GLM-4V 访问
|
||
try:
|
||
image_bytes = await storage.download_bytes(bucket, item.cropped_image_path)
|
||
except Exception as e:
|
||
raise StorageDownloadError(f"下载裁剪图失败 {item.cropped_image_path}: {e}") from e
|
||
b64 = base64.b64encode(image_bytes).decode()
|
||
quad_text = json.dumps(
|
||
{k: v for k, v in item.model_dump().items() if k != "cropped_image_path"},
|
||
ensure_ascii=False,
|
||
)
|
||
messages = [
|
||
{"role": "system", "content": "你是专业的视觉问答对生成助手。"},
|
||
{"role": "user", "content": [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}},
|
||
{"type": "text", "text": prompt + quad_text},
|
||
]},
|
||
]
|
||
try:
|
||
raw = await llm.chat_vision(messages, model)
|
||
except Exception as e:
|
||
raise LLMCallError(f"GLM-4V 调用失败: {e}") from e
|
||
for pair in _parse_qa_pairs(raw):
|
||
result.append(ImageQAPair(question=pair.question, answer=pair.answer, image_path=item.cropped_image_path))
|
||
logger.info(f"gen_image_qa model={model} items={len(items)} pairs={len(result)}")
|
||
return result
|
||
```
|
||
|
||
- [ ] **Step 5: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_qa_service.py -v
|
||
```
|
||
|
||
Expected: `6 passed`
|
||
|
||
- [ ] **Step 6: Commit**
|
||
|
||
```bash
|
||
git add app/models/qa_models.py app/services/qa_service.py tests/test_qa_service.py
|
||
git commit -m "feat: QA models and service for text and image QA generation"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 15: QA Router
|
||
|
||
**Files:**
|
||
- Create: `app/routers/qa.py`
|
||
- Create: `tests/test_qa_router.py`
|
||
|
||
- [ ] **Step 1: 编写失败测试**
|
||
|
||
`tests/test_qa_router.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from unittest.mock import AsyncMock
|
||
from app.main import app
|
||
from app.core.dependencies import set_clients
|
||
|
||
|
||
@pytest.fixture
|
||
def client(mock_llm, mock_storage):
|
||
set_clients(mock_llm, mock_storage)
|
||
return TestClient(app)
|
||
|
||
|
||
def test_gen_text_qa_success(client, mock_llm):
|
||
mock_llm.chat = AsyncMock(return_value='[{"question":"额定电压?","answer":"110kV"}]')
|
||
resp = client.post("/api/v1/qa/gen-text", json={
|
||
"items": [{"subject": "变压器", "predicate": "额定电压", "object": "110kV", "source_snippet": "额定电压为110kV"}],
|
||
})
|
||
assert resp.status_code == 200
|
||
assert resp.json()["pairs"][0]["question"] == "额定电压?"
|
||
|
||
|
||
def test_gen_image_qa_success(client, mock_llm, mock_storage):
|
||
mock_llm.chat_vision = AsyncMock(return_value='[{"question":"图中是什么?","answer":"接头"}]')
|
||
mock_storage.get_presigned_url.return_value = "https://example.com/crop.jpg"
|
||
resp = client.post("/api/v1/qa/gen-image", json={
|
||
"items": [{"subject": "A", "predicate": "B", "object": "C", "qualifier": "", "cropped_image_path": "crops/1/0.jpg"}],
|
||
})
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["pairs"][0]["image_path"] == "crops/1/0.jpg"
|
||
```
|
||
|
||
- [ ] **Step 2: 实现 `app/routers/qa.py`**
|
||
|
||
```python
|
||
from fastapi import APIRouter, Depends
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.storage.base import StorageClient
|
||
from app.core.config import get_config
|
||
from app.core.dependencies import get_llm_client, get_storage_client
|
||
from app.models.qa_models import ImageQARequest, ImageQAResponse, TextQARequest, TextQAResponse
|
||
from app.services import qa_service
|
||
|
||
router = APIRouter(tags=["QA"])
|
||
|
||
|
||
@router.post("/qa/gen-text", response_model=TextQAResponse)
|
||
async def gen_text_qa(
|
||
req: TextQARequest,
|
||
llm: LLMClient = Depends(get_llm_client),
|
||
):
|
||
cfg = get_config()
|
||
pairs = await qa_service.gen_text_qa(
|
||
items=req.items,
|
||
model=req.model or cfg["models"]["default_text"],
|
||
prompt_template=req.prompt_template or qa_service.DEFAULT_TEXT_QA_PROMPT,
|
||
llm=llm,
|
||
)
|
||
return TextQAResponse(pairs=pairs)
|
||
|
||
|
||
@router.post("/qa/gen-image", response_model=ImageQAResponse)
|
||
async def gen_image_qa(
|
||
req: ImageQARequest,
|
||
llm: LLMClient = Depends(get_llm_client),
|
||
storage: StorageClient = Depends(get_storage_client),
|
||
):
|
||
cfg = get_config()
|
||
pairs = await qa_service.gen_image_qa(
|
||
items=req.items,
|
||
model=req.model or cfg["models"]["default_vision"],
|
||
prompt_template=req.prompt_template or qa_service.DEFAULT_IMAGE_QA_PROMPT,
|
||
llm=llm,
|
||
storage=storage,
|
||
bucket=cfg["storage"]["buckets"]["source_data"],
|
||
)
|
||
return ImageQAResponse(pairs=pairs)
|
||
```
|
||
|
||
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
||
|
||
```python
|
||
from app.routers import text, image, video, qa
|
||
app.include_router(qa.router, prefix="/api/v1")
|
||
```
|
||
|
||
- [ ] **Step 4: 运行测试**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_qa_router.py -v
|
||
```
|
||
|
||
Expected: `2 passed`
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add app/routers/qa.py tests/test_qa_router.py app/main.py
|
||
git commit -m "feat: QA router POST /api/v1/qa/gen-text and /gen-image"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 16: Finetune Models + Service + Router
|
||
|
||
**Files:**
|
||
- Create: `app/models/finetune_models.py`
|
||
- Create: `app/services/finetune_service.py`
|
||
- Create: `app/routers/finetune.py`
|
||
- Create: `tests/test_finetune_service.py`
|
||
- Create: `tests/test_finetune_router.py`
|
||
|
||
- [ ] **Step 1: 实现 `app/models/finetune_models.py`**
|
||
|
||
```python
|
||
from pydantic import BaseModel
|
||
|
||
|
||
class FinetuneHyperparams(BaseModel):
|
||
learning_rate: float = 1e-4
|
||
epochs: int = 3
|
||
|
||
|
||
class FinetuneStartRequest(BaseModel):
|
||
jsonl_url: str
|
||
base_model: str
|
||
hyperparams: FinetuneHyperparams = FinetuneHyperparams()
|
||
|
||
|
||
class FinetuneStartResponse(BaseModel):
|
||
job_id: str
|
||
|
||
|
||
class FinetuneStatusResponse(BaseModel):
|
||
job_id: str
|
||
status: str # RUNNING | SUCCESS | FAILED
|
||
progress: int | None = None
|
||
error_message: str | None = None
|
||
```
|
||
|
||
- [ ] **Step 2: 编写失败测试**
|
||
|
||
`tests/test_finetune_service.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from unittest.mock import MagicMock
|
||
from app.services.finetune_service import start_finetune, get_finetune_status
|
||
from app.models.finetune_models import FinetuneHyperparams
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_start_finetune():
|
||
mock_job = MagicMock()
|
||
mock_job.id = "glm-ft-abc123"
|
||
mock_zhipuai = MagicMock()
|
||
mock_zhipuai.fine_tuning.jobs.create.return_value = mock_job
|
||
|
||
result = await start_finetune(
|
||
jsonl_url="https://example.com/export.jsonl",
|
||
base_model="glm-4-flash",
|
||
hyperparams=FinetuneHyperparams(learning_rate=1e-4, epochs=3),
|
||
client=mock_zhipuai,
|
||
)
|
||
assert result == "glm-ft-abc123"
|
||
mock_zhipuai.fine_tuning.jobs.create.assert_called_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_finetune_status_running():
|
||
mock_job = MagicMock()
|
||
mock_job.status = "running"
|
||
mock_job.progress = 50
|
||
mock_job.error = None
|
||
mock_zhipuai = MagicMock()
|
||
mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job
|
||
|
||
result = await get_finetune_status("glm-ft-abc123", mock_zhipuai)
|
||
assert result.status == "RUNNING"
|
||
assert result.progress == 50
|
||
assert result.job_id == "glm-ft-abc123"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_finetune_status_success():
|
||
mock_job = MagicMock()
|
||
mock_job.status = "succeeded"
|
||
mock_job.progress = 100
|
||
mock_job.error = None
|
||
mock_zhipuai = MagicMock()
|
||
mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job
|
||
|
||
result = await get_finetune_status("glm-ft-abc123", mock_zhipuai)
|
||
assert result.status == "SUCCESS"
|
||
```
|
||
|
||
- [ ] **Step 3: 运行,确认失败**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_finetune_service.py -v
|
||
```
|
||
|
||
Expected: `ImportError`
|
||
|
||
- [ ] **Step 4: 实现 `app/services/finetune_service.py`**
|
||
|
||
```python
|
||
import logging
|
||
|
||
from app.models.finetune_models import FinetuneHyperparams, FinetuneStatusResponse
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_STATUS_MAP = {
|
||
"running": "RUNNING",
|
||
"succeeded": "SUCCESS",
|
||
"failed": "FAILED",
|
||
}
|
||
|
||
|
||
async def start_finetune(
|
||
jsonl_url: str,
|
||
base_model: str,
|
||
hyperparams: FinetuneHyperparams,
|
||
client, # ZhipuAI SDK client instance
|
||
) -> str:
|
||
job = client.fine_tuning.jobs.create(
|
||
training_file=jsonl_url,
|
||
model=base_model,
|
||
hyperparameters={
|
||
"learning_rate_multiplier": hyperparams.learning_rate,
|
||
"n_epochs": hyperparams.epochs,
|
||
},
|
||
)
|
||
logger.info(f"finetune_start job_id={job.id} model={base_model}")
|
||
return job.id
|
||
|
||
|
||
async def get_finetune_status(job_id: str, client) -> FinetuneStatusResponse:
|
||
job = client.fine_tuning.jobs.retrieve(job_id)
|
||
status = _STATUS_MAP.get(job.status, "RUNNING")
|
||
return FinetuneStatusResponse(
|
||
job_id=job_id,
|
||
status=status,
|
||
progress=getattr(job, "progress", None),
|
||
error_message=getattr(job, "error", None),
|
||
)
|
||
```
|
||
|
||
- [ ] **Step 5: 运行,确认通过**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/test_finetune_service.py -v
|
||
```
|
||
|
||
Expected: `3 passed`
|
||
|
||
- [ ] **Step 6: 实现 `app/routers/finetune.py`**
|
||
|
||
```python
|
||
from fastapi import APIRouter, Depends
|
||
|
||
from app.clients.llm.base import LLMClient
|
||
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
||
from app.core.dependencies import get_llm_client
|
||
from app.models.finetune_models import (
|
||
FinetuneStartRequest,
|
||
FinetuneStartResponse,
|
||
FinetuneStatusResponse,
|
||
)
|
||
from app.services import finetune_service
|
||
|
||
router = APIRouter(tags=["Finetune"])
|
||
|
||
|
||
def _get_zhipuai(llm: LLMClient = Depends(get_llm_client)) -> ZhipuAIClient:
|
||
if not isinstance(llm, ZhipuAIClient):
|
||
raise RuntimeError("微调功能仅支持 ZhipuAI 后端")
|
||
return llm
|
||
|
||
|
||
@router.post("/finetune/start", response_model=FinetuneStartResponse)
|
||
async def start_finetune(
|
||
req: FinetuneStartRequest,
|
||
llm: ZhipuAIClient = Depends(_get_zhipuai),
|
||
):
|
||
job_id = await finetune_service.start_finetune(
|
||
jsonl_url=req.jsonl_url,
|
||
base_model=req.base_model,
|
||
hyperparams=req.hyperparams,
|
||
client=llm._client,
|
||
)
|
||
return FinetuneStartResponse(job_id=job_id)
|
||
|
||
|
||
@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse)
|
||
async def get_finetune_status(
|
||
job_id: str,
|
||
llm: ZhipuAIClient = Depends(_get_zhipuai),
|
||
):
|
||
return await finetune_service.get_finetune_status(job_id, llm._client)
|
||
```
|
||
|
||
- [ ] **Step 7: 编写路由测试**
|
||
|
||
`tests/test_finetune_router.py`:
|
||
|
||
```python
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from unittest.mock import MagicMock, patch
|
||
from app.main import app
|
||
from app.core.dependencies import set_clients
|
||
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
||
from app.clients.storage.base import StorageClient
|
||
|
||
|
||
@pytest.fixture
|
||
def client(mock_storage):
|
||
with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI:
|
||
mock_sdk = MagicMock()
|
||
MockZhipuAI.return_value = mock_sdk
|
||
llm = ZhipuAIClient(api_key="test-key")
|
||
llm._mock_sdk = mock_sdk
|
||
set_clients(llm, mock_storage)
|
||
yield TestClient(app), mock_sdk
|
||
|
||
|
||
def test_start_finetune(client):
|
||
test_client, mock_sdk = client
|
||
mock_job = MagicMock()
|
||
mock_job.id = "glm-ft-xyz"
|
||
mock_sdk.fine_tuning.jobs.create.return_value = mock_job
|
||
|
||
resp = test_client.post("/api/v1/finetune/start", json={
|
||
"jsonl_url": "https://example.com/export.jsonl",
|
||
"base_model": "glm-4-flash",
|
||
"hyperparams": {"learning_rate": 1e-4, "epochs": 3},
|
||
})
|
||
assert resp.status_code == 200
|
||
assert resp.json()["job_id"] == "glm-ft-xyz"
|
||
|
||
|
||
def test_get_finetune_status(client):
|
||
test_client, mock_sdk = client
|
||
mock_job = MagicMock()
|
||
mock_job.status = "running"
|
||
mock_job.progress = 30
|
||
mock_job.error = None
|
||
mock_sdk.fine_tuning.jobs.retrieve.return_value = mock_job
|
||
|
||
resp = test_client.get("/api/v1/finetune/status/glm-ft-xyz")
|
||
assert resp.status_code == 200
|
||
data = resp.json()
|
||
assert data["status"] == "RUNNING"
|
||
assert data["progress"] == 30
|
||
```
|
||
|
||
- [ ] **Step 8: 在 `app/main.py` 注册路由(最终状态)**
|
||
|
||
```python
|
||
from app.routers import text, image, video, qa, finetune
|
||
|
||
app.include_router(text.router, prefix="/api/v1")
|
||
app.include_router(image.router, prefix="/api/v1")
|
||
app.include_router(video.router, prefix="/api/v1")
|
||
app.include_router(qa.router, prefix="/api/v1")
|
||
app.include_router(finetune.router, prefix="/api/v1")
|
||
```
|
||
|
||
- [ ] **Step 9: 运行全部测试**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/ -v
|
||
```
|
||
|
||
Expected: 所有测试通过,无失败
|
||
|
||
- [ ] **Step 10: Commit**
|
||
|
||
```bash
|
||
git add app/models/finetune_models.py app/services/finetune_service.py app/routers/finetune.py tests/test_finetune_service.py tests/test_finetune_router.py app/main.py
|
||
git commit -m "feat: finetune models, service, and router - complete all endpoints"
|
||
```
|
||
|
||
---
|
||
|
||
## Task 17: 部署文件
|
||
|
||
**Files:**
|
||
- Create: `Dockerfile`
|
||
- Create: `docker-compose.yml`
|
||
|
||
- [ ] **Step 1: 创建 `Dockerfile`**
|
||
|
||
```dockerfile
|
||
FROM python:3.12-slim
|
||
|
||
WORKDIR /app
|
||
|
||
# OpenCV 系统依赖
|
||
RUN apt-get update && apt-get install -y \
|
||
libgl1 \
|
||
libglib2.0-0 \
|
||
&& rm -rf /var/lib/apt/lists/*
|
||
|
||
COPY requirements.txt .
|
||
RUN pip install --no-cache-dir -r requirements.txt
|
||
|
||
COPY app/ ./app/
|
||
COPY config.yaml .
|
||
COPY .env .
|
||
|
||
EXPOSE 8000
|
||
|
||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||
```
|
||
|
||
- [ ] **Step 2: 创建 `docker-compose.yml`**
|
||
|
||
```yaml
|
||
version: "3.9"
|
||
|
||
services:
|
||
ai-service:
|
||
build: .
|
||
ports:
|
||
- "8000:8000"
|
||
env_file:
|
||
- .env
|
||
depends_on:
|
||
- rustfs
|
||
networks:
|
||
- label-net
|
||
restart: unless-stopped
|
||
healthcheck:
|
||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||
interval: 30s
|
||
timeout: 5s
|
||
retries: 3
|
||
start_period: 10s
|
||
|
||
rustfs:
|
||
image: minio/minio:latest
|
||
command: server /data --console-address ":9001"
|
||
ports:
|
||
- "9000:9000"
|
||
- "9001:9001"
|
||
environment:
|
||
MINIO_ROOT_USER: minioadmin
|
||
MINIO_ROOT_PASSWORD: minioadmin
|
||
volumes:
|
||
- rustfs-data:/data
|
||
networks:
|
||
- label-net
|
||
|
||
volumes:
|
||
rustfs-data:
|
||
|
||
networks:
|
||
label-net:
|
||
driver: bridge
|
||
```
|
||
|
||
- [ ] **Step 3: 验证 Docker 构建**
|
||
|
||
```bash
|
||
docker build -t label-ai-service:dev .
|
||
```
|
||
|
||
Expected: 镜像构建成功,无错误
|
||
|
||
- [ ] **Step 4: 运行全量测试,最终确认**
|
||
|
||
```bash
|
||
conda run -n label pytest tests/ -v --tb=short
|
||
```
|
||
|
||
Expected: 所有测试通过
|
||
|
||
- [ ] **Step 5: Commit**
|
||
|
||
```bash
|
||
git add Dockerfile docker-compose.yml
|
||
git commit -m "feat: Dockerfile and docker-compose for containerized deployment"
|
||
```
|
||
|
||
---
|
||
|
||
## 自审检查结果
|
||
|
||
**Spec coverage:**
|
||
- ✅ 文本三元组提取(TXT/PDF/DOCX)— Task 8-9
|
||
- ✅ 图像四元组提取 + bbox 裁剪 — Task 10-11
|
||
- ✅ 视频帧提取(interval/keyframe)— Task 12-13
|
||
- ✅ 视频转文本(BackgroundTask)— Task 12-13
|
||
- ✅ 文本问答对生成 — Task 14-15
|
||
- ✅ 图像问答对生成 — Task 14-15
|
||
- ✅ 微调任务提交与状态查询 — Task 16
|
||
- ✅ LLMClient / StorageClient ABC 适配层 — Task 4-5
|
||
- ✅ config.yaml + .env 分层配置 — Task 2
|
||
- ✅ 结构化日志 + 请求日志 — Task 3
|
||
- ✅ 全局异常处理 — Task 3
|
||
- ✅ Swagger 文档(FastAPI 自动生成) — Task 6
|
||
- ✅ Dockerfile + docker-compose — Task 17
|
||
- ✅ pytest 测试覆盖全部 service 和 router — 各 Task
|
||
|
||
**类型一致性:** `TripleItem.source_offset` 在 Task 7 定义,Task 8 使用;`VideoJobCallback` 在 Task 12 定义,Task 12 service 使用 — 一致。
|
||
|
||
**占位符:** 无 TBD / TODO,所有步骤均含完整代码。
|