feat: Phase 1+2 — project setup and core infrastructure

- requirements.txt, config.yaml, .env, Dockerfile, docker-compose.yml
- app/core: config (YAML+env override), logging (JSON structured),
  exceptions (typed hierarchy), json_utils (Markdown fence stripping)
- app/clients: LLMClient ABC + ZhipuAIClient (run_in_executor),
  StorageClient ABC + RustFSClient (boto3 head_object for size check)
- app/main.py: FastAPI app with health endpoint and router registration
- app/core/dependencies.py: lru_cache singleton factories
- tests/conftest.py: mock_llm, mock_storage, test_app, client fixtures
- pytest.ini: asyncio_mode=auto
- 11 unit tests passing
This commit is contained in:
wh
2026-04-10 15:22:45 +08:00
parent 4162d9f4e6
commit e1eb5e47b1
54 changed files with 716 additions and 0 deletions

10
.env Normal file
View File

@@ -0,0 +1,10 @@
# Required — fill in before running
ZHIPUAI_API_KEY=your-zhipuai-api-key-here
STORAGE_ACCESS_KEY=your-storage-access-key
STORAGE_SECRET_KEY=your-storage-secret-key
STORAGE_ENDPOINT=http://rustfs:9000
# Optional overrides
BACKEND_CALLBACK_URL=http://label-backend:8080/api/ai/callback
LOG_LEVEL=INFO
# MAX_VIDEO_SIZE_MB=200

32
.gitignore vendored Normal file
View File

@@ -0,0 +1,32 @@
# ==========================================
# 1. Maven/Java 构建产物 (一键忽略整个目录)
# ==========================================
target/
*.class
*.jar
*.war
*.ear
# ==========================================
# 2. IDE 配置文件
# ==========================================
.idea/
.vscode/
*.iml
*.ipr
*.iws
# ==========================================
# 3. 项目特定工具目录 (根据你的文件列表)
# ==========================================
# 忽略 Specifiy 工具生成的所有配置和脚本
.specify/
# 忽略 Claude 本地设置和技能文件
.claude/
# ==========================================
# 4. 操作系统文件
# ==========================================
.DS_Store
Thumbs.db

18
Dockerfile Normal file
View File

@@ -0,0 +1,18 @@
FROM python:3.12-slim
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
libgl1 \
libglib2.0-0 \
curl \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

0
app/__init__.py Normal file
View File

Binary file not shown.

0
app/clients/__init__.py Normal file
View File

Binary file not shown.

View File

Binary file not shown.

Binary file not shown.

11
app/clients/llm/base.py Normal file
View File

@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
class LLMClient(ABC):
@abstractmethod
async def chat(self, model: str, messages: list[dict]) -> str:
"""Send a text chat request and return the response content string."""
@abstractmethod
async def chat_vision(self, model: str, messages: list[dict]) -> str:
"""Send a multimodal (vision) chat request and return the response content string."""

View File

@@ -0,0 +1,37 @@
import asyncio
from zhipuai import ZhipuAI
from app.clients.llm.base import LLMClient
from app.core.exceptions import LLMCallError
from app.core.logging import get_logger
logger = get_logger(__name__)
class ZhipuAIClient(LLMClient):
def __init__(self, api_key: str) -> None:
self._client = ZhipuAI(api_key=api_key)
async def chat(self, model: str, messages: list[dict]) -> str:
return await self._call(model, messages)
async def chat_vision(self, model: str, messages: list[dict]) -> str:
return await self._call(model, messages)
async def _call(self, model: str, messages: list[dict]) -> str:
loop = asyncio.get_event_loop()
try:
response = await loop.run_in_executor(
None,
lambda: self._client.chat.completions.create(
model=model,
messages=messages,
),
)
content = response.choices[0].message.content
logger.info("llm_call", extra={"model": model, "response_len": len(content)})
return content
except Exception as exc:
logger.error("llm_call_error", extra={"model": model, "error": str(exc)})
raise LLMCallError(f"大模型调用失败: {exc}") from exc

View File

Binary file not shown.

View File

