feat: Phase 1+2 — project setup and core infrastructure
- requirements.txt, config.yaml, .env, Dockerfile, docker-compose.yml - app/core: config (YAML+env override), logging (JSON structured), exceptions (typed hierarchy), json_utils (Markdown fence stripping) - app/clients: LLMClient ABC + ZhipuAIClient (run_in_executor), StorageClient ABC + RustFSClient (boto3 head_object for size check) - app/main.py: FastAPI app with health endpoint and router registration - app/core/dependencies.py: lru_cache singleton factories - tests/conftest.py: mock_llm, mock_storage, test_app, client fixtures - pytest.ini: asyncio_mode=auto - 11 unit tests passing
This commit is contained in:
10
.env
Normal file
10
.env
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Required — fill in before running
|
||||||
|
ZHIPUAI_API_KEY=your-zhipuai-api-key-here
|
||||||
|
STORAGE_ACCESS_KEY=your-storage-access-key
|
||||||
|
STORAGE_SECRET_KEY=your-storage-secret-key
|
||||||
|
STORAGE_ENDPOINT=http://rustfs:9000
|
||||||
|
|
||||||
|
# Optional overrides
|
||||||
|
BACKEND_CALLBACK_URL=http://label-backend:8080/api/ai/callback
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
# MAX_VIDEO_SIZE_MB=200
|
||||||
32
.gitignore
vendored
Normal file
32
.gitignore
vendored
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# ==========================================
|
||||||
|
# 1. Maven/Java 构建产物 (一键忽略整个目录)
|
||||||
|
# ==========================================
|
||||||
|
target/
|
||||||
|
*.class
|
||||||
|
*.jar
|
||||||
|
*.war
|
||||||
|
*.ear
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 2. IDE 配置文件
|
||||||
|
# ==========================================
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.iml
|
||||||
|
*.ipr
|
||||||
|
*.iws
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 3. 项目特定工具目录 (根据你的文件列表)
|
||||||
|
# ==========================================
|
||||||
|
# 忽略 Specifiy 工具生成的所有配置和脚本
|
||||||
|
.specify/
|
||||||
|
|
||||||
|
# 忽略 Claude 本地设置和技能文件
|
||||||
|
.claude/
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 4. 操作系统文件
|
||||||
|
# ==========================================
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
18
Dockerfile
Normal file
18
Dockerfile
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libgl1 \
|
||||||
|
libglib2.0-0 \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
BIN
app/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
0
app/clients/__init__.py
Normal file
0
app/clients/__init__.py
Normal file
BIN
app/clients/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/clients/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
0
app/clients/llm/__init__.py
Normal file
0
app/clients/llm/__init__.py
Normal file
BIN
app/clients/llm/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/clients/llm/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/clients/llm/__pycache__/base.cpython-312.pyc
Normal file
BIN
app/clients/llm/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/clients/llm/__pycache__/zhipuai_client.cpython-312.pyc
Normal file
BIN
app/clients/llm/__pycache__/zhipuai_client.cpython-312.pyc
Normal file
Binary file not shown.
11
app/clients/llm/base.py
Normal file
11
app/clients/llm/base.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(self, model: str, messages: list[dict]) -> str:
|
||||||
|
"""Send a text chat request and return the response content string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||||
|
"""Send a multimodal (vision) chat request and return the response content string."""
|
||||||
37
app/clients/llm/zhipuai_client.py
Normal file
37
app/clients/llm/zhipuai_client.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.core.exceptions import LLMCallError
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIClient(LLMClient):
|
||||||
|
def __init__(self, api_key: str) -> None:
|
||||||
|
self._client = ZhipuAI(api_key=api_key)
|
||||||
|
|
||||||
|
async def chat(self, model: str, messages: list[dict]) -> str:
|
||||||
|
return await self._call(model, messages)
|
||||||
|
|
||||||
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||||
|
return await self._call(model, messages)
|
||||||
|
|
||||||
|
async def _call(self, model: str, messages: list[dict]) -> str:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
logger.info("llm_call", extra={"model": model, "response_len": len(content)})
|
||||||
|
return content
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("llm_call_error", extra={"model": model, "error": str(exc)})
|
||||||
|
raise LLMCallError(f"大模型调用失败: {exc}") from exc
|
||||||
0
app/clients/storage/__init__.py
Normal file
0
app/clients/storage/__init__.py
Normal file
BIN
app/clients/storage/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/clients/storage/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/clients/storage/__pycache__/base.cpython-312.pyc
Normal file
BIN
app/clients/storage/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/clients/storage/__pycache__/rustfs_client.cpython-312.pyc
Normal file
BIN
app/clients/storage/__pycache__/rustfs_client.cpython-312.pyc
Normal file
Binary file not shown.
21
app/clients/storage/base.py
Normal file
21
app/clients/storage/base.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class StorageClient(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||||||
|
"""Download an object and return its raw bytes."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def upload_bytes(
|
||||||
|
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
||||||
|
) -> None:
|
||||||
|
"""Upload raw bytes to the given bucket/path."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||||||
|
"""Return a presigned GET URL valid for `expires` seconds."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_object_size(self, bucket: str, path: str) -> int:
|
||||||
|
"""Return the object size in bytes without downloading it."""
|
||||||
70
app/clients/storage/rustfs_client.py
Normal file
70
app/clients/storage/rustfs_client.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.exceptions import StorageError
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RustFSClient(StorageClient):
|
||||||
|
def __init__(self, endpoint: str, access_key: str, secret_key: str) -> None:
|
||||||
|
self._s3 = boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=endpoint,
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
|
||||||
|
)
|
||||||
|
return resp["Body"].read()
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"存储下载失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
|
async def upload_bytes(
|
||||||
|
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
||||||
|
) -> None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._s3.put_object(
|
||||||
|
Bucket=bucket, Key=path, Body=io.BytesIO(data), ContentType=content_type
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
|
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
url = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._s3.generate_presigned_url(
|
||||||
|
"get_object",
|
||||||
|
Params={"Bucket": bucket, "Key": path},
|
||||||
|
ExpiresIn=expires,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
|
async def get_object_size(self, bucket: str, path: str) -> int:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
|
||||||
|
)
|
||||||
|
return resp["ContentLength"]
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"获取文件大小失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
BIN
app/core/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
app/core/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/config.cpython-312.pyc
Normal file
BIN
app/core/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/dependencies.cpython-312.pyc
Normal file
BIN
app/core/__pycache__/dependencies.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/exceptions.cpython-312.pyc
Normal file
BIN
app/core/__pycache__/exceptions.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/json_utils.cpython-312.pyc
Normal file
BIN
app/core/__pycache__/json_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/logging.cpython-312.pyc
Normal file
BIN
app/core/__pycache__/logging.cpython-312.pyc
Normal file
Binary file not shown.
46
app/core/config.py
Normal file
46
app/core/config.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Maps environment variable names to nested YAML key paths
|
||||||
|
_ENV_OVERRIDES: dict[str, list[str]] = {
|
||||||
|
"ZHIPUAI_API_KEY": ["zhipuai", "api_key"],
|
||||||
|
"STORAGE_ACCESS_KEY": ["storage", "access_key"],
|
||||||
|
"STORAGE_SECRET_KEY": ["storage", "secret_key"],
|
||||||
|
"STORAGE_ENDPOINT": ["storage", "endpoint"],
|
||||||
|
"BACKEND_CALLBACK_URL": ["backend", "callback_url"],
|
||||||
|
"LOG_LEVEL": ["server", "log_level"],
|
||||||
|
"MAX_VIDEO_SIZE_MB": ["video", "max_file_size_mb"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _set_nested(cfg: dict, keys: list[str], value: Any) -> None:
|
||||||
|
for key in keys[:-1]:
|
||||||
|
cfg = cfg.setdefault(key, {})
|
||||||
|
# Coerce numeric env vars
|
||||||
|
try:
|
||||||
|
value = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
cfg[keys[-1]] = value
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_config() -> dict:
|
||||||
|
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||||
|
cfg: dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
for env_var, key_path in _ENV_OVERRIDES.items():
|
||||||
|
value = os.environ.get(env_var)
|
||||||
|
if value is not None:
|
||||||
|
_set_nested(cfg, key_path, value)
|
||||||
|
|
||||||
|
return cfg
|
||||||
23
app/core/dependencies.py
Normal file
23
app/core/dependencies.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.clients.storage.rustfs_client import RustFSClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_llm_client() -> LLMClient:
|
||||||
|
cfg = get_config()
|
||||||
|
return ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"])
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_storage_client() -> StorageClient:
|
||||||
|
cfg = get_config()
|
||||||
|
return RustFSClient(
|
||||||
|
endpoint=cfg["storage"]["endpoint"],
|
||||||
|
access_key=cfg["storage"]["access_key"],
|
||||||
|
secret_key=cfg["storage"]["secret_key"],
|
||||||
|
)
|
||||||
50
app/core/exceptions.py
Normal file
50
app/core/exceptions.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
class AIServiceError(Exception):
|
||||||
|
status_code: int = 500
|
||||||
|
code: str = "INTERNAL_ERROR"
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileTypeError(AIServiceError):
|
||||||
|
status_code = 400
|
||||||
|
code = "UNSUPPORTED_FILE_TYPE"
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTooLargeError(AIServiceError):
|
||||||
|
status_code = 400
|
||||||
|
code = "VIDEO_TOO_LARGE"
|
||||||
|
|
||||||
|
|
||||||
|
class StorageError(AIServiceError):
|
||||||
|
status_code = 502
|
||||||
|
code = "STORAGE_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMParseError(AIServiceError):
|
||||||
|
status_code = 502
|
||||||
|
code = "LLM_PARSE_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCallError(AIServiceError):
|
||||||
|
status_code = 503
|
||||||
|
code = "LLM_CALL_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
async def ai_service_exception_handler(request: Request, exc: AIServiceError) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={"code": exc.code, "message": exc.message},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"code": "INTERNAL_ERROR", "message": str(exc)},
|
||||||
|
)
|
||||||
19
app/core/json_utils.py
Normal file
19
app/core/json_utils.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from app.core.exceptions import LLMParseError
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json(text: str) -> any:
|
||||||
|
"""Parse JSON from LLM response, stripping Markdown code fences if present."""
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Strip ```json ... ``` or ``` ... ``` fences
|
||||||
|
fence_match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", text)
|
||||||
|
if fence_match:
|
||||||
|
text = fence_match.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise LLMParseError(f"大模型返回非合法 JSON: {e}") from e
|
||||||
62
app/core/logging.py
Normal file
62
app/core/logging.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(_JsonFormatter())
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class _JsonFormatter(logging.Formatter):
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
payload = {
|
||||||
|
"time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"),
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
}
|
||||||
|
if record.exc_info:
|
||||||
|
payload["exc_info"] = self.formatException(record.exc_info)
|
||||||
|
# Merge any extra fields passed via `extra=`
|
||||||
|
for key, value in record.__dict__.items():
|
||||||
|
if key not in (
|
||||||
|
"name", "msg", "args", "levelname", "levelno", "pathname",
|
||||||
|
"filename", "module", "exc_info", "exc_text", "stack_info",
|
||||||
|
"lineno", "funcName", "created", "msecs", "relativeCreated",
|
||||||
|
"thread", "threadName", "processName", "process", "message",
|
||||||
|
"taskName",
|
||||||
|
):
|
||||||
|
payload[key] = value
|
||||||
|
return json.dumps(payload, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||||
|
def __init__(self, app, logger: logging.Logger | None = None) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
self._logger = logger or get_logger("request")
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
start = time.perf_counter()
|
||||||
|
response = await call_next(request)
|
||||||
|
duration_ms = round((time.perf_counter() - start) * 1000, 1)
|
||||||
|
self._logger.info(
|
||||||
|
"request",
|
||||||
|
extra={
|
||||||
|
"method": request.method,
|
||||||
|
"path": request.url.path,
|
||||||
|
"status": response.status_code,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response
|
||||||
46
app/main.py
Normal file
46
app/main.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.core.exceptions import (
|
||||||
|
AIServiceError,
|
||||||
|
ai_service_exception_handler,
|
||||||
|
unhandled_exception_handler,
|
||||||
|
)
|
||||||
|
from app.core.logging import RequestLoggingMiddleware, get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
logger.info("startup", extra={"message": "AI service starting"})
|
||||||
|
yield
|
||||||
|
logger.info("shutdown", extra={"message": "AI service stopping"})
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Label AI Service",
|
||||||
|
description="知识图谱标注平台 AI 计算服务",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(RequestLoggingMiddleware)
|
||||||
|
app.add_exception_handler(AIServiceError, ai_service_exception_handler)
|
||||||
|
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health", tags=["Health"])
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# Routers registered after implementation (imported lazily to avoid circular deps)
|
||||||
|
from app.routers import text, image, video, qa, finetune # noqa: E402
|
||||||
|
|
||||||
|
app.include_router(text.router, prefix="/api/v1")
|
||||||
|
app.include_router(image.router, prefix="/api/v1")
|
||||||
|
app.include_router(video.router, prefix="/api/v1")
|
||||||
|
app.include_router(qa.router, prefix="/api/v1")
|
||||||
|
app.include_router(finetune.router, prefix="/api/v1")
|
||||||
0
app/models/__init__.py
Normal file
0
app/models/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
3
app/routers/finetune.py
Normal file
3
app/routers/finetune.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Finetune"])
|
||||||
3
app/routers/image.py
Normal file
3
app/routers/image.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Image"])
|
||||||
3
app/routers/qa.py
Normal file
3
app/routers/qa.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(tags=["QA"])
|
||||||
3
app/routers/text.py
Normal file
3
app/routers/text.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Text"])
|
||||||
3
app/routers/video.py
Normal file
3
app/routers/video.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Video"])
|
||||||
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
19
config.yaml
Normal file
19
config.yaml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
server:
|
||||||
|
port: 8000
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
storage:
|
||||||
|
buckets:
|
||||||
|
source_data: "source-data"
|
||||||
|
finetune_export: "finetune-export"
|
||||||
|
|
||||||
|
backend: {} # callback_url injected via BACKEND_CALLBACK_URL env var
|
||||||
|
|
||||||
|
video:
|
||||||
|
frame_sample_count: 8 # uniform frames sampled for video-to-text
|
||||||
|
max_file_size_mb: 200 # video size limit (override with MAX_VIDEO_SIZE_MB)
|
||||||
|
keyframe_diff_threshold: 30.0 # grayscale mean-diff threshold for keyframe detection
|
||||||
|
|
||||||
|
models:
|
||||||
|
default_text: "glm-4-flash"
|
||||||
|
default_vision: "glm-4v-flash"
|
||||||
37
docker-compose.yml
Normal file
37
docker-compose.yml
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
version: "3.9"
|
||||||
|
|
||||||
|
services:
|
||||||
|
ai-service:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
depends_on:
|
||||||
|
rustfs:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
|
||||||
|
rustfs:
|
||||||
|
image: rustfs/rustfs:latest
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
environment:
|
||||||
|
RUSTFS_ACCESS_KEY: ${STORAGE_ACCESS_KEY}
|
||||||
|
RUSTFS_SECRET_KEY: ${STORAGE_SECRET_KEY}
|
||||||
|
volumes:
|
||||||
|
- rustfs_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:9000/health"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
start_period: 5s
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
rustfs_data:
|
||||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
asyncio_mode = auto
|
||||||
|
testpaths = tests
|
||||||
16
requirements.txt
Normal file
16
requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
fastapi>=0.111.0
|
||||||
|
uvicorn[standard]>=0.29.0
|
||||||
|
pydantic>=2.7.0
|
||||||
|
zhipuai>=2.1.0
|
||||||
|
boto3>=1.34.0
|
||||||
|
pdfplumber>=0.11.0
|
||||||
|
python-docx>=1.1.0
|
||||||
|
opencv-python-headless>=4.9.0
|
||||||
|
numpy>=1.26.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
pyyaml>=6.0.0
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.23.0
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_config.cpython-312-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/test_config.cpython-312-pytest-9.0.3.pyc
Normal file
Binary file not shown.
BIN
tests/__pycache__/test_llm_client.cpython-312-pytest-9.0.3.pyc
Normal file
BIN
tests/__pycache__/test_llm_client.cpython-312-pytest-9.0.3.pyc
Normal file
Binary file not shown.
Binary file not shown.
39
tests/conftest.py
Normal file
39
tests/conftest.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.dependencies import get_llm_client, get_storage_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm() -> LLMClient:
|
||||||
|
client = MagicMock(spec=LLMClient)
|
||||||
|
client.chat = AsyncMock(return_value='[]')
|
||||||
|
client.chat_vision = AsyncMock(return_value='[]')
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_storage() -> StorageClient:
|
||||||
|
client = MagicMock(spec=StorageClient)
|
||||||
|
client.download_bytes = AsyncMock(return_value=b"")
|
||||||
|
client.upload_bytes = AsyncMock(return_value=None)
|
||||||
|
client.get_presigned_url = AsyncMock(return_value="http://example.com/presigned")
|
||||||
|
client.get_object_size = AsyncMock(return_value=10 * 1024 * 1024) # 10 MB default
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_app(mock_llm, mock_storage):
|
||||||
|
from app.main import app
|
||||||
|
app.dependency_overrides[get_llm_client] = lambda: mock_llm
|
||||||
|
app.dependency_overrides[get_storage_client] = lambda: mock_storage
|
||||||
|
yield app
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(test_app):
|
||||||
|
return TestClient(test_app)
|
||||||
40
tests/test_config.py
Normal file
40
tests/test_config.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_yaml_defaults_load(monkeypatch):
|
||||||
|
# Clear lru_cache so each test gets a fresh load
|
||||||
|
from app.core import config as cfg_module
|
||||||
|
cfg_module.get_config.cache_clear()
|
||||||
|
|
||||||
|
# Remove env overrides that might bleed from shell environment
|
||||||
|
for var in ["MAX_VIDEO_SIZE_MB", "LOG_LEVEL", "STORAGE_ENDPOINT"]:
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|
||||||
|
cfg = cfg_module.get_config()
|
||||||
|
|
||||||
|
assert cfg["server"]["port"] == 8000
|
||||||
|
assert cfg["video"]["max_file_size_mb"] == 200
|
||||||
|
assert cfg["models"]["default_text"] == "glm-4-flash"
|
||||||
|
assert cfg["models"]["default_vision"] == "glm-4v-flash"
|
||||||
|
assert cfg["storage"]["buckets"]["source_data"] == "source-data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_video_size_env_override(monkeypatch):
|
||||||
|
from app.core import config as cfg_module
|
||||||
|
cfg_module.get_config.cache_clear()
|
||||||
|
|
||||||
|
monkeypatch.setenv("MAX_VIDEO_SIZE_MB", "500")
|
||||||
|
cfg = cfg_module.get_config()
|
||||||
|
|
||||||
|
assert cfg["video"]["max_file_size_mb"] == 500
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_level_env_override(monkeypatch):
|
||||||
|
from app.core import config as cfg_module
|
||||||
|
cfg_module.get_config.cache_clear()
|
||||||
|
|
||||||
|
monkeypatch.setenv("LOG_LEVEL", "DEBUG")
|
||||||
|
cfg = cfg_module.get_config()
|
||||||
|
|
||||||
|
assert cfg["server"]["log_level"] == "DEBUG"
|
||||||
40
tests/test_llm_client.py
Normal file
40
tests/test_llm_client.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
||||||
|
from app.core.exceptions import LLMCallError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_sdk_response():
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.choices[0].message.content = '{"result": "ok"}'
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
with patch("app.clients.llm.zhipuai_client.ZhipuAI"):
|
||||||
|
c = ZhipuAIClient(api_key="test-key")
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_returns_content(client, mock_sdk_response):
|
||||||
|
client._client.chat.completions.create.return_value = mock_sdk_response
|
||||||
|
result = await client.chat("glm-4-flash", [{"role": "user", "content": "hello"}])
|
||||||
|
assert result == '{"result": "ok"}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_vision_returns_content(client, mock_sdk_response):
|
||||||
|
client._client.chat.completions.create.return_value = mock_sdk_response
|
||||||
|
result = await client.chat_vision("glm-4v-flash", [{"role": "user", "content": []}])
|
||||||
|
assert result == '{"result": "ok"}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_call_error_on_sdk_exception(client):
|
||||||
|
client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded")
|
||||||
|
with pytest.raises(LLMCallError, match="大模型调用失败"):
|
||||||
|
await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}])
|
||||||
62
tests/test_storage_client.py
Normal file
62
tests/test_storage_client.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from app.clients.storage.rustfs_client import RustFSClient
|
||||||
|
from app.core.exceptions import StorageError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
with patch("app.clients.storage.rustfs_client.boto3") as mock_boto3:
|
||||||
|
c = RustFSClient(
|
||||||
|
endpoint="http://rustfs:9000",
|
||||||
|
access_key="key",
|
||||||
|
secret_key="secret",
|
||||||
|
)
|
||||||
|
c._s3 = MagicMock()
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_bytes_returns_bytes(client):
|
||||||
|
client._s3.get_object.return_value = {"Body": MagicMock(read=lambda: b"hello")}
|
||||||
|
result = await client.download_bytes("source-data", "text/test.txt")
|
||||||
|
assert result == b"hello"
|
||||||
|
client._s3.get_object.assert_called_once_with(Bucket="source-data", Key="text/test.txt")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_bytes_raises_storage_error(client):
|
||||||
|
client._s3.get_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, "GetObject"
|
||||||
|
)
|
||||||
|
with pytest.raises(StorageError, match="存储下载失败"):
|
||||||
|
await client.download_bytes("source-data", "missing.txt")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_object_size_returns_content_length(client):
|
||||||
|
client._s3.head_object.return_value = {"ContentLength": 1024}
|
||||||
|
size = await client.get_object_size("source-data", "video/test.mp4")
|
||||||
|
assert size == 1024
|
||||||
|
client._s3.head_object.assert_called_once_with(Bucket="source-data", Key="video/test.mp4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_object_size_raises_storage_error(client):
|
||||||
|
client._s3.head_object.side_effect = ClientError(
|
||||||
|
{"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, "HeadObject"
|
||||||
|
)
|
||||||
|
with pytest.raises(StorageError, match="获取文件大小失败"):
|
||||||
|
await client.get_object_size("source-data", "video/missing.mp4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_bytes_calls_put_object(client):
|
||||||
|
client._s3.put_object.return_value = {}
|
||||||
|
await client.upload_bytes("source-data", "frames/1/0.jpg", b"jpeg-data", "image/jpeg")
|
||||||
|
client._s3.put_object.assert_called_once()
|
||||||
|
call_kwargs = client._s3.put_object.call_args
|
||||||
|
assert call_kwargs.kwargs["Bucket"] == "source-data"
|
||||||
|
assert call_kwargs.kwargs["Key"] == "frames/1/0.jpg"
|
||||||
Reference in New Issue
Block a user