2885 lines
80 KiB
Markdown
2885 lines
80 KiB
Markdown
|
|
# AI Service Implementation Plan
|
|||
|
|
|
|||
|
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
|||
|
|
|
|||
|
|
**Goal:** 实现 label_ai_service,一个 Python FastAPI 服务,为知识图谱标注平台提供文本三元组提取、图像四元组提取、视频处理、问答对生成和 GLM 微调管理能力。
|
|||
|
|
|
|||
|
|
**Architecture:** 分层架构:routers(HTTP 入口)→ services(业务逻辑)→ clients(外部适配层)。LLMClient 和 StorageClient 均为 ABC,当前分别实现 ZhipuAIClient 和 RustFSClient,通过 FastAPI Depends 注入,services 层不感知具体实现。视频任务用 FastAPI BackgroundTasks 异步执行,完成后回调 Java 后端。
|
|||
|
|
|
|||
|
|
**Tech Stack:** Python 3.12(conda `label` 环境),FastAPI,ZhipuAI SDK,boto3(S3),OpenCV,pdfplumber,python-docx,httpx,pytest
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 1: 项目脚手架
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/__init__.py`
|
|||
|
|
- Create: `app/core/__init__.py`
|
|||
|
|
- Create: `app/clients/__init__.py`
|
|||
|
|
- Create: `app/clients/llm/__init__.py`
|
|||
|
|
- Create: `app/clients/storage/__init__.py`
|
|||
|
|
- Create: `app/services/__init__.py`
|
|||
|
|
- Create: `app/routers/__init__.py`
|
|||
|
|
- Create: `app/models/__init__.py`
|
|||
|
|
- Create: `tests/__init__.py`
|
|||
|
|
- Create: `tests/conftest.py`
|
|||
|
|
- Create: `config.yaml`
|
|||
|
|
- Create: `.env`
|
|||
|
|
- Create: `requirements.txt`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 创建包目录结构**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
mkdir -p app/core app/clients/llm app/clients/storage app/services app/routers app/models tests
|
|||
|
|
touch app/__init__.py app/core/__init__.py
|
|||
|
|
touch app/clients/__init__.py app/clients/llm/__init__.py app/clients/storage/__init__.py
|
|||
|
|
touch app/services/__init__.py app/routers/__init__.py app/models/__init__.py
|
|||
|
|
touch tests/__init__.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 创建 `config.yaml`**
|
|||
|
|
|
|||
|
|
```yaml
|
|||
|
|
server:
|
|||
|
|
port: 8000
|
|||
|
|
log_level: INFO
|
|||
|
|
|
|||
|
|
storage:
|
|||
|
|
buckets:
|
|||
|
|
source_data: "source-data"
|
|||
|
|
finetune_export: "finetune-export"
|
|||
|
|
|
|||
|
|
backend: {}
|
|||
|
|
|
|||
|
|
video:
|
|||
|
|
frame_sample_count: 8
|
|||
|
|
|
|||
|
|
models:
|
|||
|
|
default_text: "glm-4-flash"
|
|||
|
|
default_vision: "glm-4v-flash"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 创建 `.env`**
|
|||
|
|
|
|||
|
|
```ini
|
|||
|
|
ZHIPUAI_API_KEY=your-zhipuai-api-key
|
|||
|
|
STORAGE_ACCESS_KEY=minioadmin
|
|||
|
|
STORAGE_SECRET_KEY=minioadmin
|
|||
|
|
STORAGE_ENDPOINT=http://rustfs:9000
|
|||
|
|
BACKEND_CALLBACK_URL=http://backend:8080/internal/video-job/callback
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 创建 `requirements.txt`**
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
fastapi>=0.111
|
|||
|
|
uvicorn[standard]>=0.29
|
|||
|
|
pydantic>=2.7
|
|||
|
|
python-dotenv>=1.0
|
|||
|
|
pyyaml>=6.0
|
|||
|
|
zhipuai>=2.1
|
|||
|
|
boto3>=1.34
|
|||
|
|
pdfplumber>=0.11
|
|||
|
|
python-docx>=1.1
|
|||
|
|
opencv-python-headless>=4.9
|
|||
|
|
numpy>=1.26
|
|||
|
|
httpx>=0.27
|
|||
|
|
pytest>=8.0
|
|||
|
|
pytest-asyncio>=0.23
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: 创建 `tests/conftest.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import AsyncMock, MagicMock
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_llm():
|
|||
|
|
client = MagicMock(spec=LLMClient)
|
|||
|
|
client.chat = AsyncMock()
|
|||
|
|
client.chat_vision = AsyncMock()
|
|||
|
|
return client
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_storage():
|
|||
|
|
client = MagicMock(spec=StorageClient)
|
|||
|
|
client.download_bytes = AsyncMock()
|
|||
|
|
client.upload_bytes = AsyncMock()
|
|||
|
|
client.get_presigned_url = MagicMock(return_value="https://example.com/presigned/crop.jpg")
|
|||
|
|
return client
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: 安装依赖**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pip install -r requirements.txt
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: 所有包安装成功,无错误
|
|||
|
|
|
|||
|
|
- [ ] **Step 7: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/ tests/ config.yaml .env requirements.txt
|
|||
|
|
git commit -m "feat: project scaffold - directory structure and config files"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 2: Core Config 模块
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/core/config.py`
|
|||
|
|
- Create: `tests/test_config.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_config.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import patch, mock_open
|
|||
|
|
from app.core.config import get_config
|
|||
|
|
|
|||
|
|
MOCK_YAML = """
|
|||
|
|
server:
|
|||
|
|
port: 8000
|
|||
|
|
log_level: INFO
|
|||
|
|
storage:
|
|||
|
|
buckets:
|
|||
|
|
source_data: "source-data"
|
|||
|
|
finetune_export: "finetune-export"
|
|||
|
|
backend: {}
|
|||
|
|
video:
|
|||
|
|
frame_sample_count: 8
|
|||
|
|
models:
|
|||
|
|
default_text: "glm-4-flash"
|
|||
|
|
default_vision: "glm-4v-flash"
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _fresh_config(monkeypatch, extra_env: dict = None):
|
|||
|
|
"""每次测试前清除 lru_cache,设置环境变量。"""
|
|||
|
|
get_config.cache_clear()
|
|||
|
|
base_env = {
|
|||
|
|
"ZHIPUAI_API_KEY": "test-key",
|
|||
|
|
"STORAGE_ACCESS_KEY": "test-access",
|
|||
|
|
"STORAGE_SECRET_KEY": "test-secret",
|
|||
|
|
"STORAGE_ENDPOINT": "http://localhost:9000",
|
|||
|
|
"BACKEND_CALLBACK_URL": "http://localhost:8080/callback",
|
|||
|
|
}
|
|||
|
|
if extra_env:
|
|||
|
|
base_env.update(extra_env)
|
|||
|
|
for k, v in base_env.items():
|
|||
|
|
monkeypatch.setenv(k, v)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_env_overrides_yaml(monkeypatch):
|
|||
|
|
_fresh_config(monkeypatch)
|
|||
|
|
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
|||
|
|
with patch("app.core.config.load_dotenv"):
|
|||
|
|
cfg = get_config()
|
|||
|
|
assert cfg["zhipuai"]["api_key"] == "test-key"
|
|||
|
|
assert cfg["storage"]["access_key"] == "test-access"
|
|||
|
|
assert cfg["storage"]["endpoint"] == "http://localhost:9000"
|
|||
|
|
assert cfg["backend"]["callback_url"] == "http://localhost:8080/callback"
|
|||
|
|
get_config.cache_clear()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_yaml_values_preserved(monkeypatch):
|
|||
|
|
_fresh_config(monkeypatch)
|
|||
|
|
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
|||
|
|
with patch("app.core.config.load_dotenv"):
|
|||
|
|
cfg = get_config()
|
|||
|
|
assert cfg["models"]["default_text"] == "glm-4-flash"
|
|||
|
|
assert cfg["video"]["frame_sample_count"] == 8
|
|||
|
|
assert cfg["storage"]["buckets"]["source_data"] == "source-data"
|
|||
|
|
get_config.cache_clear()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_missing_api_key_raises(monkeypatch):
|
|||
|
|
get_config.cache_clear()
|
|||
|
|
monkeypatch.delenv("ZHIPUAI_API_KEY", raising=False)
|
|||
|
|
monkeypatch.setenv("STORAGE_ACCESS_KEY", "a")
|
|||
|
|
monkeypatch.setenv("STORAGE_SECRET_KEY", "b")
|
|||
|
|
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
|||
|
|
with patch("app.core.config.load_dotenv"):
|
|||
|
|
with pytest.raises(RuntimeError, match="ZHIPUAI_API_KEY"):
|
|||
|
|
get_config()
|
|||
|
|
get_config.cache_clear()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_missing_storage_key_raises(monkeypatch):
|
|||
|
|
get_config.cache_clear()
|
|||
|
|
monkeypatch.setenv("ZHIPUAI_API_KEY", "key")
|
|||
|
|
monkeypatch.delenv("STORAGE_ACCESS_KEY", raising=False)
|
|||
|
|
monkeypatch.setenv("STORAGE_SECRET_KEY", "b")
|
|||
|
|
with patch("builtins.open", mock_open(read_data=MOCK_YAML)):
|
|||
|
|
with patch("app.core.config.load_dotenv"):
|
|||
|
|
with pytest.raises(RuntimeError, match="STORAGE_ACCESS_KEY"):
|
|||
|
|
get_config()
|
|||
|
|
get_config.cache_clear()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_config.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError: cannot import name 'get_config'`
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 实现 `app/core/config.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import os
|
|||
|
|
import yaml
|
|||
|
|
from functools import lru_cache
|
|||
|
|
from pathlib import Path
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
_ROOT = Path(__file__).parent.parent.parent
|
|||
|
|
|
|||
|
|
_ENV_OVERRIDES = {
|
|||
|
|
"ZHIPUAI_API_KEY": ["zhipuai", "api_key"],
|
|||
|
|
"STORAGE_ACCESS_KEY": ["storage", "access_key"],
|
|||
|
|
"STORAGE_SECRET_KEY": ["storage", "secret_key"],
|
|||
|
|
"STORAGE_ENDPOINT": ["storage", "endpoint"],
|
|||
|
|
"BACKEND_CALLBACK_URL": ["backend", "callback_url"],
|
|||
|
|
"LOG_LEVEL": ["server", "log_level"],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _set_nested(d: dict, keys: list[str], value: str) -> None:
|
|||
|
|
for k in keys[:-1]:
|
|||
|
|
d = d.setdefault(k, {})
|
|||
|
|
d[keys[-1]] = value
|
|||
|
|
|
|||
|
|
|
|||
|
|
@lru_cache(maxsize=1)
|
|||
|
|
def get_config() -> dict:
|
|||
|
|
load_dotenv(_ROOT / ".env")
|
|||
|
|
with open(_ROOT / "config.yaml", encoding="utf-8") as f:
|
|||
|
|
cfg = yaml.safe_load(f)
|
|||
|
|
for env_key, yaml_path in _ENV_OVERRIDES.items():
|
|||
|
|
val = os.environ.get(env_key)
|
|||
|
|
if val:
|
|||
|
|
_set_nested(cfg, yaml_path, val)
|
|||
|
|
_validate(cfg)
|
|||
|
|
return cfg
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _validate(cfg: dict) -> None:
|
|||
|
|
checks = [
|
|||
|
|
(["zhipuai", "api_key"], "ZHIPUAI_API_KEY"),
|
|||
|
|
(["storage", "access_key"], "STORAGE_ACCESS_KEY"),
|
|||
|
|
(["storage", "secret_key"], "STORAGE_SECRET_KEY"),
|
|||
|
|
]
|
|||
|
|
for path, name in checks:
|
|||
|
|
val = cfg
|
|||
|
|
for k in path:
|
|||
|
|
val = (val or {}).get(k, "")
|
|||
|
|
if not val:
|
|||
|
|
raise RuntimeError(f"缺少必要配置项:{name}")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_config.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `4 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/core/config.py tests/test_config.py
|
|||
|
|
git commit -m "feat: core config module with YAML + env layered loading"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 3: Core Logging、Exceptions、JSON Utils
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/core/logging.py`
|
|||
|
|
- Create: `app/core/exceptions.py`
|
|||
|
|
- Create: `app/core/json_utils.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 实现 `app/core/logging.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import time
|
|||
|
|
from typing import Callable
|
|||
|
|
|
|||
|
|
from fastapi import Request, Response
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _JSONFormatter(logging.Formatter):
|
|||
|
|
def format(self, record: logging.LogRecord) -> str:
|
|||
|
|
entry: dict = {
|
|||
|
|
"time": self.formatTime(record),
|
|||
|
|
"level": record.levelname,
|
|||
|
|
"logger": record.name,
|
|||
|
|
"message": record.getMessage(),
|
|||
|
|
}
|
|||
|
|
if record.exc_info:
|
|||
|
|
entry["exception"] = self.formatException(record.exc_info)
|
|||
|
|
return json.dumps(entry, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def setup_logging(log_level: str = "INFO") -> None:
|
|||
|
|
handler = logging.StreamHandler()
|
|||
|
|
handler.setFormatter(_JSONFormatter())
|
|||
|
|
root = logging.getLogger()
|
|||
|
|
root.handlers.clear()
|
|||
|
|
root.addHandler(handler)
|
|||
|
|
root.setLevel(getattr(logging, log_level.upper(), logging.INFO))
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def request_logging_middleware(request: Request, call_next: Callable) -> Response:
|
|||
|
|
start = time.monotonic()
|
|||
|
|
response = await call_next(request)
|
|||
|
|
duration_ms = round((time.monotonic() - start) * 1000, 2)
|
|||
|
|
logging.getLogger("api").info(
|
|||
|
|
f"method={request.method} path={request.url.path} "
|
|||
|
|
f"status={response.status_code} duration_ms={duration_ms}"
|
|||
|
|
)
|
|||
|
|
return response
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 实现 `app/core/exceptions.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import logging
|
|||
|
|
from fastapi import Request
|
|||
|
|
from fastapi.responses import JSONResponse
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UnsupportedFileTypeError(Exception):
|
|||
|
|
def __init__(self, ext: str):
|
|||
|
|
super().__init__(f"不支持的文件类型:{ext}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StorageDownloadError(Exception):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMResponseParseError(Exception):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMCallError(Exception):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def unsupported_file_type_handler(request: Request, exc: UnsupportedFileTypeError):
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=400,
|
|||
|
|
content={"code": "UNSUPPORTED_FILE_TYPE", "message": str(exc)},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def storage_download_handler(request: Request, exc: StorageDownloadError):
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=502,
|
|||
|
|
content={"code": "STORAGE_ERROR", "message": str(exc)},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def llm_parse_handler(request: Request, exc: LLMResponseParseError):
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=502,
|
|||
|
|
content={"code": "LLM_PARSE_ERROR", "message": str(exc)},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def llm_call_handler(request: Request, exc: LLMCallError):
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=503,
|
|||
|
|
content={"code": "LLM_CALL_ERROR", "message": str(exc)},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def generic_error_handler(request: Request, exc: Exception):
|
|||
|
|
logging.getLogger("error").exception("未捕获异常")
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=500,
|
|||
|
|
content={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 实现 `app/core/json_utils.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import json
|
|||
|
|
from app.core.exceptions import LLMResponseParseError
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_json_response(raw: str) -> list | dict:
|
|||
|
|
"""从 GLM 响应中解析 JSON,兼容 markdown 代码块包裹格式。"""
|
|||
|
|
content = raw.strip()
|
|||
|
|
if "```json" in content:
|
|||
|
|
content = content.split("```json")[1].split("```")[0]
|
|||
|
|
elif "```" in content:
|
|||
|
|
content = content.split("```")[1].split("```")[0]
|
|||
|
|
content = content.strip()
|
|||
|
|
try:
|
|||
|
|
return json.loads(content)
|
|||
|
|
except json.JSONDecodeError as e:
|
|||
|
|
raise LLMResponseParseError(
|
|||
|
|
f"GLM 返回内容无法解析为 JSON: {raw[:200]}"
|
|||
|
|
) from e
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/core/logging.py app/core/exceptions.py app/core/json_utils.py
|
|||
|
|
git commit -m "feat: core logging, exceptions, json utils"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 4: LLM 适配层
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/clients/llm/base.py`
|
|||
|
|
- Create: `app/clients/llm/zhipuai_client.py`
|
|||
|
|
- Create: `tests/test_llm_client.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_llm_client.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import asyncio
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def zhipuai_client():
|
|||
|
|
with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI:
|
|||
|
|
mock_sdk = MagicMock()
|
|||
|
|
MockZhipuAI.return_value = mock_sdk
|
|||
|
|
client = ZhipuAIClient(api_key="test-key")
|
|||
|
|
client._mock_sdk = mock_sdk
|
|||
|
|
yield client
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_chat_returns_content(zhipuai_client):
|
|||
|
|
mock_resp = MagicMock()
|
|||
|
|
mock_resp.choices[0].message.content = "三元组提取结果"
|
|||
|
|
zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp
|
|||
|
|
|
|||
|
|
result = asyncio.run(
|
|||
|
|
zhipuai_client.chat(
|
|||
|
|
messages=[{"role": "user", "content": "提取三元组"}],
|
|||
|
|
model="glm-4-flash",
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
assert result == "三元组提取结果"
|
|||
|
|
zhipuai_client._mock_sdk.chat.completions.create.assert_called_once()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_chat_vision_calls_same_endpoint(zhipuai_client):
|
|||
|
|
mock_resp = MagicMock()
|
|||
|
|
mock_resp.choices[0].message.content = "图像分析结果"
|
|||
|
|
zhipuai_client._mock_sdk.chat.completions.create.return_value = mock_resp
|
|||
|
|
|
|||
|
|
result = asyncio.run(
|
|||
|
|
zhipuai_client.chat_vision(
|
|||
|
|
messages=[{"role": "user", "content": [{"type": "text", "text": "分析"}]}],
|
|||
|
|
model="glm-4v-flash",
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
assert result == "图像分析结果"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_llm_client.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError`
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 实现 `app/clients/llm/base.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMClient(ABC):
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat(self, messages: list[dict], model: str, **kwargs) -> str:
|
|||
|
|
"""纯文本对话,返回模型输出文本。"""
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str:
|
|||
|
|
"""多模态对话(图文混合输入),返回模型输出文本。"""
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 实现 `app/clients/llm/zhipuai_client.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import asyncio
|
|||
|
|
from zhipuai import ZhipuAI
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ZhipuAIClient(LLMClient):
|
|||
|
|
def __init__(self, api_key: str):
|
|||
|
|
self._client = ZhipuAI(api_key=api_key)
|
|||
|
|
|
|||
|
|
async def chat(self, messages: list[dict], model: str, **kwargs) -> str:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
resp = await loop.run_in_executor(
|
|||
|
|
None,
|
|||
|
|
lambda: self._client.chat.completions.create(
|
|||
|
|
model=model, messages=messages, **kwargs
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
return resp.choices[0].message.content
|
|||
|
|
|
|||
|
|
async def chat_vision(self, messages: list[dict], model: str, **kwargs) -> str:
|
|||
|
|
# GLM-4V 与文本接口相同,通过 image_url type 区分图文消息
|
|||
|
|
return await self.chat(messages, model, **kwargs)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_llm_client.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `2 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/clients/llm/ tests/test_llm_client.py
|
|||
|
|
git commit -m "feat: LLMClient ABC and ZhipuAI implementation"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 5: Storage 适配层
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/clients/storage/base.py`
|
|||
|
|
- Create: `app/clients/storage/rustfs_client.py`
|
|||
|
|
- Create: `tests/test_storage_client.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_storage_client.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import asyncio
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
from app.clients.storage.rustfs_client import RustFSClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def rustfs_client():
|
|||
|
|
with patch("app.clients.storage.rustfs_client.boto3") as mock_boto3:
|
|||
|
|
mock_s3 = MagicMock()
|
|||
|
|
mock_boto3.client.return_value = mock_s3
|
|||
|
|
client = RustFSClient(
|
|||
|
|
endpoint="http://localhost:9000",
|
|||
|
|
access_key="minioadmin",
|
|||
|
|
secret_key="minioadmin",
|
|||
|
|
)
|
|||
|
|
client._mock_s3 = mock_s3
|
|||
|
|
yield client
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_download_bytes(rustfs_client):
|
|||
|
|
mock_body = MagicMock()
|
|||
|
|
mock_body.read.return_value = b"file content"
|
|||
|
|
rustfs_client._mock_s3.get_object.return_value = {"Body": mock_body}
|
|||
|
|
|
|||
|
|
result = asyncio.run(
|
|||
|
|
rustfs_client.download_bytes("source-data", "text/202404/1.txt")
|
|||
|
|
)
|
|||
|
|
assert result == b"file content"
|
|||
|
|
rustfs_client._mock_s3.get_object.assert_called_once_with(
|
|||
|
|
Bucket="source-data", Key="text/202404/1.txt"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_upload_bytes(rustfs_client):
|
|||
|
|
asyncio.run(
|
|||
|
|
rustfs_client.upload_bytes("source-data", "crops/1/0.jpg", b"img", "image/jpeg")
|
|||
|
|
)
|
|||
|
|
rustfs_client._mock_s3.put_object.assert_called_once_with(
|
|||
|
|
Bucket="source-data", Key="crops/1/0.jpg", Body=b"img", ContentType="image/jpeg"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_get_presigned_url(rustfs_client):
|
|||
|
|
rustfs_client._mock_s3.generate_presigned_url.return_value = "https://example.com/signed"
|
|||
|
|
url = rustfs_client.get_presigned_url("source-data", "crops/1/0.jpg", expires=3600)
|
|||
|
|
assert url == "https://example.com/signed"
|
|||
|
|
rustfs_client._mock_s3.generate_presigned_url.assert_called_once_with(
|
|||
|
|
"get_object",
|
|||
|
|
Params={"Bucket": "source-data", "Key": "crops/1/0.jpg"},
|
|||
|
|
ExpiresIn=3600,
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_storage_client.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError`
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 实现 `app/clients/storage/base.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StorageClient(ABC):
|
|||
|
|
@abstractmethod
|
|||
|
|
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
|||
|
|
"""从对象存储下载文件,返回字节内容。"""
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def upload_bytes(
|
|||
|
|
self,
|
|||
|
|
bucket: str,
|
|||
|
|
path: str,
|
|||
|
|
data: bytes,
|
|||
|
|
content_type: str = "application/octet-stream",
|
|||
|
|
) -> None:
|
|||
|
|
"""上传字节内容到对象存储。"""
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
|||
|
|
"""生成预签名访问 URL。"""
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 实现 `app/clients/storage/rustfs_client.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import asyncio
|
|||
|
|
import boto3
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RustFSClient(StorageClient):
|
|||
|
|
def __init__(self, endpoint: str, access_key: str, secret_key: str):
|
|||
|
|
self._s3 = boto3.client(
|
|||
|
|
"s3",
|
|||
|
|
endpoint_url=endpoint,
|
|||
|
|
aws_access_key_id=access_key,
|
|||
|
|
aws_secret_access_key=secret_key,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
resp = await loop.run_in_executor(
|
|||
|
|
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
|
|||
|
|
)
|
|||
|
|
return resp["Body"].read()
|
|||
|
|
|
|||
|
|
async def upload_bytes(
|
|||
|
|
self,
|
|||
|
|
bucket: str,
|
|||
|
|
path: str,
|
|||
|
|
data: bytes,
|
|||
|
|
content_type: str = "application/octet-stream",
|
|||
|
|
) -> None:
|
|||
|
|
loop = asyncio.get_event_loop()
|
|||
|
|
await loop.run_in_executor(
|
|||
|
|
None,
|
|||
|
|
lambda: self._s3.put_object(
|
|||
|
|
Bucket=bucket, Key=path, Body=data, ContentType=content_type
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
|||
|
|
return self._s3.generate_presigned_url(
|
|||
|
|
"get_object",
|
|||
|
|
Params={"Bucket": bucket, "Key": path},
|
|||
|
|
ExpiresIn=expires,
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_storage_client.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `3 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/clients/storage/ tests/test_storage_client.py
|
|||
|
|
git commit -m "feat: StorageClient ABC and RustFS S3 implementation"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 6: 依赖注入 + FastAPI 应用入口
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/core/dependencies.py`
|
|||
|
|
- Create: `app/main.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 实现 `app/core/dependencies.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
|
|||
|
|
_llm_client: LLMClient | None = None
|
|||
|
|
_storage_client: StorageClient | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def set_clients(llm: LLMClient, storage: StorageClient) -> None:
|
|||
|
|
global _llm_client, _storage_client
|
|||
|
|
_llm_client, _storage_client = llm, storage
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_llm_client() -> LLMClient:
|
|||
|
|
return _llm_client
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_storage_client() -> StorageClient:
|
|||
|
|
return _storage_client
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 实现 `app/main.py`**
|
|||
|
|
|
|||
|
|
注意:routers 在后续任务中创建,先注释掉 include_router,待各路由实现后逐步取消注释。
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import logging
|
|||
|
|
from contextlib import asynccontextmanager
|
|||
|
|
|
|||
|
|
from fastapi import FastAPI
|
|||
|
|
|
|||
|
|
from app.core.config import get_config
|
|||
|
|
from app.core.dependencies import set_clients
|
|||
|
|
from app.core.exceptions import (
|
|||
|
|
LLMCallError,
|
|||
|
|
LLMResponseParseError,
|
|||
|
|
StorageDownloadError,
|
|||
|
|
UnsupportedFileTypeError,
|
|||
|
|
generic_error_handler,
|
|||
|
|
llm_call_handler,
|
|||
|
|
llm_parse_handler,
|
|||
|
|
storage_download_handler,
|
|||
|
|
unsupported_file_type_handler,
|
|||
|
|
)
|
|||
|
|
from app.core.logging import request_logging_middleware, setup_logging
|
|||
|
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
|||
|
|
from app.clients.storage.rustfs_client import RustFSClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
@asynccontextmanager
|
|||
|
|
async def lifespan(app: FastAPI):
|
|||
|
|
cfg = get_config()
|
|||
|
|
setup_logging(cfg["server"]["log_level"])
|
|||
|
|
set_clients(
|
|||
|
|
llm=ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"]),
|
|||
|
|
storage=RustFSClient(
|
|||
|
|
endpoint=cfg["storage"]["endpoint"],
|
|||
|
|
access_key=cfg["storage"]["access_key"],
|
|||
|
|
secret_key=cfg["storage"]["secret_key"],
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
logging.getLogger("startup").info("AI 服务启动完成")
|
|||
|
|
yield
|
|||
|
|
logging.getLogger("startup").info("AI 服务关闭")
|
|||
|
|
|
|||
|
|
|
|||
|
|
app = FastAPI(title="Label AI Service", version="1.0.0", lifespan=lifespan)
|
|||
|
|
|
|||
|
|
app.middleware("http")(request_logging_middleware)
|
|||
|
|
|
|||
|
|
app.add_exception_handler(UnsupportedFileTypeError, unsupported_file_type_handler)
|
|||
|
|
app.add_exception_handler(StorageDownloadError, storage_download_handler)
|
|||
|
|
app.add_exception_handler(LLMResponseParseError, llm_parse_handler)
|
|||
|
|
app.add_exception_handler(LLMCallError, llm_call_handler)
|
|||
|
|
app.add_exception_handler(Exception, generic_error_handler)
|
|||
|
|
|
|||
|
|
# Routers registered after each task:
|
|||
|
|
# from app.routers import text, image, video, qa, finetune
|
|||
|
|
# app.include_router(text.router, prefix="/api/v1")
|
|||
|
|
# app.include_router(image.router, prefix="/api/v1")
|
|||
|
|
# app.include_router(video.router, prefix="/api/v1")
|
|||
|
|
# app.include_router(qa.router, prefix="/api/v1")
|
|||
|
|
# app.include_router(finetune.router, prefix="/api/v1")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/core/dependencies.py app/main.py
|
|||
|
|
git commit -m "feat: DI dependencies and FastAPI app entry with lifespan"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 7: Text Pydantic Models
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/models/text_models.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 实现 `app/models/text_models.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SourceOffset(BaseModel):
|
|||
|
|
start: int
|
|||
|
|
end: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TripleItem(BaseModel):
|
|||
|
|
subject: str
|
|||
|
|
predicate: str
|
|||
|
|
object: str
|
|||
|
|
source_snippet: str
|
|||
|
|
source_offset: SourceOffset
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TextExtractRequest(BaseModel):
|
|||
|
|
file_path: str
|
|||
|
|
file_name: str
|
|||
|
|
model: str | None = None
|
|||
|
|
prompt_template: str | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TextExtractResponse(BaseModel):
|
|||
|
|
items: list[TripleItem]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 快速验证 schema**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label python -c "
|
|||
|
|
from app.models.text_models import TextExtractRequest, TextExtractResponse, TripleItem, SourceOffset
|
|||
|
|
req = TextExtractRequest(file_path='text/1.txt', file_name='1.txt')
|
|||
|
|
item = TripleItem(subject='A', predicate='B', object='C', source_snippet='ABC', source_offset=SourceOffset(start=0, end=3))
|
|||
|
|
resp = TextExtractResponse(items=[item])
|
|||
|
|
print(resp.model_dump())
|
|||
|
|
"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: 打印出完整字典,无报错
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/models/text_models.py
|
|||
|
|
git commit -m "feat: text Pydantic models"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 8: Text Service
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/services/text_service.py`
|
|||
|
|
- Create: `tests/test_text_service.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_text_service.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from app.services.text_service import extract_triples, _extract_text_from_bytes
|
|||
|
|
from app.core.exceptions import UnsupportedFileTypeError, LLMResponseParseError, StorageDownloadError
|
|||
|
|
|
|||
|
|
TRIPLE_JSON = '[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压为110kV","source_offset":{"start":0,"end":10}}]'
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_triples_txt(mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes.return_value = b"变压器额定电压为110kV"
|
|||
|
|
mock_llm.chat.return_value = TRIPLE_JSON
|
|||
|
|
|
|||
|
|
result = await extract_triples(
|
|||
|
|
file_path="text/1.txt",
|
|||
|
|
file_name="test.txt",
|
|||
|
|
model="glm-4-flash",
|
|||
|
|
prompt_template="提取三元组:",
|
|||
|
|
llm=mock_llm,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
)
|
|||
|
|
assert len(result) == 1
|
|||
|
|
assert result[0].subject == "变压器"
|
|||
|
|
assert result[0].predicate == "额定电压"
|
|||
|
|
assert result[0].object == "110kV"
|
|||
|
|
assert result[0].source_offset.start == 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_triples_markdown_wrapped_json(mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes.return_value = b"some text"
|
|||
|
|
mock_llm.chat.return_value = f"```json\n{TRIPLE_JSON}\n```"
|
|||
|
|
|
|||
|
|
result = await extract_triples(
|
|||
|
|
file_path="text/1.txt",
|
|||
|
|
file_name="test.txt",
|
|||
|
|
model="glm-4-flash",
|
|||
|
|
prompt_template="",
|
|||
|
|
llm=mock_llm,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
)
|
|||
|
|
assert len(result) == 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_triples_storage_error(mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes.side_effect = Exception("connection refused")
|
|||
|
|
|
|||
|
|
with pytest.raises(StorageDownloadError):
|
|||
|
|
await extract_triples(
|
|||
|
|
file_path="text/1.txt",
|
|||
|
|
file_name="test.txt",
|
|||
|
|
model="glm-4-flash",
|
|||
|
|
prompt_template="",
|
|||
|
|
llm=mock_llm,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_triples_llm_parse_error(mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes.return_value = b"some text"
|
|||
|
|
mock_llm.chat.return_value = "这不是JSON"
|
|||
|
|
|
|||
|
|
with pytest.raises(LLMResponseParseError):
|
|||
|
|
await extract_triples(
|
|||
|
|
file_path="text/1.txt",
|
|||
|
|
file_name="test.txt",
|
|||
|
|
model="glm-4-flash",
|
|||
|
|
prompt_template="",
|
|||
|
|
llm=mock_llm,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_unsupported_file_type_raises():
|
|||
|
|
with pytest.raises(UnsupportedFileTypeError):
|
|||
|
|
_extract_text_from_bytes(b"content", "doc.xlsx")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_parse_txt_bytes():
|
|||
|
|
result = _extract_text_from_bytes("你好世界".encode("utf-8"), "file.txt")
|
|||
|
|
assert result == "你好世界"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_text_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError`
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 实现 `app/services/text_service.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import logging
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError, UnsupportedFileTypeError
|
|||
|
|
from app.core.json_utils import parse_json_response
|
|||
|
|
from app.models.text_models import SourceOffset, TripleItem
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
DEFAULT_PROMPT = """请从以下文本中提取知识三元组。
|
|||
|
|
对每个三元组提供:
|
|||
|
|
- subject:主语实体
|
|||
|
|
- predicate:谓语关系
|
|||
|
|
- object:宾语实体
|
|||
|
|
- source_snippet:原文中的证据片段(直接引用原文)
|
|||
|
|
- source_offset:证据片段字符偏移 {"start": N, "end": M}
|
|||
|
|
|
|||
|
|
以 JSON 数组格式返回,例如:
|
|||
|
|
[{"subject":"...","predicate":"...","object":"...","source_snippet":"...","source_offset":{"start":0,"end":50}}]
|
|||
|
|
|
|||
|
|
文本内容:
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_txt(data: bytes) -> str:
|
|||
|
|
return data.decode("utf-8")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_pdf(data: bytes) -> str:
|
|||
|
|
import io
|
|||
|
|
import pdfplumber
|
|||
|
|
with pdfplumber.open(io.BytesIO(data)) as pdf:
|
|||
|
|
return "\n".join(page.extract_text() or "" for page in pdf.pages)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_docx(data: bytes) -> str:
|
|||
|
|
import io
|
|||
|
|
import docx
|
|||
|
|
doc = docx.Document(io.BytesIO(data))
|
|||
|
|
return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
|
|||
|
|
|
|||
|
|
|
|||
|
|
_PARSERS = {
|
|||
|
|
".txt": _parse_txt,
|
|||
|
|
".pdf": _parse_pdf,
|
|||
|
|
".docx": _parse_docx,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_text_from_bytes(data: bytes, filename: str) -> str:
|
|||
|
|
ext = Path(filename).suffix.lower()
|
|||
|
|
parser = _PARSERS.get(ext)
|
|||
|
|
if parser is None:
|
|||
|
|
raise UnsupportedFileTypeError(ext)
|
|||
|
|
return parser(data)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def extract_triples(
|
|||
|
|
file_path: str,
|
|||
|
|
file_name: str,
|
|||
|
|
model: str,
|
|||
|
|
prompt_template: str,
|
|||
|
|
llm: LLMClient,
|
|||
|
|
storage: StorageClient,
|
|||
|
|
bucket: str = "source-data",
|
|||
|
|
) -> list[TripleItem]:
|
|||
|
|
try:
|
|||
|
|
data = await storage.download_bytes(bucket, file_path)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise StorageDownloadError(f"下载文件失败 {file_path}: {e}") from e
|
|||
|
|
|
|||
|
|
text = _extract_text_from_bytes(data, file_name)
|
|||
|
|
prompt = prompt_template or DEFAULT_PROMPT
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": "你是专业的知识图谱构建助手,擅长从文本中提取结构化知识三元组。"},
|
|||
|
|
{"role": "user", "content": prompt + text},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
raw = await llm.chat(messages, model)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise LLMCallError(f"GLM 调用失败: {e}") from e
|
|||
|
|
|
|||
|
|
logger.info(f"text_extract file={file_path} model={model}")
|
|||
|
|
|
|||
|
|
items_raw = parse_json_response(raw)
|
|||
|
|
|
|||
|
|
result = []
|
|||
|
|
for item in items_raw:
|
|||
|
|
try:
|
|||
|
|
offset = item.get("source_offset", {})
|
|||
|
|
result.append(TripleItem(
|
|||
|
|
subject=item["subject"],
|
|||
|
|
predicate=item["predicate"],
|
|||
|
|
object=item["object"],
|
|||
|
|
source_snippet=item.get("source_snippet", ""),
|
|||
|
|
source_offset=SourceOffset(
|
|||
|
|
start=offset.get("start", 0),
|
|||
|
|
end=offset.get("end", 0),
|
|||
|
|
),
|
|||
|
|
))
|
|||
|
|
except (KeyError, TypeError) as e:
|
|||
|
|
logger.warning(f"跳过不完整三元组: {item}, error: {e}")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_text_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `6 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/services/text_service.py tests/test_text_service.py
|
|||
|
|
git commit -m "feat: text service with txt/pdf/docx parsing and triple extraction"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 9: Text Router
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/routers/text.py`
|
|||
|
|
- Create: `tests/test_text_router.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_text_router.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from fastapi.testclient import TestClient
|
|||
|
|
from unittest.mock import AsyncMock, patch
|
|||
|
|
from app.main import app
|
|||
|
|
from app.core.dependencies import set_clients
|
|||
|
|
from app.models.text_models import TripleItem, SourceOffset
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def client(mock_llm, mock_storage):
|
|||
|
|
set_clients(mock_llm, mock_storage)
|
|||
|
|
return TestClient(app)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_text_extract_success(client, mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes = AsyncMock(return_value=b"变压器额定电压110kV")
|
|||
|
|
mock_llm.chat = AsyncMock(return_value='[{"subject":"变压器","predicate":"额定电压","object":"110kV","source_snippet":"额定电压110kV","source_offset":{"start":3,"end":10}}]')
|
|||
|
|
|
|||
|
|
resp = client.post("/api/v1/text/extract", json={
|
|||
|
|
"file_path": "text/202404/1.txt",
|
|||
|
|
"file_name": "规范.txt",
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 200
|
|||
|
|
data = resp.json()
|
|||
|
|
assert len(data["items"]) == 1
|
|||
|
|
assert data["items"][0]["subject"] == "变压器"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_text_extract_unsupported_file(client, mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes = AsyncMock(return_value=b"content")
|
|||
|
|
resp = client.post("/api/v1/text/extract", json={
|
|||
|
|
"file_path": "text/202404/1.xlsx",
|
|||
|
|
"file_name": "file.xlsx",
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 400
|
|||
|
|
assert resp.json()["code"] == "UNSUPPORTED_FILE_TYPE"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 实现 `app/routers/text.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, Depends
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.config import get_config
|
|||
|
|
from app.core.dependencies import get_llm_client, get_storage_client
|
|||
|
|
from app.models.text_models import TextExtractRequest, TextExtractResponse
|
|||
|
|
from app.services import text_service
|
|||
|
|
|
|||
|
|
router = APIRouter(tags=["Text"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/text/extract", response_model=TextExtractResponse)
|
|||
|
|
async def extract_text(
|
|||
|
|
req: TextExtractRequest,
|
|||
|
|
llm: LLMClient = Depends(get_llm_client),
|
|||
|
|
storage: StorageClient = Depends(get_storage_client),
|
|||
|
|
):
|
|||
|
|
cfg = get_config()
|
|||
|
|
model = req.model or cfg["models"]["default_text"]
|
|||
|
|
prompt = req.prompt_template or text_service.DEFAULT_PROMPT
|
|||
|
|
|
|||
|
|
items = await text_service.extract_triples(
|
|||
|
|
file_path=req.file_path,
|
|||
|
|
file_name=req.file_name,
|
|||
|
|
model=model,
|
|||
|
|
prompt_template=prompt,
|
|||
|
|
llm=llm,
|
|||
|
|
storage=storage,
|
|||
|
|
bucket=cfg["storage"]["buckets"]["source_data"],
|
|||
|
|
)
|
|||
|
|
return TextExtractResponse(items=items)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
|||
|
|
|
|||
|
|
取消注释以下两行:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.routers import text
|
|||
|
|
app.include_router(text.router, prefix="/api/v1")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 运行测试**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_text_router.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `2 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/routers/text.py tests/test_text_router.py app/main.py
|
|||
|
|
git commit -m "feat: text router POST /api/v1/text/extract"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 10: Image Models + Service
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/models/image_models.py`
|
|||
|
|
- Create: `app/services/image_service.py`
|
|||
|
|
- Create: `tests/test_image_service.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 实现 `app/models/image_models.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BBox(BaseModel):
|
|||
|
|
x: int
|
|||
|
|
y: int
|
|||
|
|
w: int
|
|||
|
|
h: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
class QuadrupleItem(BaseModel):
|
|||
|
|
subject: str
|
|||
|
|
predicate: str
|
|||
|
|
object: str
|
|||
|
|
qualifier: str
|
|||
|
|
bbox: BBox
|
|||
|
|
cropped_image_path: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ImageExtractRequest(BaseModel):
|
|||
|
|
file_path: str
|
|||
|
|
task_id: int
|
|||
|
|
model: str | None = None
|
|||
|
|
prompt_template: str | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ImageExtractResponse(BaseModel):
|
|||
|
|
items: list[QuadrupleItem]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_image_service.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
from app.services.image_service import extract_quadruples, _crop_image
|
|||
|
|
from app.models.image_models import BBox
|
|||
|
|
from app.core.exceptions import LLMResponseParseError, StorageDownloadError
|
|||
|
|
|
|||
|
|
QUAD_JSON = '[{"subject":"电缆接头","predicate":"位于","object":"配电箱左侧","qualifier":"2024年","bbox":{"x":10,"y":20,"w":50,"h":40}}]'
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _make_test_image_bytes(width=200, height=200) -> bytes:
|
|||
|
|
img = np.zeros((height, width, 3), dtype=np.uint8)
|
|||
|
|
img[:] = (100, 150, 200)
|
|||
|
|
_, buf = cv2.imencode(".jpg", img)
|
|||
|
|
return buf.tobytes()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_crop_image():
|
|||
|
|
img_bytes = _make_test_image_bytes(200, 200)
|
|||
|
|
bbox = BBox(x=10, y=20, w=50, h=40)
|
|||
|
|
result = _crop_image(img_bytes, bbox)
|
|||
|
|
assert isinstance(result, bytes)
|
|||
|
|
arr = np.frombuffer(result, dtype=np.uint8)
|
|||
|
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|||
|
|
assert img.shape[0] == 40 # height
|
|||
|
|
assert img.shape[1] == 50 # width
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_quadruples_success(mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes.return_value = _make_test_image_bytes()
|
|||
|
|
mock_llm.chat_vision.return_value = QUAD_JSON
|
|||
|
|
mock_storage.upload_bytes.return_value = None
|
|||
|
|
|
|||
|
|
result = await extract_quadruples(
|
|||
|
|
file_path="image/202404/1.jpg",
|
|||
|
|
task_id=789,
|
|||
|
|
model="glm-4v-flash",
|
|||
|
|
prompt_template="提取四元组",
|
|||
|
|
llm=mock_llm,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
)
|
|||
|
|
assert len(result) == 1
|
|||
|
|
assert result[0].subject == "电缆接头"
|
|||
|
|
assert result[0].cropped_image_path == "crops/789/0.jpg"
|
|||
|
|
mock_storage.upload_bytes.assert_called_once()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_quadruples_storage_error(mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes.side_effect = Exception("timeout")
|
|||
|
|
with pytest.raises(StorageDownloadError):
|
|||
|
|
await extract_quadruples(
|
|||
|
|
file_path="image/1.jpg",
|
|||
|
|
task_id=1,
|
|||
|
|
model="glm-4v-flash",
|
|||
|
|
prompt_template="",
|
|||
|
|
llm=mock_llm,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_quadruples_parse_error(mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes.return_value = _make_test_image_bytes()
|
|||
|
|
mock_llm.chat_vision.return_value = "不是JSON"
|
|||
|
|
with pytest.raises(LLMResponseParseError):
|
|||
|
|
await extract_quadruples(
|
|||
|
|
file_path="image/1.jpg",
|
|||
|
|
task_id=1,
|
|||
|
|
model="glm-4v-flash",
|
|||
|
|
prompt_template="",
|
|||
|
|
llm=mock_llm,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_image_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError`
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 实现 `app/services/image_service.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import base64
|
|||
|
|
import logging
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.exceptions import LLMCallError, LLMResponseParseError, StorageDownloadError
|
|||
|
|
from app.core.json_utils import parse_json_response
|
|||
|
|
from app.models.image_models import BBox, QuadrupleItem
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
DEFAULT_PROMPT = """请分析这张图片,提取知识四元组。
|
|||
|
|
对每个四元组提供:
|
|||
|
|
- subject:主体实体
|
|||
|
|
- predicate:关系/属性
|
|||
|
|
- object:客体实体
|
|||
|
|
- qualifier:修饰信息(时间、条件、场景,无则填空字符串)
|
|||
|
|
- bbox:边界框 {"x": N, "y": N, "w": N, "h": N}(像素坐标,相对原图)
|
|||
|
|
|
|||
|
|
以 JSON 数组格式返回:
|
|||
|
|
[{"subject":"...","predicate":"...","object":"...","qualifier":"...","bbox":{"x":0,"y":0,"w":100,"h":100}}]
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _crop_image(image_bytes: bytes, bbox: BBox) -> bytes:
|
|||
|
|
arr = np.frombuffer(image_bytes, dtype=np.uint8)
|
|||
|
|
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
|||
|
|
h, w = img.shape[:2]
|
|||
|
|
x = max(0, bbox.x)
|
|||
|
|
y = max(0, bbox.y)
|
|||
|
|
x2 = min(w, bbox.x + bbox.w)
|
|||
|
|
y2 = min(h, bbox.y + bbox.h)
|
|||
|
|
cropped = img[y:y2, x:x2]
|
|||
|
|
_, buf = cv2.imencode(".jpg", cropped, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
|||
|
|
return buf.tobytes()
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def extract_quadruples(
|
|||
|
|
file_path: str,
|
|||
|
|
task_id: int,
|
|||
|
|
model: str,
|
|||
|
|
prompt_template: str,
|
|||
|
|
llm: LLMClient,
|
|||
|
|
storage: StorageClient,
|
|||
|
|
source_bucket: str = "source-data",
|
|||
|
|
) -> list[QuadrupleItem]:
|
|||
|
|
try:
|
|||
|
|
data = await storage.download_bytes(source_bucket, file_path)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise StorageDownloadError(f"下载图片失败 {file_path}: {e}") from e
|
|||
|
|
|
|||
|
|
ext = Path(file_path).suffix.lstrip(".") or "jpeg"
|
|||
|
|
b64 = base64.b64encode(data).decode()
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": "你是专业的视觉分析助手,擅长从图像中提取结构化知识四元组。"},
|
|||
|
|
{"role": "user", "content": [
|
|||
|
|
{"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{b64}"}},
|
|||
|
|
{"type": "text", "text": prompt_template or DEFAULT_PROMPT},
|
|||
|
|
]},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
raw = await llm.chat_vision(messages, model)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise LLMCallError(f"GLM-4V 调用失败: {e}") from e
|
|||
|
|
|
|||
|
|
logger.info(f"image_extract file={file_path} model={model}")
|
|||
|
|
items_raw = parse_json_response(raw)
|
|||
|
|
|
|||
|
|
result = []
|
|||
|
|
for i, item in enumerate(items_raw):
|
|||
|
|
try:
|
|||
|
|
bbox = BBox(**item["bbox"])
|
|||
|
|
cropped = _crop_image(data, bbox)
|
|||
|
|
crop_path = f"crops/{task_id}/{i}.jpg"
|
|||
|
|
await storage.upload_bytes(source_bucket, crop_path, cropped, "image/jpeg")
|
|||
|
|
result.append(QuadrupleItem(
|
|||
|
|
subject=item["subject"],
|
|||
|
|
predicate=item["predicate"],
|
|||
|
|
object=item["object"],
|
|||
|
|
qualifier=item.get("qualifier", ""),
|
|||
|
|
bbox=bbox,
|
|||
|
|
cropped_image_path=crop_path,
|
|||
|
|
))
|
|||
|
|
except (KeyError, TypeError, Exception) as e:
|
|||
|
|
logger.warning(f"跳过不完整四元组 index={i}: {e}")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_image_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `4 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/models/image_models.py app/services/image_service.py tests/test_image_service.py
|
|||
|
|
git commit -m "feat: image models, service with bbox crop and quadruple extraction"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 11: Image Router
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/routers/image.py`
|
|||
|
|
- Create: `tests/test_image_router.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_image_router.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
import pytest
|
|||
|
|
from fastapi.testclient import TestClient
|
|||
|
|
from unittest.mock import AsyncMock
|
|||
|
|
from app.main import app
|
|||
|
|
from app.core.dependencies import set_clients
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _make_image_bytes() -> bytes:
|
|||
|
|
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
|||
|
|
_, buf = cv2.imencode(".jpg", img)
|
|||
|
|
return buf.tobytes()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def client(mock_llm, mock_storage):
|
|||
|
|
set_clients(mock_llm, mock_storage)
|
|||
|
|
return TestClient(app)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_image_extract_success(client, mock_llm, mock_storage):
|
|||
|
|
mock_storage.download_bytes = AsyncMock(return_value=_make_image_bytes())
|
|||
|
|
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
|||
|
|
mock_llm.chat_vision = AsyncMock(return_value='[{"subject":"A","predicate":"B","object":"C","qualifier":"","bbox":{"x":0,"y":0,"w":10,"h":10}}]')
|
|||
|
|
|
|||
|
|
resp = client.post("/api/v1/image/extract", json={
|
|||
|
|
"file_path": "image/202404/1.jpg",
|
|||
|
|
"task_id": 42,
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 200
|
|||
|
|
data = resp.json()
|
|||
|
|
assert len(data["items"]) == 1
|
|||
|
|
assert data["items"][0]["cropped_image_path"] == "crops/42/0.jpg"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 实现 `app/routers/image.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, Depends
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.config import get_config
|
|||
|
|
from app.core.dependencies import get_llm_client, get_storage_client
|
|||
|
|
from app.models.image_models import ImageExtractRequest, ImageExtractResponse
|
|||
|
|
from app.services import image_service
|
|||
|
|
|
|||
|
|
router = APIRouter(tags=["Image"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/image/extract", response_model=ImageExtractResponse)
|
|||
|
|
async def extract_image(
|
|||
|
|
req: ImageExtractRequest,
|
|||
|
|
llm: LLMClient = Depends(get_llm_client),
|
|||
|
|
storage: StorageClient = Depends(get_storage_client),
|
|||
|
|
):
|
|||
|
|
cfg = get_config()
|
|||
|
|
model = req.model or cfg["models"]["default_vision"]
|
|||
|
|
prompt = req.prompt_template or image_service.DEFAULT_PROMPT
|
|||
|
|
|
|||
|
|
items = await image_service.extract_quadruples(
|
|||
|
|
file_path=req.file_path,
|
|||
|
|
task_id=req.task_id,
|
|||
|
|
model=model,
|
|||
|
|
prompt_template=prompt,
|
|||
|
|
llm=llm,
|
|||
|
|
storage=storage,
|
|||
|
|
source_bucket=cfg["storage"]["buckets"]["source_data"],
|
|||
|
|
)
|
|||
|
|
return ImageExtractResponse(items=items)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.routers import text, image
|
|||
|
|
app.include_router(image.router, prefix="/api/v1")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 运行测试**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_image_router.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `1 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/routers/image.py tests/test_image_router.py app/main.py
|
|||
|
|
git commit -m "feat: image router POST /api/v1/image/extract"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 12: Video Models + Service
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/models/video_models.py`
|
|||
|
|
- Create: `app/services/video_service.py`
|
|||
|
|
- Create: `tests/test_video_service.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 实现 `app/models/video_models.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExtractFramesRequest(BaseModel):
|
|||
|
|
file_path: str
|
|||
|
|
source_id: int
|
|||
|
|
job_id: int
|
|||
|
|
mode: str = "interval" # interval | keyframe
|
|||
|
|
frame_interval: int = 30
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ExtractFramesResponse(BaseModel):
|
|||
|
|
message: str
|
|||
|
|
job_id: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FrameInfo(BaseModel):
|
|||
|
|
frame_index: int
|
|||
|
|
time_sec: float
|
|||
|
|
frame_path: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class VideoToTextRequest(BaseModel):
|
|||
|
|
file_path: str
|
|||
|
|
source_id: int
|
|||
|
|
job_id: int
|
|||
|
|
start_sec: float = 0.0
|
|||
|
|
end_sec: float
|
|||
|
|
model: str | None = None
|
|||
|
|
prompt_template: str | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class VideoToTextResponse(BaseModel):
|
|||
|
|
message: str
|
|||
|
|
job_id: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
class VideoJobCallback(BaseModel):
|
|||
|
|
job_id: int
|
|||
|
|
status: str # SUCCESS | FAILED
|
|||
|
|
frames: list[FrameInfo] | None = None
|
|||
|
|
output_path: str | None = None
|
|||
|
|
error_message: str | None = None
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_video_service.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import numpy as np
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|||
|
|
from app.services.video_service import _is_scene_change, extract_frames_background
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_is_scene_change_different_frames():
|
|||
|
|
prev = np.zeros((100, 100), dtype=np.uint8)
|
|||
|
|
curr = np.full((100, 100), 200, dtype=np.uint8)
|
|||
|
|
assert _is_scene_change(prev, curr, threshold=30.0) is True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_is_scene_change_similar_frames():
|
|||
|
|
prev = np.full((100, 100), 100, dtype=np.uint8)
|
|||
|
|
curr = np.full((100, 100), 101, dtype=np.uint8)
|
|||
|
|
assert _is_scene_change(prev, curr, threshold=30.0) is False
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_frames_background_calls_callback_on_success(mock_storage):
|
|||
|
|
import cv2
|
|||
|
|
import tempfile, os
|
|||
|
|
|
|||
|
|
# 创建一个有效的真实测试视频(5帧,10x10)
|
|||
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
|||
|
|
tmp_path = f.name
|
|||
|
|
|
|||
|
|
out = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*"mp4v"), 10, (10, 10))
|
|||
|
|
for _ in range(5):
|
|||
|
|
out.write(np.zeros((10, 10, 3), dtype=np.uint8))
|
|||
|
|
out.release()
|
|||
|
|
|
|||
|
|
with open(tmp_path, "rb") as f:
|
|||
|
|
video_bytes = f.read()
|
|||
|
|
os.unlink(tmp_path)
|
|||
|
|
|
|||
|
|
mock_storage.download_bytes.return_value = video_bytes
|
|||
|
|
mock_storage.upload_bytes = AsyncMock(return_value=None)
|
|||
|
|
|
|||
|
|
with patch("app.services.video_service.httpx") as mock_httpx:
|
|||
|
|
mock_client = AsyncMock()
|
|||
|
|
mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|||
|
|
mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
|
|||
|
|
mock_client.post = AsyncMock()
|
|||
|
|
|
|||
|
|
await extract_frames_background(
|
|||
|
|
file_path="video/1.mp4",
|
|||
|
|
source_id=10,
|
|||
|
|
job_id=42,
|
|||
|
|
mode="interval",
|
|||
|
|
frame_interval=1,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
callback_url="http://backend/callback",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
mock_client.post.assert_called_once()
|
|||
|
|
call_kwargs = mock_client.post.call_args
|
|||
|
|
payload = call_kwargs.kwargs.get("json") or call_kwargs.args[1] if len(call_kwargs.args) > 1 else call_kwargs.kwargs["json"]
|
|||
|
|
assert payload["job_id"] == 42
|
|||
|
|
assert payload["status"] == "SUCCESS"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_extract_frames_background_calls_callback_on_failure(mock_storage):
|
|||
|
|
mock_storage.download_bytes.side_effect = Exception("storage error")
|
|||
|
|
|
|||
|
|
with patch("app.services.video_service.httpx") as mock_httpx:
|
|||
|
|
mock_client = AsyncMock()
|
|||
|
|
mock_httpx.AsyncClient.return_value.__aenter__ = AsyncMock(return_value=mock_client)
|
|||
|
|
mock_httpx.AsyncClient.return_value.__aexit__ = AsyncMock(return_value=False)
|
|||
|
|
mock_client.post = AsyncMock()
|
|||
|
|
|
|||
|
|
await extract_frames_background(
|
|||
|
|
file_path="video/1.mp4",
|
|||
|
|
source_id=10,
|
|||
|
|
job_id=99,
|
|||
|
|
mode="interval",
|
|||
|
|
frame_interval=30,
|
|||
|
|
storage=mock_storage,
|
|||
|
|
callback_url="http://backend/callback",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
mock_client.post.assert_called_once()
|
|||
|
|
call_kwargs = mock_client.post.call_args
|
|||
|
|
payload = call_kwargs.kwargs.get("json") or (call_kwargs.args[1] if len(call_kwargs.args) > 1 else {})
|
|||
|
|
assert payload["status"] == "FAILED"
|
|||
|
|
assert payload["job_id"] == 99
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_video_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError`
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 实现 `app/services/video_service.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import base64
|
|||
|
|
import logging
|
|||
|
|
import tempfile
|
|||
|
|
import time
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import httpx
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.exceptions import LLMCallError
|
|||
|
|
from app.models.video_models import FrameInfo, VideoJobCallback
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
DEFAULT_VIDEO_TO_TEXT_PROMPT = """请分析这段视频的帧序列,用中文详细描述:
|
|||
|
|
1. 视频中出现的主要对象、设备、人物
|
|||
|
|
2. 发生的主要动作、操作步骤
|
|||
|
|
3. 场景的整体情况
|
|||
|
|
|
|||
|
|
请输出结构化的文字描述,适合作为知识图谱构建的文本素材。"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _is_scene_change(prev: np.ndarray, curr: np.ndarray, threshold: float = 30.0) -> bool:
|
|||
|
|
"""通过帧差分均值判断是否发生场景切换。"""
|
|||
|
|
diff = cv2.absdiff(prev, curr)
|
|||
|
|
return float(diff.mean()) > threshold
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_frames(
|
|||
|
|
video_path: str, mode: str, frame_interval: int
|
|||
|
|
) -> list[tuple[int, float, bytes]]:
|
|||
|
|
cap = cv2.VideoCapture(video_path)
|
|||
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
|||
|
|
results = []
|
|||
|
|
prev_gray = None
|
|||
|
|
idx = 0
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
ret, frame = cap.read()
|
|||
|
|
if not ret:
|
|||
|
|
break
|
|||
|
|
time_sec = idx / fps
|
|||
|
|
if mode == "interval":
|
|||
|
|
if idx % frame_interval == 0:
|
|||
|
|
_, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
|||
|
|
results.append((idx, time_sec, buf.tobytes()))
|
|||
|
|
else:
|
|||
|
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
|||
|
|
if prev_gray is None or _is_scene_change(prev_gray, gray):
|
|||
|
|
_, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
|||
|
|
results.append((idx, time_sec, buf.tobytes()))
|
|||
|
|
prev_gray = gray
|
|||
|
|
idx += 1
|
|||
|
|
|
|||
|
|
cap.release()
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _sample_frames_as_base64(
|
|||
|
|
video_path: str, start_sec: float, end_sec: float, count: int
|
|||
|
|
) -> list[str]:
|
|||
|
|
cap = cv2.VideoCapture(video_path)
|
|||
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
|||
|
|
start_frame = int(start_sec * fps)
|
|||
|
|
end_frame = int(end_sec * fps)
|
|||
|
|
total = max(1, end_frame - start_frame)
|
|||
|
|
step = max(1, total // count)
|
|||
|
|
results = []
|
|||
|
|
for i in range(count):
|
|||
|
|
frame_pos = start_frame + i * step
|
|||
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_pos)
|
|||
|
|
ret, frame = cap.read()
|
|||
|
|
if ret:
|
|||
|
|
_, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
|||
|
|
results.append(base64.b64encode(buf.tobytes()).decode())
|
|||
|
|
cap.release()
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _send_callback(url: str, payload: VideoJobCallback) -> None:
|
|||
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|||
|
|
try:
|
|||
|
|
await client.post(url, json=payload.model_dump())
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"回调失败 url={url}: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def extract_frames_background(
|
|||
|
|
file_path: str,
|
|||
|
|
source_id: int,
|
|||
|
|
job_id: int,
|
|||
|
|
mode: str,
|
|||
|
|
frame_interval: int,
|
|||
|
|
storage: StorageClient,
|
|||
|
|
callback_url: str,
|
|||
|
|
bucket: str = "source-data",
|
|||
|
|
) -> None:
|
|||
|
|
try:
|
|||
|
|
data = await storage.download_bytes(bucket, file_path)
|
|||
|
|
except Exception as e:
|
|||
|
|
await _send_callback(callback_url, VideoJobCallback(
|
|||
|
|
job_id=job_id, status="FAILED", error_message=str(e)
|
|||
|
|
))
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
suffix = Path(file_path).suffix or ".mp4"
|
|||
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
|||
|
|
tmp.write(data)
|
|||
|
|
tmp_path = tmp.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
frames = _extract_frames(tmp_path, mode, frame_interval)
|
|||
|
|
frame_infos = []
|
|||
|
|
for i, (frame_idx, time_sec, frame_data) in enumerate(frames):
|
|||
|
|
frame_path = f"frames/{source_id}/{i}.jpg"
|
|||
|
|
await storage.upload_bytes(bucket, frame_path, frame_data, "image/jpeg")
|
|||
|
|
frame_infos.append(FrameInfo(
|
|||
|
|
frame_index=frame_idx,
|
|||
|
|
time_sec=round(time_sec, 3),
|
|||
|
|
frame_path=frame_path,
|
|||
|
|
))
|
|||
|
|
await _send_callback(callback_url, VideoJobCallback(
|
|||
|
|
job_id=job_id, status="SUCCESS", frames=frame_infos
|
|||
|
|
))
|
|||
|
|
logger.info(f"extract_frames job_id={job_id} frames={len(frame_infos)}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.exception(f"extract_frames failed job_id={job_id}")
|
|||
|
|
await _send_callback(callback_url, VideoJobCallback(
|
|||
|
|
job_id=job_id, status="FAILED", error_message=str(e)
|
|||
|
|
))
|
|||
|
|
finally:
|
|||
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def video_to_text_background(
|
|||
|
|
file_path: str,
|
|||
|
|
source_id: int,
|
|||
|
|
job_id: int,
|
|||
|
|
start_sec: float,
|
|||
|
|
end_sec: float,
|
|||
|
|
model: str,
|
|||
|
|
prompt_template: str,
|
|||
|
|
frame_sample_count: int,
|
|||
|
|
llm: LLMClient,
|
|||
|
|
storage: StorageClient,
|
|||
|
|
callback_url: str,
|
|||
|
|
bucket: str = "source-data",
|
|||
|
|
) -> None:
|
|||
|
|
try:
|
|||
|
|
data = await storage.download_bytes(bucket, file_path)
|
|||
|
|
except Exception as e:
|
|||
|
|
await _send_callback(callback_url, VideoJobCallback(
|
|||
|
|
job_id=job_id, status="FAILED", error_message=str(e)
|
|||
|
|
))
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
suffix = Path(file_path).suffix or ".mp4"
|
|||
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
|||
|
|
tmp.write(data)
|
|||
|
|
tmp_path = tmp.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
frames_b64 = _sample_frames_as_base64(tmp_path, start_sec, end_sec, frame_sample_count)
|
|||
|
|
content: list = []
|
|||
|
|
for b64 in frames_b64:
|
|||
|
|
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}})
|
|||
|
|
content.append({
|
|||
|
|
"type": "text",
|
|||
|
|
"text": f"以上是视频第{start_sec}秒至{end_sec}秒的均匀采样帧。\n{prompt_template}",
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": "你是专业的视频内容分析助手。"},
|
|||
|
|
{"role": "user", "content": content},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
description = await llm.chat_vision(messages, model)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise LLMCallError(f"GLM-4V 调用失败: {e}") from e
|
|||
|
|
|
|||
|
|
timestamp = int(time.time())
|
|||
|
|
output_path = f"video-text/{source_id}/{timestamp}.txt"
|
|||
|
|
await storage.upload_bytes(bucket, output_path, description.encode("utf-8"), "text/plain")
|
|||
|
|
|
|||
|
|
await _send_callback(callback_url, VideoJobCallback(
|
|||
|
|
job_id=job_id, status="SUCCESS", output_path=output_path
|
|||
|
|
))
|
|||
|
|
logger.info(f"video_to_text job_id={job_id} output={output_path}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.exception(f"video_to_text failed job_id={job_id}")
|
|||
|
|
await _send_callback(callback_url, VideoJobCallback(
|
|||
|
|
job_id=job_id, status="FAILED", error_message=str(e)
|
|||
|
|
))
|
|||
|
|
finally:
|
|||
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_video_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `4 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/models/video_models.py app/services/video_service.py tests/test_video_service.py
|
|||
|
|
git commit -m "feat: video models and service with frame extraction and video-to-text"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 13: Video Router
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/routers/video.py`
|
|||
|
|
- Create: `tests/test_video_router.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_video_router.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from fastapi.testclient import TestClient
|
|||
|
|
from app.main import app
|
|||
|
|
from app.core.dependencies import set_clients
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def client(mock_llm, mock_storage):
|
|||
|
|
set_clients(mock_llm, mock_storage)
|
|||
|
|
return TestClient(app)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_extract_frames_returns_202(client):
|
|||
|
|
resp = client.post("/api/v1/video/extract-frames", json={
|
|||
|
|
"file_path": "video/202404/1.mp4",
|
|||
|
|
"source_id": 10,
|
|||
|
|
"job_id": 42,
|
|||
|
|
"mode": "interval",
|
|||
|
|
"frame_interval": 30,
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 202
|
|||
|
|
assert resp.json()["job_id"] == 42
|
|||
|
|
assert "后台处理中" in resp.json()["message"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_video_to_text_returns_202(client):
|
|||
|
|
resp = client.post("/api/v1/video/to-text", json={
|
|||
|
|
"file_path": "video/202404/1.mp4",
|
|||
|
|
"source_id": 10,
|
|||
|
|
"job_id": 43,
|
|||
|
|
"start_sec": 0,
|
|||
|
|
"end_sec": 60,
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 202
|
|||
|
|
assert resp.json()["job_id"] == 43
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 实现 `app/routers/video.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, BackgroundTasks, Depends
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.config import get_config
|
|||
|
|
from app.core.dependencies import get_llm_client, get_storage_client
|
|||
|
|
from app.models.video_models import (
|
|||
|
|
ExtractFramesRequest,
|
|||
|
|
ExtractFramesResponse,
|
|||
|
|
VideoToTextRequest,
|
|||
|
|
VideoToTextResponse,
|
|||
|
|
)
|
|||
|
|
from app.services import video_service
|
|||
|
|
|
|||
|
|
router = APIRouter(tags=["Video"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/video/extract-frames", response_model=ExtractFramesResponse, status_code=202)
|
|||
|
|
async def extract_frames(
|
|||
|
|
req: ExtractFramesRequest,
|
|||
|
|
background_tasks: BackgroundTasks,
|
|||
|
|
storage: StorageClient = Depends(get_storage_client),
|
|||
|
|
):
|
|||
|
|
cfg = get_config()
|
|||
|
|
background_tasks.add_task(
|
|||
|
|
video_service.extract_frames_background,
|
|||
|
|
file_path=req.file_path,
|
|||
|
|
source_id=req.source_id,
|
|||
|
|
job_id=req.job_id,
|
|||
|
|
mode=req.mode,
|
|||
|
|
frame_interval=req.frame_interval,
|
|||
|
|
storage=storage,
|
|||
|
|
callback_url=cfg["backend"]["callback_url"],
|
|||
|
|
bucket=cfg["storage"]["buckets"]["source_data"],
|
|||
|
|
)
|
|||
|
|
return ExtractFramesResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/video/to-text", response_model=VideoToTextResponse, status_code=202)
|
|||
|
|
async def video_to_text(
|
|||
|
|
req: VideoToTextRequest,
|
|||
|
|
background_tasks: BackgroundTasks,
|
|||
|
|
llm: LLMClient = Depends(get_llm_client),
|
|||
|
|
storage: StorageClient = Depends(get_storage_client),
|
|||
|
|
):
|
|||
|
|
cfg = get_config()
|
|||
|
|
model = req.model or cfg["models"]["default_vision"]
|
|||
|
|
prompt = req.prompt_template or video_service.DEFAULT_VIDEO_TO_TEXT_PROMPT
|
|||
|
|
background_tasks.add_task(
|
|||
|
|
video_service.video_to_text_background,
|
|||
|
|
file_path=req.file_path,
|
|||
|
|
source_id=req.source_id,
|
|||
|
|
job_id=req.job_id,
|
|||
|
|
start_sec=req.start_sec,
|
|||
|
|
end_sec=req.end_sec,
|
|||
|
|
model=model,
|
|||
|
|
prompt_template=prompt,
|
|||
|
|
frame_sample_count=cfg["video"]["frame_sample_count"],
|
|||
|
|
llm=llm,
|
|||
|
|
storage=storage,
|
|||
|
|
callback_url=cfg["backend"]["callback_url"],
|
|||
|
|
bucket=cfg["storage"]["buckets"]["source_data"],
|
|||
|
|
)
|
|||
|
|
return VideoToTextResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.routers import text, image, video
|
|||
|
|
app.include_router(video.router, prefix="/api/v1")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 运行测试**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_video_router.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `2 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/routers/video.py tests/test_video_router.py app/main.py
|
|||
|
|
git commit -m "feat: video router POST /api/v1/video/extract-frames and /to-text"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 14: QA Models + Service
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/models/qa_models.py`
|
|||
|
|
- Create: `app/services/qa_service.py`
|
|||
|
|
- Create: `tests/test_qa_service.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 实现 `app/models/qa_models.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TextTripleForQA(BaseModel):
|
|||
|
|
subject: str
|
|||
|
|
predicate: str
|
|||
|
|
object: str
|
|||
|
|
source_snippet: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TextQARequest(BaseModel):
|
|||
|
|
items: list[TextTripleForQA]
|
|||
|
|
model: str | None = None
|
|||
|
|
prompt_template: str | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class QAPair(BaseModel):
|
|||
|
|
question: str
|
|||
|
|
answer: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TextQAResponse(BaseModel):
|
|||
|
|
pairs: list[QAPair]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ImageQuadrupleForQA(BaseModel):
|
|||
|
|
subject: str
|
|||
|
|
predicate: str
|
|||
|
|
object: str
|
|||
|
|
qualifier: str
|
|||
|
|
cropped_image_path: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ImageQARequest(BaseModel):
|
|||
|
|
items: list[ImageQuadrupleForQA]
|
|||
|
|
model: str | None = None
|
|||
|
|
prompt_template: str | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ImageQAPair(BaseModel):
|
|||
|
|
question: str
|
|||
|
|
answer: str
|
|||
|
|
image_path: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ImageQAResponse(BaseModel):
|
|||
|
|
pairs: list[ImageQAPair]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_qa_service.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from app.services.qa_service import gen_text_qa, gen_image_qa, _parse_qa_pairs
|
|||
|
|
from app.models.qa_models import TextTripleForQA, ImageQuadrupleForQA
|
|||
|
|
from app.core.exceptions import LLMResponseParseError, LLMCallError
|
|||
|
|
|
|||
|
|
QA_JSON = '[{"question":"变压器额定电压是多少?","answer":"110kV"}]'
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_parse_qa_pairs_plain_json():
|
|||
|
|
result = _parse_qa_pairs(QA_JSON)
|
|||
|
|
assert len(result) == 1
|
|||
|
|
assert result[0].question == "变压器额定电压是多少?"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_parse_qa_pairs_markdown_wrapped():
|
|||
|
|
result = _parse_qa_pairs(f"```json\n{QA_JSON}\n```")
|
|||
|
|
assert len(result) == 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_parse_qa_pairs_invalid_raises():
|
|||
|
|
with pytest.raises(LLMResponseParseError):
|
|||
|
|
_parse_qa_pairs("这不是JSON")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_gen_text_qa(mock_llm):
|
|||
|
|
mock_llm.chat.return_value = QA_JSON
|
|||
|
|
items = [TextTripleForQA(subject="变压器", predicate="额定电压", object="110kV", source_snippet="额定电压为110kV")]
|
|||
|
|
|
|||
|
|
result = await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm)
|
|||
|
|
assert len(result) == 1
|
|||
|
|
assert result[0].answer == "110kV"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_gen_text_qa_llm_error(mock_llm):
|
|||
|
|
mock_llm.chat.side_effect = Exception("network error")
|
|||
|
|
items = [TextTripleForQA(subject="A", predicate="B", object="C", source_snippet="ABC")]
|
|||
|
|
|
|||
|
|
with pytest.raises(LLMCallError):
|
|||
|
|
await gen_text_qa(items=items, model="glm-4-flash", prompt_template="", llm=mock_llm)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_gen_image_qa(mock_llm, mock_storage):
|
|||
|
|
mock_llm.chat_vision.return_value = '[{"question":"图中是什么?","answer":"电缆接头"}]'
|
|||
|
|
items = [ImageQuadrupleForQA(
|
|||
|
|
subject="电缆接头", predicate="位于", object="配电箱", qualifier="", cropped_image_path="crops/1/0.jpg"
|
|||
|
|
)]
|
|||
|
|
|
|||
|
|
result = await gen_image_qa(items=items, model="glm-4v-flash", prompt_template="", llm=mock_llm, storage=mock_storage)
|
|||
|
|
assert len(result) == 1
|
|||
|
|
assert result[0].image_path == "crops/1/0.jpg"
|
|||
|
|
mock_storage.get_presigned_url.assert_called_once_with("source-data", "crops/1/0.jpg")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_qa_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError`
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 实现 `app/services/qa_service.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.exceptions import LLMCallError, LLMResponseParseError
|
|||
|
|
from app.core.json_utils import parse_json_response
|
|||
|
|
from app.models.qa_models import (
|
|||
|
|
ImageQAPair,
|
|||
|
|
ImageQuadrupleForQA,
|
|||
|
|
QAPair,
|
|||
|
|
TextTripleForQA,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
DEFAULT_TEXT_QA_PROMPT = """基于以下知识三元组和原文证据片段,生成高质量问答对。
|
|||
|
|
要求:
|
|||
|
|
1. 问题自然、具体,不能过于宽泛
|
|||
|
|
2. 答案基于原文片段,语言流畅
|
|||
|
|
3. 每个三元组生成1-2个问答对
|
|||
|
|
|
|||
|
|
以 JSON 数组格式返回:[{"question":"...","answer":"..."}]
|
|||
|
|
|
|||
|
|
三元组数据:
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
DEFAULT_IMAGE_QA_PROMPT = """基于图片内容和以下四元组信息,生成高质量图文问答对。
|
|||
|
|
要求:
|
|||
|
|
1. 问题需要结合图片才能回答
|
|||
|
|
2. 答案基于图片中的实际内容
|
|||
|
|
3. 每个四元组生成1个问答对
|
|||
|
|
|
|||
|
|
以 JSON 数组格式返回:[{"question":"...","answer":"..."}]
|
|||
|
|
|
|||
|
|
四元组信息:
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _parse_qa_pairs(raw: str) -> list[QAPair]:
|
|||
|
|
items_raw = parse_json_response(raw)
|
|||
|
|
result = []
|
|||
|
|
for item in items_raw:
|
|||
|
|
try:
|
|||
|
|
result.append(QAPair(question=item["question"], answer=item["answer"]))
|
|||
|
|
except KeyError as e:
|
|||
|
|
logger.warning(f"跳过不完整问答对: {item}, error: {e}")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def gen_text_qa(
|
|||
|
|
items: list[TextTripleForQA],
|
|||
|
|
model: str,
|
|||
|
|
prompt_template: str,
|
|||
|
|
llm: LLMClient,
|
|||
|
|
) -> list[QAPair]:
|
|||
|
|
triples_text = json.dumps([i.model_dump() for i in items], ensure_ascii=False, indent=2)
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": "你是专业的知识问答对生成助手。"},
|
|||
|
|
{"role": "user", "content": (prompt_template or DEFAULT_TEXT_QA_PROMPT) + triples_text},
|
|||
|
|
]
|
|||
|
|
try:
|
|||
|
|
raw = await llm.chat(messages, model)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise LLMCallError(f"GLM 调用失败: {e}") from e
|
|||
|
|
logger.info(f"gen_text_qa model={model} items={len(items)}")
|
|||
|
|
return _parse_qa_pairs(raw)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def gen_image_qa(
|
|||
|
|
items: list[ImageQuadrupleForQA],
|
|||
|
|
model: str,
|
|||
|
|
prompt_template: str,
|
|||
|
|
llm: LLMClient,
|
|||
|
|
storage: StorageClient,
|
|||
|
|
bucket: str = "source-data",
|
|||
|
|
) -> list[ImageQAPair]:
|
|||
|
|
result = []
|
|||
|
|
prompt = prompt_template or DEFAULT_IMAGE_QA_PROMPT
|
|||
|
|
for item in items:
|
|||
|
|
presigned_url = storage.get_presigned_url(bucket, item.cropped_image_path)
|
|||
|
|
quad_text = json.dumps(
|
|||
|
|
{k: v for k, v in item.model_dump().items() if k != "cropped_image_path"},
|
|||
|
|
ensure_ascii=False,
|
|||
|
|
)
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": "你是专业的视觉问答对生成助手。"},
|
|||
|
|
{"role": "user", "content": [
|
|||
|
|
{"type": "image_url", "image_url": {"url": presigned_url}},
|
|||
|
|
{"type": "text", "text": prompt + quad_text},
|
|||
|
|
]},
|
|||
|
|
]
|
|||
|
|
try:
|
|||
|
|
raw = await llm.chat_vision(messages, model)
|
|||
|
|
except Exception as e:
|
|||
|
|
raise LLMCallError(f"GLM-4V 调用失败: {e}") from e
|
|||
|
|
for pair in _parse_qa_pairs(raw):
|
|||
|
|
result.append(ImageQAPair(question=pair.question, answer=pair.answer, image_path=item.cropped_image_path))
|
|||
|
|
logger.info(f"gen_image_qa model={model} items={len(items)} pairs={len(result)}")
|
|||
|
|
return result
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_qa_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `6 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/models/qa_models.py app/services/qa_service.py tests/test_qa_service.py
|
|||
|
|
git commit -m "feat: QA models and service for text and image QA generation"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 15: QA Router
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/routers/qa.py`
|
|||
|
|
- Create: `tests/test_qa_router.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_qa_router.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from fastapi.testclient import TestClient
|
|||
|
|
from unittest.mock import AsyncMock
|
|||
|
|
from app.main import app
|
|||
|
|
from app.core.dependencies import set_clients
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def client(mock_llm, mock_storage):
|
|||
|
|
set_clients(mock_llm, mock_storage)
|
|||
|
|
return TestClient(app)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_gen_text_qa_success(client, mock_llm):
|
|||
|
|
mock_llm.chat = AsyncMock(return_value='[{"question":"额定电压?","answer":"110kV"}]')
|
|||
|
|
resp = client.post("/api/v1/qa/gen-text", json={
|
|||
|
|
"items": [{"subject": "变压器", "predicate": "额定电压", "object": "110kV", "source_snippet": "额定电压为110kV"}],
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 200
|
|||
|
|
assert resp.json()["pairs"][0]["question"] == "额定电压?"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_gen_image_qa_success(client, mock_llm, mock_storage):
|
|||
|
|
mock_llm.chat_vision = AsyncMock(return_value='[{"question":"图中是什么?","answer":"接头"}]')
|
|||
|
|
mock_storage.get_presigned_url.return_value = "https://example.com/crop.jpg"
|
|||
|
|
resp = client.post("/api/v1/qa/gen-image", json={
|
|||
|
|
"items": [{"subject": "A", "predicate": "B", "object": "C", "qualifier": "", "cropped_image_path": "crops/1/0.jpg"}],
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 200
|
|||
|
|
data = resp.json()
|
|||
|
|
assert data["pairs"][0]["image_path"] == "crops/1/0.jpg"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 实现 `app/routers/qa.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, Depends
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
from app.core.config import get_config
|
|||
|
|
from app.core.dependencies import get_llm_client, get_storage_client
|
|||
|
|
from app.models.qa_models import ImageQARequest, ImageQAResponse, TextQARequest, TextQAResponse
|
|||
|
|
from app.services import qa_service
|
|||
|
|
|
|||
|
|
router = APIRouter(tags=["QA"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/qa/gen-text", response_model=TextQAResponse)
|
|||
|
|
async def gen_text_qa(
|
|||
|
|
req: TextQARequest,
|
|||
|
|
llm: LLMClient = Depends(get_llm_client),
|
|||
|
|
):
|
|||
|
|
cfg = get_config()
|
|||
|
|
pairs = await qa_service.gen_text_qa(
|
|||
|
|
items=req.items,
|
|||
|
|
model=req.model or cfg["models"]["default_text"],
|
|||
|
|
prompt_template=req.prompt_template or qa_service.DEFAULT_TEXT_QA_PROMPT,
|
|||
|
|
llm=llm,
|
|||
|
|
)
|
|||
|
|
return TextQAResponse(pairs=pairs)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/qa/gen-image", response_model=ImageQAResponse)
|
|||
|
|
async def gen_image_qa(
|
|||
|
|
req: ImageQARequest,
|
|||
|
|
llm: LLMClient = Depends(get_llm_client),
|
|||
|
|
storage: StorageClient = Depends(get_storage_client),
|
|||
|
|
):
|
|||
|
|
cfg = get_config()
|
|||
|
|
pairs = await qa_service.gen_image_qa(
|
|||
|
|
items=req.items,
|
|||
|
|
model=req.model or cfg["models"]["default_vision"],
|
|||
|
|
prompt_template=req.prompt_template or qa_service.DEFAULT_IMAGE_QA_PROMPT,
|
|||
|
|
llm=llm,
|
|||
|
|
storage=storage,
|
|||
|
|
bucket=cfg["storage"]["buckets"]["source_data"],
|
|||
|
|
)
|
|||
|
|
return ImageQAResponse(pairs=pairs)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 在 `app/main.py` 注册路由**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.routers import text, image, video, qa
|
|||
|
|
app.include_router(qa.router, prefix="/api/v1")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 运行测试**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_qa_router.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `2 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/routers/qa.py tests/test_qa_router.py app/main.py
|
|||
|
|
git commit -m "feat: QA router POST /api/v1/qa/gen-text and /gen-image"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 16: Finetune Models + Service + Router
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `app/models/finetune_models.py`
|
|||
|
|
- Create: `app/services/finetune_service.py`
|
|||
|
|
- Create: `app/routers/finetune.py`
|
|||
|
|
- Create: `tests/test_finetune_service.py`
|
|||
|
|
- Create: `tests/test_finetune_router.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 实现 `app/models/finetune_models.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FinetuneHyperparams(BaseModel):
|
|||
|
|
learning_rate: float = 1e-4
|
|||
|
|
epochs: int = 3
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FinetuneStartRequest(BaseModel):
|
|||
|
|
jsonl_url: str
|
|||
|
|
base_model: str
|
|||
|
|
hyperparams: FinetuneHyperparams = FinetuneHyperparams()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FinetuneStartResponse(BaseModel):
|
|||
|
|
job_id: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FinetuneStatusResponse(BaseModel):
|
|||
|
|
job_id: str
|
|||
|
|
status: str # RUNNING | SUCCESS | FAILED
|
|||
|
|
progress: int | None = None
|
|||
|
|
error_message: str | None = None
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 编写失败测试**
|
|||
|
|
|
|||
|
|
`tests/test_finetune_service.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import MagicMock
|
|||
|
|
from app.services.finetune_service import start_finetune, get_finetune_status
|
|||
|
|
from app.models.finetune_models import FinetuneHyperparams
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_start_finetune():
|
|||
|
|
mock_job = MagicMock()
|
|||
|
|
mock_job.id = "glm-ft-abc123"
|
|||
|
|
mock_zhipuai = MagicMock()
|
|||
|
|
mock_zhipuai.fine_tuning.jobs.create.return_value = mock_job
|
|||
|
|
|
|||
|
|
result = await start_finetune(
|
|||
|
|
jsonl_url="https://example.com/export.jsonl",
|
|||
|
|
base_model="glm-4-flash",
|
|||
|
|
hyperparams=FinetuneHyperparams(learning_rate=1e-4, epochs=3),
|
|||
|
|
client=mock_zhipuai,
|
|||
|
|
)
|
|||
|
|
assert result == "glm-ft-abc123"
|
|||
|
|
mock_zhipuai.fine_tuning.jobs.create.assert_called_once()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_finetune_status_running():
|
|||
|
|
mock_job = MagicMock()
|
|||
|
|
mock_job.status = "running"
|
|||
|
|
mock_job.progress = 50
|
|||
|
|
mock_job.error = None
|
|||
|
|
mock_zhipuai = MagicMock()
|
|||
|
|
mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job
|
|||
|
|
|
|||
|
|
result = await get_finetune_status("glm-ft-abc123", mock_zhipuai)
|
|||
|
|
assert result.status == "RUNNING"
|
|||
|
|
assert result.progress == 50
|
|||
|
|
assert result.job_id == "glm-ft-abc123"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_get_finetune_status_success():
|
|||
|
|
mock_job = MagicMock()
|
|||
|
|
mock_job.status = "succeeded"
|
|||
|
|
mock_job.progress = 100
|
|||
|
|
mock_job.error = None
|
|||
|
|
mock_zhipuai = MagicMock()
|
|||
|
|
mock_zhipuai.fine_tuning.jobs.retrieve.return_value = mock_job
|
|||
|
|
|
|||
|
|
result = await get_finetune_status("glm-ft-abc123", mock_zhipuai)
|
|||
|
|
assert result.status == "SUCCESS"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 运行,确认失败**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_finetune_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ImportError`
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 实现 `app/services/finetune_service.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import logging
|
|||
|
|
|
|||
|
|
from app.models.finetune_models import FinetuneHyperparams, FinetuneStatusResponse
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
_STATUS_MAP = {
|
|||
|
|
"running": "RUNNING",
|
|||
|
|
"succeeded": "SUCCESS",
|
|||
|
|
"failed": "FAILED",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def start_finetune(
|
|||
|
|
jsonl_url: str,
|
|||
|
|
base_model: str,
|
|||
|
|
hyperparams: FinetuneHyperparams,
|
|||
|
|
client, # ZhipuAI SDK client instance
|
|||
|
|
) -> str:
|
|||
|
|
job = client.fine_tuning.jobs.create(
|
|||
|
|
training_file=jsonl_url,
|
|||
|
|
model=base_model,
|
|||
|
|
hyperparameters={
|
|||
|
|
"learning_rate_multiplier": hyperparams.learning_rate,
|
|||
|
|
"n_epochs": hyperparams.epochs,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
logger.info(f"finetune_start job_id={job.id} model={base_model}")
|
|||
|
|
return job.id
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_finetune_status(job_id: str, client) -> FinetuneStatusResponse:
|
|||
|
|
job = client.fine_tuning.jobs.retrieve(job_id)
|
|||
|
|
status = _STATUS_MAP.get(job.status, "RUNNING")
|
|||
|
|
return FinetuneStatusResponse(
|
|||
|
|
job_id=job_id,
|
|||
|
|
status=status,
|
|||
|
|
progress=getattr(job, "progress", None),
|
|||
|
|
error_message=getattr(job, "error", None),
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: 运行,确认通过**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/test_finetune_service.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `3 passed`
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: 实现 `app/routers/finetune.py`**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, Depends
|
|||
|
|
|
|||
|
|
from app.clients.llm.base import LLMClient
|
|||
|
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
|||
|
|
from app.core.dependencies import get_llm_client
|
|||
|
|
from app.models.finetune_models import (
|
|||
|
|
FinetuneStartRequest,
|
|||
|
|
FinetuneStartResponse,
|
|||
|
|
FinetuneStatusResponse,
|
|||
|
|
)
|
|||
|
|
from app.services import finetune_service
|
|||
|
|
|
|||
|
|
router = APIRouter(tags=["Finetune"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_zhipuai(llm: LLMClient = Depends(get_llm_client)) -> ZhipuAIClient:
|
|||
|
|
if not isinstance(llm, ZhipuAIClient):
|
|||
|
|
raise RuntimeError("微调功能仅支持 ZhipuAI 后端")
|
|||
|
|
return llm
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/finetune/start", response_model=FinetuneStartResponse)
|
|||
|
|
async def start_finetune(
|
|||
|
|
req: FinetuneStartRequest,
|
|||
|
|
llm: ZhipuAIClient = Depends(_get_zhipuai),
|
|||
|
|
):
|
|||
|
|
job_id = await finetune_service.start_finetune(
|
|||
|
|
jsonl_url=req.jsonl_url,
|
|||
|
|
base_model=req.base_model,
|
|||
|
|
hyperparams=req.hyperparams,
|
|||
|
|
client=llm._client,
|
|||
|
|
)
|
|||
|
|
return FinetuneStartResponse(job_id=job_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse)
|
|||
|
|
async def get_finetune_status(
|
|||
|
|
job_id: str,
|
|||
|
|
llm: ZhipuAIClient = Depends(_get_zhipuai),
|
|||
|
|
):
|
|||
|
|
return await finetune_service.get_finetune_status(job_id, llm._client)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 7: 编写路由测试**
|
|||
|
|
|
|||
|
|
`tests/test_finetune_router.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
import pytest
|
|||
|
|
from fastapi.testclient import TestClient
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
from app.main import app
|
|||
|
|
from app.core.dependencies import set_clients
|
|||
|
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
|||
|
|
from app.clients.storage.base import StorageClient
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def client(mock_storage):
|
|||
|
|
with patch("app.clients.llm.zhipuai_client.ZhipuAI") as MockZhipuAI:
|
|||
|
|
mock_sdk = MagicMock()
|
|||
|
|
MockZhipuAI.return_value = mock_sdk
|
|||
|
|
llm = ZhipuAIClient(api_key="test-key")
|
|||
|
|
llm._mock_sdk = mock_sdk
|
|||
|
|
set_clients(llm, mock_storage)
|
|||
|
|
yield TestClient(app), mock_sdk
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_start_finetune(client):
|
|||
|
|
test_client, mock_sdk = client
|
|||
|
|
mock_job = MagicMock()
|
|||
|
|
mock_job.id = "glm-ft-xyz"
|
|||
|
|
mock_sdk.fine_tuning.jobs.create.return_value = mock_job
|
|||
|
|
|
|||
|
|
resp = test_client.post("/api/v1/finetune/start", json={
|
|||
|
|
"jsonl_url": "https://example.com/export.jsonl",
|
|||
|
|
"base_model": "glm-4-flash",
|
|||
|
|
"hyperparams": {"learning_rate": 1e-4, "epochs": 3},
|
|||
|
|
})
|
|||
|
|
assert resp.status_code == 200
|
|||
|
|
assert resp.json()["job_id"] == "glm-ft-xyz"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_get_finetune_status(client):
|
|||
|
|
test_client, mock_sdk = client
|
|||
|
|
mock_job = MagicMock()
|
|||
|
|
mock_job.status = "running"
|
|||
|
|
mock_job.progress = 30
|
|||
|
|
mock_job.error = None
|
|||
|
|
mock_sdk.fine_tuning.jobs.retrieve.return_value = mock_job
|
|||
|
|
|
|||
|
|
resp = test_client.get("/api/v1/finetune/status/glm-ft-xyz")
|
|||
|
|
assert resp.status_code == 200
|
|||
|
|
data = resp.json()
|
|||
|
|
assert data["status"] == "RUNNING"
|
|||
|
|
assert data["progress"] == 30
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 8: 在 `app/main.py` 注册路由(最终状态)**
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.routers import text, image, video, qa, finetune
|
|||
|
|
|
|||
|
|
app.include_router(text.router, prefix="/api/v1")
|
|||
|
|
app.include_router(image.router, prefix="/api/v1")
|
|||
|
|
app.include_router(video.router, prefix="/api/v1")
|
|||
|
|
app.include_router(qa.router, prefix="/api/v1")
|
|||
|
|
app.include_router(finetune.router, prefix="/api/v1")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 9: 运行全部测试**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/ -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: 所有测试通过,无失败
|
|||
|
|
|
|||
|
|
- [ ] **Step 10: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add app/models/finetune_models.py app/services/finetune_service.py app/routers/finetune.py tests/test_finetune_service.py tests/test_finetune_router.py app/main.py
|
|||
|
|
git commit -m "feat: finetune models, service, and router - complete all endpoints"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Task 17: 部署文件
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `Dockerfile`
|
|||
|
|
- Create: `docker-compose.yml`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: 创建 `Dockerfile`**
|
|||
|
|
|
|||
|
|
```dockerfile
|
|||
|
|
FROM python:3.12-slim
|
|||
|
|
|
|||
|
|
WORKDIR /app
|
|||
|
|
|
|||
|
|
# OpenCV 系统依赖
|
|||
|
|
RUN apt-get update && apt-get install -y \
|
|||
|
|
libgl1 \
|
|||
|
|
libglib2.0-0 \
|
|||
|
|
&& rm -rf /var/lib/apt/lists/*
|
|||
|
|
|
|||
|
|
COPY requirements.txt .
|
|||
|
|
RUN pip install --no-cache-dir -r requirements.txt
|
|||
|
|
|
|||
|
|
COPY app/ ./app/
|
|||
|
|
COPY config.yaml .
|
|||
|
|
COPY .env .
|
|||
|
|
|
|||
|
|
EXPOSE 8000
|
|||
|
|
|
|||
|
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: 创建 `docker-compose.yml`**
|
|||
|
|
|
|||
|
|
```yaml
|
|||
|
|
version: "3.9"
|
|||
|
|
|
|||
|
|
services:
|
|||
|
|
ai-service:
|
|||
|
|
build: .
|
|||
|
|
ports:
|
|||
|
|
- "8000:8000"
|
|||
|
|
env_file:
|
|||
|
|
- .env
|
|||
|
|
depends_on:
|
|||
|
|
- rustfs
|
|||
|
|
networks:
|
|||
|
|
- label-net
|
|||
|
|
restart: unless-stopped
|
|||
|
|
|
|||
|
|
rustfs:
|
|||
|
|
image: minio/minio:latest
|
|||
|
|
command: server /data --console-address ":9001"
|
|||
|
|
ports:
|
|||
|
|
- "9000:9000"
|
|||
|
|
- "9001:9001"
|
|||
|
|
environment:
|
|||
|
|
MINIO_ROOT_USER: minioadmin
|
|||
|
|
MINIO_ROOT_PASSWORD: minioadmin
|
|||
|
|
volumes:
|
|||
|
|
- rustfs-data:/data
|
|||
|
|
networks:
|
|||
|
|
- label-net
|
|||
|
|
|
|||
|
|
volumes:
|
|||
|
|
rustfs-data:
|
|||
|
|
|
|||
|
|
networks:
|
|||
|
|
label-net:
|
|||
|
|
driver: bridge
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: 验证 Docker 构建**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
docker build -t label-ai-service:dev .
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: 镜像构建成功,无错误
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: 运行全量测试,最终确认**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
conda run -n label pytest tests/ -v --tb=short
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: 所有测试通过
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Commit**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
git add Dockerfile docker-compose.yml
|
|||
|
|
git commit -m "feat: Dockerfile and docker-compose for containerized deployment"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 自审检查结果
|
|||
|
|
|
|||
|
|
**Spec coverage:**
|
|||
|
|
- ✅ 文本三元组提取(TXT/PDF/DOCX)— Task 8-9
|
|||
|
|
- ✅ 图像四元组提取 + bbox 裁剪 — Task 10-11
|
|||
|
|
- ✅ 视频帧提取(interval/keyframe)— Task 12-13
|
|||
|
|
- ✅ 视频转文本(BackgroundTask)— Task 12-13
|
|||
|
|
- ✅ 文本问答对生成 — Task 14-15
|
|||
|
|
- ✅ 图像问答对生成 — Task 14-15
|
|||
|
|
- ✅ 微调任务提交与状态查询 — Task 16
|
|||
|
|
- ✅ LLMClient / StorageClient ABC 适配层 — Task 4-5
|
|||
|
|
- ✅ config.yaml + .env 分层配置 — Task 2
|
|||
|
|
- ✅ 结构化日志 + 请求日志 — Task 3
|
|||
|
|
- ✅ 全局异常处理 — Task 3
|
|||
|
|
- ✅ Swagger 文档(FastAPI 自动生成) — Task 6
|
|||
|
|
- ✅ Dockerfile + docker-compose — Task 17
|
|||
|
|
- ✅ pytest 测试覆盖全部 service 和 router — 各 Task
|
|||
|
|
|
|||
|
|
**类型一致性:** `TripleItem.source_offset` 在 Task 7 定义,Task 8 使用;`VideoJobCallback` 在 Task 12 定义,Task 12 service 使用 — 一致。
|
|||
|
|
|
|||
|
|
**占位符:** 无 TBD / TODO,所有步骤均含完整代码。
|