@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
class StorageClient(ABC):
@abstractmethod
async def download_bytes(self, bucket: str, path: str) -> bytes:
"""Download an object and return its raw bytes."""
@abstractmethod
async def upload_bytes(
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
) -> None:
"""Upload raw bytes to the given bucket/path."""
@abstractmethod
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
"""Return a presigned GET URL valid for `expires` seconds."""
@abstractmethod
async def get_object_size(self, bucket: str, path: str) -> int:
"""Return the object size in bytes without downloading it."""

View File

@@ -0,0 +1,70 @@
import asyncio
import io
import boto3
from botocore.exceptions import ClientError
from app.clients.storage.base import StorageClient
from app.core.exceptions import StorageError
from app.core.logging import get_logger
logger = get_logger(__name__)
class RustFSClient(StorageClient):
def __init__(self, endpoint: str, access_key: str, secret_key: str) -> None:
self._s3 = boto3.client(
"s3",
endpoint_url=endpoint,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
)
async def download_bytes(self, bucket: str, path: str) -> bytes:
loop = asyncio.get_event_loop()
try:
resp = await loop.run_in_executor(
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
)
return resp["Body"].read()
except ClientError as exc:
raise StorageError(f"存储下载失败 [{bucket}/{path}]: {exc}") from exc
async def upload_bytes(
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
) -> None:
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(
None,
lambda: self._s3.put_object(
Bucket=bucket, Key=path, Body=io.BytesIO(data), ContentType=content_type
),
)
except ClientError as exc:
raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
loop = asyncio.get_event_loop()
try:
url = await loop.run_in_executor(
None,
lambda: self._s3.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": path},
ExpiresIn=expires,
),
)
return url
except ClientError as exc:
raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc
async def get_object_size(self, bucket: str, path: str) -> int:
loop = asyncio.get_event_loop()
try:
resp = await loop.run_in_executor(
None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
)
return resp["ContentLength"]
except ClientError as exc:
raise StorageError(f"获取文件大小失败 [{bucket}/{path}]: {exc}") from exc

0
app/core/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

46
app/core/config.py Normal file
View File

@@ -0,0 +1,46 @@
import os
from functools import lru_cache
from pathlib import Path
from typing import Any
import yaml
from dotenv import load_dotenv
load_dotenv()
# Maps environment variable names to nested YAML key paths
_ENV_OVERRIDES: dict[str, list[str]] = {
"ZHIPUAI_API_KEY": ["zhipuai", "api_key"],
"STORAGE_ACCESS_KEY": ["storage", "access_key"],
"STORAGE_SECRET_KEY": ["storage", "secret_key"],
"STORAGE_ENDPOINT": ["storage", "endpoint"],
"BACKEND_CALLBACK_URL": ["backend", "callback_url"],
"LOG_LEVEL": ["server", "log_level"],
"MAX_VIDEO_SIZE_MB": ["video", "max_file_size_mb"],
}
_CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
def _set_nested(cfg: dict, keys: list[str], value: Any) -> None:
for key in keys[:-1]:
cfg = cfg.setdefault(key, {})
# Coerce numeric env vars
try:
value = int(value)
except (TypeError, ValueError):
pass
cfg[keys[-1]] = value
@lru_cache(maxsize=1)
def get_config() -> dict:
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
cfg: dict = yaml.safe_load(f)
for env_var, key_path in _ENV_OVERRIDES.items():
value = os.environ.get(env_var)
if value is not None:
_set_nested(cfg, key_path, value)
return cfg

23
app/core/dependencies.py Normal file
View File

@@ -0,0 +1,23 @@
from functools import lru_cache
from app.clients.llm.base import LLMClient
from app.clients.llm.zhipuai_client import ZhipuAIClient
from app.clients.storage.base import StorageClient
from app.clients.storage.rustfs_client import RustFSClient
from app.core.config import get_config
@lru_cache(maxsize=1)
def get_llm_client() -> LLMClient:
cfg = get_config()
return ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"])
@lru_cache(maxsize=1)
def get_storage_client() -> StorageClient:
cfg = get_config()
return RustFSClient(
endpoint=cfg["storage"]["endpoint"],
access_key=cfg["storage"]["access_key"],
secret_key=cfg["storage"]["secret_key"],
)

50
app/core/exceptions.py Normal file
View File

