diff --git a/.env b/.env new file mode 100644 index 0000000..20292a0 --- /dev/null +++ b/.env @@ -0,0 +1,10 @@ +# Required — fill in before running +ZHIPUAI_API_KEY=your-zhipuai-api-key-here +STORAGE_ACCESS_KEY=your-storage-access-key +STORAGE_SECRET_KEY=your-storage-secret-key +STORAGE_ENDPOINT=http://rustfs:9000 + +# Optional overrides +BACKEND_CALLBACK_URL=http://label-backend:8080/api/ai/callback +LOG_LEVEL=INFO +# MAX_VIDEO_SIZE_MB=200 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0f47a07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# ========================================== +# 1. Maven/Java 构建产物 (一键忽略整个目录) +# ========================================== +target/ +*.class +*.jar +*.war +*.ear + +# ========================================== +# 2. IDE 配置文件 +# ========================================== +.idea/ +.vscode/ +*.iml +*.ipr +*.iws + +# ========================================== +# 3. 项目特定工具目录 (根据你的文件列表) +# ========================================== +# 忽略 Specifiy 工具生成的所有配置和脚本 +.specify/ + +# 忽略 Claude 本地设置和技能文件 +.claude/ + +# ========================================== +# 4. 操作系统文件 +# ========================================== +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c6fdd28 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.12-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1 \ + libglib2.0-0 \ + curl \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/__pycache__/__init__.cpython-312.pyc b/app/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..96a8c63 Binary files /dev/null and b/app/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/clients/__init__.py b/app/clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/clients/__pycache__/__init__.cpython-312.pyc b/app/clients/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..747081c Binary files /dev/null and b/app/clients/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/clients/llm/__init__.py b/app/clients/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/clients/llm/__pycache__/__init__.cpython-312.pyc b/app/clients/llm/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..5557668 Binary files /dev/null and b/app/clients/llm/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/clients/llm/__pycache__/base.cpython-312.pyc b/app/clients/llm/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000..f09bdaa Binary files /dev/null and b/app/clients/llm/__pycache__/base.cpython-312.pyc differ diff --git a/app/clients/llm/__pycache__/zhipuai_client.cpython-312.pyc b/app/clients/llm/__pycache__/zhipuai_client.cpython-312.pyc new file mode 100644 index 0000000..f1146fd Binary files /dev/null and b/app/clients/llm/__pycache__/zhipuai_client.cpython-312.pyc differ diff --git a/app/clients/llm/base.py b/app/clients/llm/base.py new file mode 100644 index 0000000..33ab1d8 --- /dev/null +++ b/app/clients/llm/base.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + + +class LLMClient(ABC): + @abstractmethod + async def chat(self, model: str, messages: list[dict]) -> str: + """Send a text chat request and return the response content string.""" + + @abstractmethod + async def chat_vision(self, model: str, messages: list[dict]) -> str: + """Send a multimodal (vision) chat request and return the response content string.""" diff --git a/app/clients/llm/zhipuai_client.py b/app/clients/llm/zhipuai_client.py new file mode 100644 index 0000000..a92322d --- /dev/null +++ b/app/clients/llm/zhipuai_client.py @@ -0,0 +1,37 @@ +import asyncio + +from zhipuai import ZhipuAI + +from app.clients.llm.base import LLMClient +from app.core.exceptions import LLMCallError +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +class ZhipuAIClient(LLMClient): + def __init__(self, api_key: str) -> None: + self._client = ZhipuAI(api_key=api_key) + + async def chat(self, model: str, messages: list[dict]) -> str: + return await self._call(model, messages) + + async def chat_vision(self, model: str, messages: list[dict]) -> str: + return await self._call(model, messages) + + async def _call(self, model: str, messages: list[dict]) -> str: + loop = asyncio.get_event_loop() + try: + response = await loop.run_in_executor( + None, + lambda: self._client.chat.completions.create( + model=model, + messages=messages, + ), + ) + content = response.choices[0].message.content + logger.info("llm_call", extra={"model": model, "response_len": len(content)}) + return content + except Exception as exc: + logger.error("llm_call_error", extra={"model": model, "error": str(exc)}) + raise LLMCallError(f"大模型调用失败: {exc}") from exc diff --git a/app/clients/storage/__init__.py b/app/clients/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/clients/storage/__pycache__/__init__.cpython-312.pyc b/app/clients/storage/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..56c457d Binary files /dev/null and b/app/clients/storage/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/clients/storage/__pycache__/base.cpython-312.pyc b/app/clients/storage/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000..ce7f4b2 Binary files /dev/null and b/app/clients/storage/__pycache__/base.cpython-312.pyc differ diff --git a/app/clients/storage/__pycache__/rustfs_client.cpython-312.pyc b/app/clients/storage/__pycache__/rustfs_client.cpython-312.pyc new file mode 100644 index 0000000..f34f312 Binary files /dev/null and b/app/clients/storage/__pycache__/rustfs_client.cpython-312.pyc differ diff --git a/app/clients/storage/base.py b/app/clients/storage/base.py new file mode 100644 index 0000000..89535a4 --- /dev/null +++ b/app/clients/storage/base.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + + +class StorageClient(ABC): + @abstractmethod + async def download_bytes(self, bucket: str, path: str) -> bytes: + """Download an object and return its raw bytes.""" + + @abstractmethod + async def upload_bytes( + self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream" + ) -> None: + """Upload raw bytes to the given bucket/path.""" + + @abstractmethod + async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str: + """Return a presigned GET URL valid for `expires` seconds.""" + + @abstractmethod + async def get_object_size(self, bucket: str, path: str) -> int: + """Return the object size in bytes without downloading it.""" diff --git a/app/clients/storage/rustfs_client.py b/app/clients/storage/rustfs_client.py new file mode 100644 index 0000000..8ef105a --- /dev/null +++ b/app/clients/storage/rustfs_client.py @@ -0,0 +1,70 @@ +import asyncio +import io + +import boto3 +from botocore.exceptions import ClientError + +from app.clients.storage.base import StorageClient +from app.core.exceptions import StorageError +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +class RustFSClient(StorageClient): + def __init__(self, endpoint: str, access_key: str, secret_key: str) -> None: + self._s3 = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + + async def download_bytes(self, bucket: str, path: str) -> bytes: + loop = asyncio.get_event_loop() + try: + resp = await loop.run_in_executor( + None, lambda: self._s3.get_object(Bucket=bucket, Key=path) + ) + return resp["Body"].read() + except ClientError as exc: + raise StorageError(f"存储下载失败 [{bucket}/{path}]: {exc}") from exc + + async def upload_bytes( + self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream" + ) -> None: + loop = asyncio.get_event_loop() + try: + await loop.run_in_executor( + None, + lambda: self._s3.put_object( + Bucket=bucket, Key=path, Body=io.BytesIO(data), ContentType=content_type + ), + ) + except ClientError as exc: + raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc + + async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str: + loop = asyncio.get_event_loop() + try: + url = await loop.run_in_executor( + None, + lambda: self._s3.generate_presigned_url( + "get_object", + Params={"Bucket": bucket, "Key": path}, + ExpiresIn=expires, + ), + ) + return url + except ClientError as exc: + raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc + + async def get_object_size(self, bucket: str, path: str) -> int: + loop = asyncio.get_event_loop() + try: + resp = await loop.run_in_executor( + None, lambda: self._s3.head_object(Bucket=bucket, Key=path) + ) + return resp["ContentLength"] + except ClientError as exc: + raise StorageError(f"获取文件大小失败 [{bucket}/{path}]: {exc}") from exc diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/__pycache__/__init__.cpython-312.pyc b/app/core/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..aa06ba8 Binary files /dev/null and b/app/core/__pycache__/__init__.cpython-312.pyc differ diff --git a/app/core/__pycache__/config.cpython-312.pyc b/app/core/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000..9ed37a1 Binary files /dev/null and b/app/core/__pycache__/config.cpython-312.pyc differ diff --git a/app/core/__pycache__/dependencies.cpython-312.pyc b/app/core/__pycache__/dependencies.cpython-312.pyc new file mode 100644 index 0000000..f8a413c Binary files /dev/null and b/app/core/__pycache__/dependencies.cpython-312.pyc differ diff --git a/app/core/__pycache__/exceptions.cpython-312.pyc b/app/core/__pycache__/exceptions.cpython-312.pyc new file mode 100644 index 0000000..48fd955 Binary files /dev/null and b/app/core/__pycache__/exceptions.cpython-312.pyc differ diff --git a/app/core/__pycache__/json_utils.cpython-312.pyc b/app/core/__pycache__/json_utils.cpython-312.pyc new file mode 100644 index 0000000..7bd9150 Binary files /dev/null and b/app/core/__pycache__/json_utils.cpython-312.pyc differ diff --git a/app/core/__pycache__/logging.cpython-312.pyc b/app/core/__pycache__/logging.cpython-312.pyc new file mode 100644 index 0000000..a702a7a Binary files /dev/null and b/app/core/__pycache__/logging.cpython-312.pyc differ diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..ac40a1d --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,46 @@ +import os +from functools import lru_cache +from pathlib import Path +from typing import Any + +import yaml +from dotenv import load_dotenv + +load_dotenv() + +# Maps environment variable names to nested YAML key paths +_ENV_OVERRIDES: dict[str, list[str]] = { + "ZHIPUAI_API_KEY": ["zhipuai", "api_key"], + "STORAGE_ACCESS_KEY": ["storage", "access_key"], + "STORAGE_SECRET_KEY": ["storage", "secret_key"], + "STORAGE_ENDPOINT": ["storage", "endpoint"], + "BACKEND_CALLBACK_URL": ["backend", "callback_url"], + "LOG_LEVEL": ["server", "log_level"], + "MAX_VIDEO_SIZE_MB": ["video", "max_file_size_mb"], +} + +_CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml" + + +def _set_nested(cfg: dict, keys: list[str], value: Any) -> None: + for key in keys[:-1]: + cfg = cfg.setdefault(key, {}) + # Coerce numeric env vars + try: + value = int(value) + except (TypeError, ValueError): + pass + cfg[keys[-1]] = value + + +@lru_cache(maxsize=1) +def get_config() -> dict: + with open(_CONFIG_PATH, "r", encoding="utf-8") as f: + cfg: dict = yaml.safe_load(f) + + for env_var, key_path in _ENV_OVERRIDES.items(): + value = os.environ.get(env_var) + if value is not None: + _set_nested(cfg, key_path, value) + + return cfg diff --git a/app/core/dependencies.py b/app/core/dependencies.py new file mode 100644 index 0000000..66a9c72 --- /dev/null +++ b/app/core/dependencies.py @@ -0,0 +1,23 @@ +from functools import lru_cache + +from app.clients.llm.base import LLMClient +from app.clients.llm.zhipuai_client import ZhipuAIClient +from app.clients.storage.base import StorageClient +from app.clients.storage.rustfs_client import RustFSClient +from app.core.config import get_config + + +@lru_cache(maxsize=1) +def get_llm_client() -> LLMClient: + cfg = get_config() + return ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"]) + + +@lru_cache(maxsize=1) +def get_storage_client() -> StorageClient: + cfg = get_config() + return RustFSClient( + endpoint=cfg["storage"]["endpoint"], + access_key=cfg["storage"]["access_key"], + secret_key=cfg["storage"]["secret_key"], + ) diff --git a/app/core/exceptions.py b/app/core/exceptions.py new file mode 100644 index 0000000..aa7fd15 --- /dev/null +++ b/app/core/exceptions.py @@ -0,0 +1,50 @@ +from fastapi import Request +from fastapi.responses import JSONResponse + + +class AIServiceError(Exception): + status_code: int = 500 + code: str = "INTERNAL_ERROR" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) + + +class UnsupportedFileTypeError(AIServiceError): + status_code = 400 + code = "UNSUPPORTED_FILE_TYPE" + + +class VideoTooLargeError(AIServiceError): + status_code = 400 + code = "VIDEO_TOO_LARGE" + + +class StorageError(AIServiceError): + status_code = 502 + code = "STORAGE_ERROR" + + +class LLMParseError(AIServiceError): + status_code = 502 + code = "LLM_PARSE_ERROR" + + +class LLMCallError(AIServiceError): + status_code = 503 + code = "LLM_CALL_ERROR" + + +async def ai_service_exception_handler(request: Request, exc: AIServiceError) -> JSONResponse: + return JSONResponse( + status_code=exc.status_code, + content={"code": exc.code, "message": exc.message}, + ) + + +async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=500, + content={"code": "INTERNAL_ERROR", "message": str(exc)}, + ) diff --git a/app/core/json_utils.py b/app/core/json_utils.py new file mode 100644 index 0000000..494b5fb --- /dev/null +++ b/app/core/json_utils.py @@ -0,0 +1,19 @@ +import json +import re + +from app.core.exceptions import LLMParseError + + +def extract_json(text: str) -> any: + """Parse JSON from LLM response, stripping Markdown code fences if present.""" + text = text.strip() + + # Strip ```json ... ``` or ``` ... ``` fences + fence_match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", text) + if fence_match: + text = fence_match.group(1).strip() + + try: + return json.loads(text) + except json.JSONDecodeError as e: + raise LLMParseError(f"大模型返回非合法 JSON: {e}") from e diff --git a/app/core/logging.py b/app/core/logging.py new file mode 100644 index 0000000..1fd8b9d --- /dev/null +++ b/app/core/logging.py @@ -0,0 +1,62 @@ +import json +import logging +import time +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + + +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + handler = logging.StreamHandler() + handler.setFormatter(_JsonFormatter()) + logger.addHandler(handler) + logger.propagate = False + return logger + + +class _JsonFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + payload = { + "time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + if record.exc_info: + payload["exc_info"] = self.formatException(record.exc_info) + # Merge any extra fields passed via `extra=` + for key, value in record.__dict__.items(): + if key not in ( + "name", "msg", "args", "levelname", "levelno", "pathname", + "filename", "module", "exc_info", "exc_text", "stack_info", + "lineno", "funcName", "created", "msecs", "relativeCreated", + "thread", "threadName", "processName", "process", "message", + "taskName", + ): + payload[key] = value + return json.dumps(payload, ensure_ascii=False) + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + def __init__(self, app, logger: logging.Logger | None = None) -> None: + super().__init__(app) + self._logger = logger or get_logger("request") + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + start = time.perf_counter() + response = await call_next(request) + duration_ms = round((time.perf_counter() - start) * 1000, 1) + self._logger.info( + "request", + extra={ + "method": request.method, + "path": request.url.path, + "status": response.status_code, + "duration_ms": duration_ms, + }, + ) + return response diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..1ee4df0 --- /dev/null +++ b/app/main.py @@ -0,0 +1,46 @@ +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from app.core.exceptions import ( + AIServiceError, + ai_service_exception_handler, + unhandled_exception_handler, +) +from app.core.logging import RequestLoggingMiddleware, get_logger + +logger = get_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("startup", extra={"message": "AI service starting"}) + yield + logger.info("shutdown", extra={"message": "AI service stopping"}) + + +app = FastAPI( + title="Label AI Service", + description="知识图谱标注平台 AI 计算服务", + version="1.0.0", + lifespan=lifespan, +) + +app.add_middleware(RequestLoggingMiddleware) +app.add_exception_handler(AIServiceError, ai_service_exception_handler) +app.add_exception_handler(Exception, unhandled_exception_handler) + + +@app.get("/health", tags=["Health"]) +async def health(): + return {"status": "ok"} + + +# Routers registered after implementation (imported lazily to avoid circular deps) +from app.routers import text, image, video, qa, finetune # noqa: E402 + +app.include_router(text.router, prefix="/api/v1") +app.include_router(image.router, prefix="/api/v1") +app.include_router(video.router, prefix="/api/v1") +app.include_router(qa.router, prefix="/api/v1") +app.include_router(finetune.router, prefix="/api/v1") diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/routers/finetune.py b/app/routers/finetune.py new file mode 100644 index 0000000..f16ec0f --- /dev/null +++ b/app/routers/finetune.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["Finetune"]) diff --git a/app/routers/image.py b/app/routers/image.py new file mode 100644 index 0000000..30aefbc --- /dev/null +++ b/app/routers/image.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["Image"]) diff --git a/app/routers/qa.py b/app/routers/qa.py new file mode 100644 index 0000000..5b22c10 --- /dev/null +++ b/app/routers/qa.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["QA"]) diff --git a/app/routers/text.py b/app/routers/text.py new file mode 100644 index 0000000..44c49f9 --- /dev/null +++ b/app/routers/text.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["Text"]) diff --git a/app/routers/video.py b/app/routers/video.py new file mode 100644 index 0000000..136e997 --- /dev/null +++ b/app/routers/video.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router = APIRouter(tags=["Video"]) diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..2e68628 --- /dev/null +++ b/config.yaml @@ -0,0 +1,19 @@ +server: + port: 8000 + log_level: INFO + +storage: + buckets: + source_data: "source-data" + finetune_export: "finetune-export" + +backend: {} # callback_url injected via BACKEND_CALLBACK_URL env var + +video: + frame_sample_count: 8 # uniform frames sampled for video-to-text + max_file_size_mb: 200 # video size limit (override with MAX_VIDEO_SIZE_MB) + keyframe_diff_threshold: 30.0 # grayscale mean-diff threshold for keyframe detection + +models: + default_text: "glm-4-flash" + default_vision: "glm-4v-flash" diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..fa5fa65 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,37 @@ +version: "3.9" + +services: + ai-service: + build: . + ports: + - "8000:8000" + env_file: + - .env + depends_on: + rustfs: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 10s + + rustfs: + image: rustfs/rustfs:latest + ports: + - "9000:9000" + environment: + RUSTFS_ACCESS_KEY: ${STORAGE_ACCESS_KEY} + RUSTFS_SECRET_KEY: ${STORAGE_SECRET_KEY} + volumes: + - rustfs_data:/data + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/health"] + interval: 10s + timeout: 3s + retries: 5 + start_period: 5s + +volumes: + rustfs_data: diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..78c5011 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +asyncio_mode = auto +testpaths = tests diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9e74516 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +fastapi>=0.111.0 +uvicorn[standard]>=0.29.0 +pydantic>=2.7.0 +zhipuai>=2.1.0 +boto3>=1.34.0 +pdfplumber>=0.11.0 +python-docx>=1.1.0 +opencv-python-headless>=4.9.0 +numpy>=1.26.0 +httpx>=0.27.0 +python-dotenv>=1.0.0 +pyyaml>=6.0.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__pycache__/__init__.cpython-312.pyc b/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..3cb3690 Binary files /dev/null and b/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc new file mode 100644 index 0000000..d04d8c6 Binary files /dev/null and b/tests/__pycache__/conftest.cpython-312-pytest-9.0.3.pyc differ diff --git a/tests/__pycache__/test_config.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/test_config.cpython-312-pytest-9.0.3.pyc new file mode 100644 index 0000000..1652cf6 Binary files /dev/null and b/tests/__pycache__/test_config.cpython-312-pytest-9.0.3.pyc differ diff --git a/tests/__pycache__/test_llm_client.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/test_llm_client.cpython-312-pytest-9.0.3.pyc new file mode 100644 index 0000000..9f328ca Binary files /dev/null and b/tests/__pycache__/test_llm_client.cpython-312-pytest-9.0.3.pyc differ diff --git a/tests/__pycache__/test_storage_client.cpython-312-pytest-9.0.3.pyc b/tests/__pycache__/test_storage_client.cpython-312-pytest-9.0.3.pyc new file mode 100644 index 0000000..890bd18 Binary files /dev/null and b/tests/__pycache__/test_storage_client.cpython-312-pytest-9.0.3.pyc differ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ae81f4d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,39 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +from fastapi.testclient import TestClient + +from app.clients.llm.base import LLMClient +from app.clients.storage.base import StorageClient +from app.core.dependencies import get_llm_client, get_storage_client + + +@pytest.fixture +def mock_llm() -> LLMClient: + client = MagicMock(spec=LLMClient) + client.chat = AsyncMock(return_value='[]') + client.chat_vision = AsyncMock(return_value='[]') + return client + + +@pytest.fixture +def mock_storage() -> StorageClient: + client = MagicMock(spec=StorageClient) + client.download_bytes = AsyncMock(return_value=b"") + client.upload_bytes = AsyncMock(return_value=None) + client.get_presigned_url = AsyncMock(return_value="http://example.com/presigned") + client.get_object_size = AsyncMock(return_value=10 * 1024 * 1024) # 10 MB default + return client + + +@pytest.fixture +def test_app(mock_llm, mock_storage): + from app.main import app + app.dependency_overrides[get_llm_client] = lambda: mock_llm + app.dependency_overrides[get_storage_client] = lambda: mock_storage + yield app + app.dependency_overrides.clear() + + +@pytest.fixture +def client(test_app): + return TestClient(test_app) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..aa8f464 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,40 @@ +import os +import pytest + + +def test_yaml_defaults_load(monkeypatch): + # Clear lru_cache so each test gets a fresh load + from app.core import config as cfg_module + cfg_module.get_config.cache_clear() + + # Remove env overrides that might bleed from shell environment + for var in ["MAX_VIDEO_SIZE_MB", "LOG_LEVEL", "STORAGE_ENDPOINT"]: + monkeypatch.delenv(var, raising=False) + + cfg = cfg_module.get_config() + + assert cfg["server"]["port"] == 8000 + assert cfg["video"]["max_file_size_mb"] == 200 + assert cfg["models"]["default_text"] == "glm-4-flash" + assert cfg["models"]["default_vision"] == "glm-4v-flash" + assert cfg["storage"]["buckets"]["source_data"] == "source-data" + + +def test_max_video_size_env_override(monkeypatch): + from app.core import config as cfg_module + cfg_module.get_config.cache_clear() + + monkeypatch.setenv("MAX_VIDEO_SIZE_MB", "500") + cfg = cfg_module.get_config() + + assert cfg["video"]["max_file_size_mb"] == 500 + + +def test_log_level_env_override(monkeypatch): + from app.core import config as cfg_module + cfg_module.get_config.cache_clear() + + monkeypatch.setenv("LOG_LEVEL", "DEBUG") + cfg = cfg_module.get_config() + + assert cfg["server"]["log_level"] == "DEBUG" diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py new file mode 100644 index 0000000..39a586b --- /dev/null +++ b/tests/test_llm_client.py @@ -0,0 +1,40 @@ +import pytest +from unittest.mock import MagicMock, patch + +from app.clients.llm.zhipuai_client import ZhipuAIClient +from app.core.exceptions import LLMCallError + + +@pytest.fixture +def mock_sdk_response(): + resp = MagicMock() + resp.choices[0].message.content = '{"result": "ok"}' + return resp + + +@pytest.fixture +def client(): + with patch("app.clients.llm.zhipuai_client.ZhipuAI"): + c = ZhipuAIClient(api_key="test-key") + return c + + +@pytest.mark.asyncio +async def test_chat_returns_content(client, mock_sdk_response): + client._client.chat.completions.create.return_value = mock_sdk_response + result = await client.chat("glm-4-flash", [{"role": "user", "content": "hello"}]) + assert result == '{"result": "ok"}' + + +@pytest.mark.asyncio +async def test_chat_vision_returns_content(client, mock_sdk_response): + client._client.chat.completions.create.return_value = mock_sdk_response + result = await client.chat_vision("glm-4v-flash", [{"role": "user", "content": []}]) + assert result == '{"result": "ok"}' + + +@pytest.mark.asyncio +async def test_llm_call_error_on_sdk_exception(client): + client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded") + with pytest.raises(LLMCallError, match="大模型调用失败"): + await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}]) diff --git a/tests/test_storage_client.py b/tests/test_storage_client.py new file mode 100644 index 0000000..d124563 --- /dev/null +++ b/tests/test_storage_client.py @@ -0,0 +1,62 @@ +import pytest +from unittest.mock import MagicMock, patch +from botocore.exceptions import ClientError + +from app.clients.storage.rustfs_client import RustFSClient +from app.core.exceptions import StorageError + + +@pytest.fixture +def client(): + with patch("app.clients.storage.rustfs_client.boto3") as mock_boto3: + c = RustFSClient( + endpoint="http://rustfs:9000", + access_key="key", + secret_key="secret", + ) + c._s3 = MagicMock() + return c + + +@pytest.mark.asyncio +async def test_download_bytes_returns_bytes(client): + client._s3.get_object.return_value = {"Body": MagicMock(read=lambda: b"hello")} + result = await client.download_bytes("source-data", "text/test.txt") + assert result == b"hello" + client._s3.get_object.assert_called_once_with(Bucket="source-data", Key="text/test.txt") + + +@pytest.mark.asyncio +async def test_download_bytes_raises_storage_error(client): + client._s3.get_object.side_effect = ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, "GetObject" + ) + with pytest.raises(StorageError, match="存储下载失败"): + await client.download_bytes("source-data", "missing.txt") + + +@pytest.mark.asyncio +async def test_get_object_size_returns_content_length(client): + client._s3.head_object.return_value = {"ContentLength": 1024} + size = await client.get_object_size("source-data", "video/test.mp4") + assert size == 1024 + client._s3.head_object.assert_called_once_with(Bucket="source-data", Key="video/test.mp4") + + +@pytest.mark.asyncio +async def test_get_object_size_raises_storage_error(client): + client._s3.head_object.side_effect = ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, "HeadObject" + ) + with pytest.raises(StorageError, match="获取文件大小失败"): + await client.get_object_size("source-data", "video/missing.mp4") + + +@pytest.mark.asyncio +async def test_upload_bytes_calls_put_object(client): + client._s3.put_object.return_value = {} + await client.upload_bytes("source-data", "frames/1/0.jpg", b"jpeg-data", "image/jpeg") + client._s3.put_object.assert_called_once() + call_kwargs = client._s3.put_object.call_args + assert call_kwargs.kwargs["Bucket"] == "source-data" + assert call_kwargs.kwargs["Key"] == "frames/1/0.jpg"