@@ -0,0 +1,50 @@
from fastapi import Request
from fastapi.responses import JSONResponse
class AIServiceError(Exception):
status_code: int = 500
code: str = "INTERNAL_ERROR"
def __init__(self, message: str) -> None:
self.message = message
super().__init__(message)
class UnsupportedFileTypeError(AIServiceError):
status_code = 400
code = "UNSUPPORTED_FILE_TYPE"
class VideoTooLargeError(AIServiceError):
status_code = 400
code = "VIDEO_TOO_LARGE"
class StorageError(AIServiceError):
status_code = 502
code = "STORAGE_ERROR"
class LLMParseError(AIServiceError):
status_code = 502
code = "LLM_PARSE_ERROR"
class LLMCallError(AIServiceError):
status_code = 503
code = "LLM_CALL_ERROR"
async def ai_service_exception_handler(request: Request, exc: AIServiceError) -> JSONResponse:
return JSONResponse(
status_code=exc.status_code,
content={"code": exc.code, "message": exc.message},
)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse(
status_code=500,
content={"code": "INTERNAL_ERROR", "message": str(exc)},
)

19
app/core/json_utils.py Normal file
View File

@@ -0,0 +1,19 @@
import json
import re
from app.core.exceptions import LLMParseError
def extract_json(text: str) -> any:
"""Parse JSON from LLM response, stripping Markdown code fences if present."""
text = text.strip()
# Strip ```json ... ``` or ``` ... ``` fences
fence_match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", text)
if fence_match:
text = fence_match.group(1).strip()
try:
return json.loads(text)
except json.JSONDecodeError as e:
raise LLMParseError(f"大模型返回非合法 JSON: {e}") from e

62
app/core/logging.py Normal file
View File

@@ -0,0 +1,62 @@
import json
import logging
import time
from typing import Callable
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
def get_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(_JsonFormatter())
logger.addHandler(handler)
logger.propagate = False
return logger
class _JsonFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
payload = {
"time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info:
payload["exc_info"] = self.formatException(record.exc_info)
# Merge any extra fields passed via `extra=`
for key, value in record.__dict__.items():
if key not in (
"name", "msg", "args", "levelname", "levelno", "pathname",
"filename", "module", "exc_info", "exc_text", "stack_info",
"lineno", "funcName", "created", "msecs", "relativeCreated",
"thread", "threadName", "processName", "process", "message",
"taskName",
):
payload[key] = value
return json.dumps(payload, ensure_ascii=False)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
def __init__(self, app, logger: logging.Logger | None = None) -> None:
super().__init__(app)
self._logger = logger or get_logger("request")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start = time.perf_counter()
response = await call_next(request)
duration_ms = round((time.perf_counter() - start) * 1000, 1)
self._logger.info(
"request",
extra={
"method": request.method,
"path": request.url.path,
"status": response.status_code,
"duration_ms": duration_ms,
},
)
return response

46
app/main.py Normal file
View File

@@ -0,0 +1,46 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.core.exceptions import (
AIServiceError,
ai_service_exception_handler,
unhandled_exception_handler,
)
from app.core.logging import RequestLoggingMiddleware, get_logger
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("startup", extra={"message": "AI service starting"})
yield
logger.info("shutdown", extra={"message": "AI service stopping"})
app = FastAPI(
title="Label AI Service",
description="知识图谱标注平台 AI 计算服务",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(RequestLoggingMiddleware)
app.add_exception_handler(AIServiceError, ai_service_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)
@app.get("/health", tags=["Health"])
async def health():
return {"status": "ok"}
# Routers registered after implementation (imported lazily to avoid circular deps)
from app.routers import text, image, video, qa, finetune # noqa: E402
app.include_router(text.router, prefix="/api/v1")
app.include_router(image.router, prefix="/api/v1")
app.include_router(video.router, prefix="/api/v1")
app.include_router(qa.router, prefix="/api/v1")
app.include_router(finetune.router, prefix="/api/v1")

0
app/models/__init__.py Normal file
View File

0
app/routers/__init__.py Normal file
View File

3
app/routers/finetune.py Normal file
View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter(tags=["Finetune"])

3
app/routers/image.py Normal file
View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter(tags=["Image"])

3
app/routers/qa.py Normal file
View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter(tags=["QA"])

3
app/routers/text.py Normal file
View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter(tags=["Text"])

3
app/routers/video.py Normal file
View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router = APIRouter(tags=["Video"])

0
app/services/__init__.py Normal file
View File

19
config.yaml Normal file
View File

@@ -0,0 +1,19 @@
server:
port: 8000
log_level: INFO
storage:
buckets:
source_data: "source-data"
finetune_export: "finetune-export"
backend: {} # callback_url injected via BACKEND_CALLBACK_URL env var
video:
frame_sample_count: 8 # uniform frames sampled for video-to-text
max_file_size_mb: 200 # video size limit (override with MAX_VIDEO_SIZE_MB)
keyframe_diff_threshold: 30.0 # grayscale mean-diff threshold for keyframe detection
models:
default_text: "glm-4-flash"
default_vision: "glm-4v-flash"

37
docker-compose.yml Normal file
View File

@@ -0,0 +1,37 @@
version: "3.9"
services:
ai-service:
build: .
ports:
- "8000:8000"
env_file:
- .env
depends_on:
rustfs:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 5s
retries: 3
start_period: 10s
rustfs:
image: rustfs/rustfs:latest
ports:
- "9000:9000"
environment:
RUSTFS_ACCESS_KEY: ${STORAGE_ACCESS_KEY}
RUSTFS_SECRET_KEY: ${STORAGE_SECRET_KEY}
volumes:
- rustfs_data:/data
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/health"]
interval: 10s
timeout: 3s
retries: 5
start_period: 5s
volumes:
rustfs_data:

3
pytest.ini Normal file
View File

@@ -0,0 +1,3 @@
[pytest]
asyncio_mode = auto
testpaths = tests

16
requirements.txt Normal file
View File

@@ -0,0 +1,16 @@
fastapi>=0.111.0
uvicorn[standard]>=0.29.0
pydantic>=2.7.0
zhipuai>=2.1.0
boto3>=1.34.0
pdfplumber>=0.11.0
python-docx>=1.1.0
opencv-python-headless>=4.9.0
numpy>=1.26.0
httpx>=0.27.0
python-dotenv>=1.0.0
pyyaml>=6.0.0
# Testing
pytest>=8.0.0
pytest-asyncio>=0.23.0

0
tests/__init__.py Normal file
View File

Binary file not shown.

39
tests/conftest.py Normal file
View File

@@ -0,0 +1,39 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from fastapi.testclient import TestClient
from app.clients.llm.base import LLMClient
from app.clients.storage.base import StorageClient
from app.core.dependencies import get_llm_client, get_storage_client
@pytest.fixture
def mock_llm() -> LLMClient:
client = MagicMock(spec=LLMClient)
client.chat = AsyncMock(return_value='[]')
client.chat_vision = AsyncMock(return_value='[]')
return client
@pytest.fixture
def mock_storage() -> StorageClient:
client = MagicMock(spec=StorageClient)
client.download_bytes = AsyncMock(return_value=b"")
client.upload_bytes = AsyncMock(return_value=None)
client.get_presigned_url = AsyncMock(return_value="http://example.com/presigned")
client.get_object_size = AsyncMock(return_value=10 * 1024 * 1024) # 10 MB default
return client
@pytest.fixture
def test_app(mock_llm, mock_storage):
from app.main import app
app.dependency_overrides[get_llm_client] = lambda: mock_llm
app.dependency_overrides[get_storage_client] = lambda: mock_storage
yield app
app.dependency_overrides.clear()
@pytest.fixture
def client(test_app):
return TestClient(test_app)

40
tests/test_config.py Normal file
View File

@@ -0,0 +1,40 @@
import os
import pytest
def test_yaml_defaults_load(monkeypatch):
# Clear lru_cache so each test gets a fresh load
from app.core import config as cfg_module
cfg_module.get_config.cache_clear()
# Remove env overrides that might bleed from shell environment
for var in ["MAX_VIDEO_SIZE_MB", "LOG_LEVEL", "STORAGE_ENDPOINT"]:
monkeypatch.delenv(var, raising=False)
cfg = cfg_module.get_config()
assert cfg["server"]["port"] == 8000
assert cfg["video"]["max_file_size_mb"] == 200
assert cfg["models"]["default_text"] == "glm-4-flash"
assert cfg["models"]["default_vision"] == "glm-4v-flash"
assert cfg["storage"]["buckets"]["source_data"] == "source-data"
def test_max_video_size_env_override(monkeypatch):
from app.core import config as cfg_module
cfg_module.get_config.cache_clear()
monkeypatch.setenv("MAX_VIDEO_SIZE_MB", "500")
cfg = cfg_module.get_config()
assert cfg["video"]["max_file_size_mb"] == 500
def test_log_level_env_override(monkeypatch):
from app.core import config as cfg_module
cfg_module.get_config.cache_clear()
monkeypatch.setenv("LOG_LEVEL", "DEBUG")
cfg = cfg_module.get_config()
assert cfg["server"]["log_level"] == "DEBUG"

40
tests/test_llm_client.py Normal file
View File

@@ -0,0 +1,40 @@
import pytest
from unittest.mock import MagicMock, patch
from app.clients.llm.zhipuai_client import ZhipuAIClient
from app.core.exceptions import LLMCallError
@pytest.fixture
def mock_sdk_response():
resp = MagicMock()
resp.choices[0].message.content = '{"result": "ok"}'
return resp
@pytest.fixture
def client():
with patch("app.clients.llm.zhipuai_client.ZhipuAI"):
c = ZhipuAIClient(api_key="test-key")
return c
@pytest.mark.asyncio
async def test_chat_returns_content(client, mock_sdk_response):
client._client.chat.completions.create.return_value = mock_sdk_response
result = await client.chat("glm-4-flash", [{"role": "user", "content": "hello"}])
assert result == '{"result": "ok"}'
@pytest.mark.asyncio
async def test_chat_vision_returns_content(client, mock_sdk_response):
client._client.chat.completions.create.return_value = mock_sdk_response
result = await client.chat_vision("glm-4v-flash", [{"role": "user", "content": []}])
assert result == '{"result": "ok"}'
@pytest.mark.asyncio
async def test_llm_call_error_on_sdk_exception(client):
client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded")
with pytest.raises(LLMCallError, match="大模型调用失败"):
await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}])

View File

@@ -0,0 +1,62 @@
import pytest
from unittest.mock import MagicMock, patch
from botocore.exceptions import ClientError
from app.clients.storage.rustfs_client import RustFSClient
from app.core.exceptions import StorageError
@pytest.fixture
def client():
with patch("app.clients.storage.rustfs_client.boto3") as mock_boto3:
c = RustFSClient(
endpoint="http://rustfs:9000",
access_key="key",
secret_key="secret",
)
c._s3 = MagicMock()
return c
@pytest.mark.asyncio
async def test_download_bytes_returns_bytes(client):
client._s3.get_object.return_value = {"Body": MagicMock(read=lambda: b"hello")}
result = await client.download_bytes("source-data", "text/test.txt")
assert result == b"hello"
client._s3.get_object.assert_called_once_with(Bucket="source-data", Key="text/test.txt")
@pytest.mark.asyncio
async def test_download_bytes_raises_storage_error(client):
client._s3.get_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, "GetObject"
)
with pytest.raises(StorageError, match="存储下载失败"):
await client.download_bytes("source-data", "missing.txt")
@pytest.mark.asyncio
async def test_get_object_size_returns_content_length(client):
client._s3.head_object.return_value = {"ContentLength": 1024}
size = await client.get_object_size("source-data", "video/test.mp4")
assert size == 1024
client._s3.head_object.assert_called_once_with(Bucket="source-data", Key="video/test.mp4")
@pytest.mark.asyncio
async def test_get_object_size_raises_storage_error(client):
client._s3.head_object.side_effect = ClientError(
{"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, "HeadObject"
)
with pytest.raises(StorageError, match="获取文件大小失败"):
await client.get_object_size("source-data", "video/missing.mp4")
@pytest.mark.asyncio
async def test_upload_bytes_calls_put_object(client):
client._s3.put_object.return_value = {}
await client.upload_bytes("source-data", "frames/1/0.jpg", b"jpeg-data", "image/jpeg")
client._s3.put_object.assert_called_once()
call_kwargs = client._s3.put_object.call_args
assert call_kwargs.kwargs["Bucket"] == "source-data"
assert call_kwargs.kwargs["Key"] == "frames/1/0.jpg"