Compare commits
29 Commits
54d7a8bf55
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a60d8cb33 | ||
|
|
d34f703523 | ||
|
|
7b07f764a7 | ||
|
|
3a3841313f | ||
|
|
bfbd7f3bb9 | ||
| 15adab5897 | |||
| 58802d40fe | |||
| e753609249 | |||
|
|
7adc88bab7 | ||
|
|
68567b98b3 | ||
|
|
1e327ea92f | ||
|
|
a0ec71d877 | ||
|
|
65842b5e8b | ||
|
|
ec279262a7 | ||
|
|
0880e1018c | ||
|
|
603382d1fa | ||
|
|
00f092e728 | ||
|
|
4211e587ee | ||
|
|
0274bb470a | ||
|
|
2876c179ac | ||
|
|
dd8da386f4 | ||
|
|
e1eb5e47b1 | ||
|
|
4162d9f4e6 | ||
|
|
092f9dbfc5 | ||
|
|
e0d080ceea | ||
|
|
10887da4ab | ||
|
|
3892c6e60f | ||
|
|
f9f84937db | ||
|
|
f5a43a4bbc |
13
.env
Normal file
13
.env
Normal file
@@ -0,0 +1,13 @@
|
||||
# Required - fill in before running
|
||||
ZHIPUAI_API_KEY=b11404531c574043b1b0750186cd9d79.LfoyJjtj1fnGbTLl
|
||||
DASHSCOPE_API_KEY=sk-d0ebc07bad2d4666bcd284f80d3fe138
|
||||
STORAGE_ACCESS_KEY=admin
|
||||
STORAGE_SECRET_KEY=your_strong_password
|
||||
STORAGE_ENDPOINT=http://39.107.112.174:9000
|
||||
|
||||
# Optional overrides
|
||||
DASHSCOPE_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
DASHSCOPE_FINE_TUNE_BASE_URL=https://dashscope.aliyuncs.com/api/v1
|
||||
BACKEND_CALLBACK_URL=http://localhost:18082/api/ai/callback
|
||||
LOG_LEVEL=INFO
|
||||
MAX_VIDEO_SIZE_MB=500MB
|
||||
37
.gitignore
vendored
Normal file
37
.gitignore
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Environment
|
||||
.env
|
||||
*.env
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# Temp files (video processing)
|
||||
tmp/
|
||||
*.tmp
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
.specify/
|
||||
.claude/
|
||||
docs/
|
||||
specs/
|
||||
tests/
|
||||
CLAUDE.md
|
||||
pytest.ini
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
22
Dockerfile
Normal file
22
Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
||||
FROM registry.bjzgzp.com:4433/library/python3.12:base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG APT_MIRROR=mirrors.tuna.tsinghua.edu.cn
|
||||
ARG PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN sed -i "s|deb.debian.org|${APT_MIRROR}|g; s|security.debian.org|${APT_MIRROR}|g" /etc/apt/sources.list.d/debian.sources \
|
||||
&& apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -i "${PIP_INDEX_URL}" -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 18000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "18000"]
|
||||
430
README.md
430
README.md
@@ -0,0 +1,430 @@
|
||||
# label_ai_service
|
||||
|
||||
> 2026-04-16 update: 默认 LLM 适配器已切换为阿里云百炼/千问(DashScope OpenAI-compatible API)。文本默认模型为 `qwen3.6-plus`,视觉默认模型为 `qwen-vl-plus`;旧的 `ZhipuAIClient` 代码保留在仓库中,但默认依赖注入不再使用。
|
||||
|
||||
`label_ai_service` 是知识图谱智能标注平台的 AI 计算服务,基于 FastAPI 提供独立部署的推理与预处理能力。它不直接访问数据库,而是通过 ZhipuAI GLM 系列模型完成结构化抽取,通过 RustFS 读写原始文件和处理结果,并通过 HTTP 回调把异步视频任务结果通知上游后端。
|
||||
|
||||
当前服务覆盖 6 类核心能力:
|
||||
|
||||
- 文本三元组提取,支持 `TXT`、`PDF`、`DOCX`
|
||||
- 图像四元组提取,并自动裁剪 `bbox` 区域图
|
||||
- 视频抽帧,支持固定间隔和近似关键帧两种模式
|
||||
- 视频转文本,将视频片段描述输出为文本文件
|
||||
- 基于文本或图片证据生成问答对
|
||||
- 向 ZhipuAI 提交微调任务并查询状态
|
||||
|
||||
## 适用场景
|
||||
|
||||
这个服务适合作为 `label-backend` 的 AI 能力侧车服务,也可以单独运行,用于验证文件解析、图像理解、视频预处理和问答生成流程。
|
||||
|
||||
典型调用链如下:
|
||||
|
||||
1. Java 后端把原始文本、图片或视频上传到 RustFS。
|
||||
2. Java 后端调用 `label_ai_service` 的 REST API。
|
||||
3. AI 服务从 RustFS 读取文件,调用 GLM 模型做抽取或生成。
|
||||
4. 结果以 JSON 返回,或写回 RustFS 后通过回调通知上游。
|
||||
|
||||
## 功能概览
|
||||
|
||||
| 能力 | 接口 | 说明 |
|
||||
|---|---|---|
|
||||
| 健康检查 | `GET /health` | 用于容器存活探测和联调自检 |
|
||||
| 文本三元组提取 | `POST /api/v1/text/extract` | 从文档中提取 `subject / predicate / object / source_snippet / source_offset` |
|
||||
| 图像四元组提取 | `POST /api/v1/image/extract` | 从图片中提取 `subject / predicate / object / qualifier / bbox`,并输出裁剪图路径 |
|
||||
| 视频抽帧 | `POST /api/v1/video/extract-frames` | 异步抽取视频帧,结果通过回调返回 |
|
||||
| 视频转文本 | `POST /api/v1/video/to-text` | 异步抽样视频代表帧,生成中文描述文本并上传到对象存储 |
|
||||
| 文本问答生成 | `POST /api/v1/qa/gen-text` | 基于三元组和原文证据生成问答对 |
|
||||
| 图像问答生成 | `POST /api/v1/qa/gen-image` | 基于裁剪图和四元组生成问答对 |
|
||||
| 微调任务提交 | `POST /api/v1/finetune/start` | 向 ZhipuAI 提交微调任务 |
|
||||
| 微调状态查询 | `GET /api/v1/finetune/status/{job_id}` | 查询微调任务状态和进度 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
- Python 3.12
|
||||
- FastAPI
|
||||
- Pydantic v2
|
||||
- ZhipuAI Python SDK
|
||||
- boto3
|
||||
- OpenCV
|
||||
- pdfplumber
|
||||
- python-docx
|
||||
- httpx
|
||||
- pytest / pytest-asyncio
|
||||
|
||||
## 架构说明
|
||||
|
||||
### 外部依赖
|
||||
|
||||
- ZhipuAI
|
||||
- 文本与多模态推理
|
||||
- 微调任务提交与查询
|
||||
- RustFS 或任意 S3 兼容对象存储
|
||||
- 原始文件读取
|
||||
- 裁剪图、视频帧、视频描述文本写回
|
||||
- 上游回调接口
|
||||
- 视频任务完成后接收结果
|
||||
|
||||
### 处理边界
|
||||
|
||||
- 服务本身不负责文件上传,也不维护任务状态库。
|
||||
- 文本、图像接口是同步返回。
|
||||
- 视频接口是异步返回 `202 Accepted`,真实处理结果走回调。
|
||||
- 服务默认不做鉴权,通常由上游网关或后端负责访问控制。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```text
|
||||
label_ai_service/
|
||||
├── app/
|
||||
│ ├── main.py
|
||||
│ ├── clients/
|
||||
│ │ ├── llm/
|
||||
│ │ └── storage/
|
||||
│ ├── core/
|
||||
│ ├── models/
|
||||
│ ├── routers/
|
||||
│ └── services/
|
||||
├── docs/
|
||||
│ └── superpowers/
|
||||
├── specs/
|
||||
├── tests/
|
||||
├── config.yaml
|
||||
├── .env
|
||||
├── Dockerfile
|
||||
├── docker-compose.yml
|
||||
├── requirements.txt
|
||||
└── README.md
|
||||
```
|
||||
|
||||
目录职责:
|
||||
|
||||
- `app/main.py`
|
||||
- FastAPI 应用入口,注册中间件、异常处理器和所有路由
|
||||
- `app/clients`
|
||||
- 第三方依赖适配层,当前包含 ZhipuAI 和 RustFS
|
||||
- `app/services`
|
||||
- 业务核心逻辑,负责文件解析、提示词拼装、结果转换和异步任务处理
|
||||
- `app/routers`
|
||||
- HTTP 接口层
|
||||
- `app/models`
|
||||
- 请求与响应模型
|
||||
- `app/core`
|
||||
- 配置、日志、中间件、异常等通用模块
|
||||
- `tests`
|
||||
- Router、Service、Config 和 Client 的测试
|
||||
|
||||
## 配置说明
|
||||
|
||||
配置采用 `config.yaml + .env` 分层方式:
|
||||
|
||||
- `config.yaml`
|
||||
- 存放稳定、可提交的结构化配置
|
||||
- `.env`
|
||||
- 存放密钥和环境差异项
|
||||
|
||||
环境变量会覆盖 `config.yaml` 中的同名配置。
|
||||
|
||||
### config.yaml
|
||||
|
||||
当前项目默认配置如下:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
port: 18000
|
||||
log_level: INFO
|
||||
|
||||
dashscope:
|
||||
api_key: ""
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
fine_tune_base_url: "https://dashscope.aliyuncs.com/api/v1"
|
||||
|
||||
storage:
|
||||
buckets:
|
||||
source_data: "source-data"
|
||||
finetune_export: "finetune-export"
|
||||
|
||||
backend: {}
|
||||
|
||||
video:
|
||||
frame_sample_count: 8
|
||||
max_file_size_mb: 500
|
||||
keyframe_diff_threshold: 30.0
|
||||
|
||||
models:
|
||||
default_text: "qwen3.6-plus"
|
||||
default_vision: "qwen-vl-plus"
|
||||
```
|
||||
|
||||
### .env
|
||||
|
||||
建议至少配置这些变量:
|
||||
|
||||
| 变量名 | 必填 | 说明 |
|
||||
|---|---|---|
|
||||
| `DASHSCOPE_API_KEY` | 是 | DashScope API Key |
|
||||
| `DASHSCOPE_BASE_URL` | 否 | DashScope OpenAI-compatible base URL |
|
||||
| `DASHSCOPE_FINE_TUNE_BASE_URL` | 否 | DashScope fine-tune API base URL |
|
||||
| `STORAGE_ACCESS_KEY` | 是 | RustFS/S3 Access Key |
|
||||
| `STORAGE_SECRET_KEY` | 是 | RustFS/S3 Secret Key |
|
||||
| `STORAGE_ENDPOINT` | 是 | RustFS/S3 Endpoint,例如 `http://rustfs:9000` |
|
||||
| `BACKEND_CALLBACK_URL` | 否 | 视频异步任务回调地址 |
|
||||
| `LOG_LEVEL` | 否 | 日志级别,默认 `INFO` |
|
||||
| `MAX_VIDEO_SIZE_MB` | 否 | 覆盖视频大小上限 |
|
||||
|
||||
`.env` 示例:
|
||||
|
||||
```ini
|
||||
DASHSCOPE_API_KEY=your-dashscope-api-key-here
|
||||
DASHSCOPE_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
DASHSCOPE_FINE_TUNE_BASE_URL=https://dashscope.aliyuncs.com/api/v1
|
||||
STORAGE_ACCESS_KEY=your-storage-access-key
|
||||
STORAGE_SECRET_KEY=your-storage-secret-key
|
||||
STORAGE_ENDPOINT=http://rustfs:9000
|
||||
BACKEND_CALLBACK_URL=http://label-backend:8080/api/ai/callback
|
||||
LOG_LEVEL=INFO
|
||||
```
|
||||
|
||||
## 本地运行
|
||||
|
||||
### 方式一:直接运行
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
Windows PowerShell 可以使用:
|
||||
|
||||
```powershell
|
||||
python -m venv .venv
|
||||
.\.venv\Scripts\Activate.ps1
|
||||
pip install -r requirements.txt
|
||||
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
启动后访问:
|
||||
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- 健康检查: `http://localhost:8000/health`
|
||||
|
||||
### 方式二:Docker Compose
|
||||
|
||||
项目自带的 Compose 文件会启动:
|
||||
|
||||
- `ai-service`
|
||||
- `rustfs`
|
||||
|
||||
启动命令:
|
||||
|
||||
```bash
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
如果你要联调视频异步任务,请确保 `BACKEND_CALLBACK_URL` 指向一个可访问的后端地址。否则任务本身会继续处理,但回调会失败并记录错误日志。
|
||||
|
||||
## API 使用示例
|
||||
|
||||
### 1. 健康检查
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
返回:
|
||||
|
||||
```json
|
||||
{"status":"ok"}
|
||||
```
|
||||
|
||||
### 2. 文本三元组提取
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/v1/text/extract \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"file_path": "text/202404/123.txt",
|
||||
"file_name": "设备规范.txt"
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. 图像四元组提取
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/v1/image/extract \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"file_path": "image/202404/456.jpg",
|
||||
"task_id": 789
|
||||
}'
|
||||
```
|
||||
|
||||
### 4. 视频抽帧
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/v1/video/extract-frames \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"file_path": "video/202404/001.mp4",
|
||||
"source_id": 10,
|
||||
"job_id": 42,
|
||||
"mode": "interval",
|
||||
"frame_interval": 30
|
||||
}'
|
||||
```
|
||||
|
||||
### 5. 视频转文本
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/v1/video/to-text \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"file_path": "video/202404/001.mp4",
|
||||
"source_id": 10,
|
||||
"job_id": 43,
|
||||
"start_sec": 0,
|
||||
"end_sec": 60
|
||||
}'
|
||||
```
|
||||
|
||||
### 6. 文本问答生成
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/v1/qa/gen-text \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"items": [
|
||||
{
|
||||
"subject": "变压器",
|
||||
"predicate": "额定电压",
|
||||
"object": "110kV",
|
||||
"source_snippet": "该变压器额定电压为110kV"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 7. 图像问答生成
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/v1/qa/gen-image \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"items": [
|
||||
{
|
||||
"subject": "电缆接头",
|
||||
"predicate": "位于",
|
||||
"object": "配电箱左侧",
|
||||
"cropped_image_path": "crops/1/0.jpg"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 8. 微调任务提交
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/v1/finetune/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"jsonl_url": "https://example.com/train.jsonl",
|
||||
"base_model": "qwen3-14b",
|
||||
"hyperparams": {
|
||||
"epochs": 3,
|
||||
"learning_rate": 0.0001
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## 数据输出约定
|
||||
|
||||
当前服务会主动写入这些派生结果:
|
||||
|
||||
| 类型 | 路径模式 | 说明 |
|
||||
|---|---|---|
|
||||
| 图像裁剪图 | `crops/{task_id}/{index}.jpg` | 图像提取结果的局部证据图 |
|
||||
| 视频抽帧图片 | `frames/{source_id}/{index}.jpg` | 视频帧提取结果 |
|
||||
| 视频文本描述 | `video-text/{source_id}/{timestamp}.txt` | 视频转文本结果 |
|
||||
|
||||
说明:
|
||||
|
||||
- 这些对象默认写入 `storage.buckets.source_data`
|
||||
- 原始文件的上传路径由上游系统决定
|
||||
- 服务不会替上游生成原始文件路径,只消费请求里传入的 `file_path`
|
||||
|
||||
## 日志与错误处理
|
||||
|
||||
### 日志
|
||||
|
||||
日志使用 JSON 格式输出,适合直接接入容器日志平台。请求日志会带上:
|
||||
|
||||
- `method`
|
||||
- `path`
|
||||
- `status`
|
||||
- `duration_ms`
|
||||
|
||||
LLM 调用和后台任务也会输出结构化字段,方便排查接口超时、回调失败和模型解析错误。
|
||||
|
||||
### 错误码
|
||||
|
||||
统一错误返回格式:
|
||||
|
||||
```json
|
||||
{"code":"ERROR_CODE","message":"具体描述"}
|
||||
```
|
||||
|
||||
当前主要错误码:
|
||||
|
||||
| 错误码 | HTTP 状态码 | 含义 |
|
||||
|---|---|---|
|
||||
| `UNSUPPORTED_FILE_TYPE` | 400 | 文本提取文件格式不支持 |
|
||||
| `VIDEO_TOO_LARGE` | 400 | 视频大小超过限制 |
|
||||
| `STORAGE_ERROR` | 502 | 对象存储访问失败 |
|
||||
| `LLM_PARSE_ERROR` | 502 | 模型返回内容无法解析为预期 JSON |
|
||||
| `LLM_CALL_ERROR` | 503 | 模型调用或微调接口调用失败 |
|
||||
| `INTERNAL_ERROR` | 500 | 未捕获异常 |
|
||||
|
||||
## 测试
|
||||
|
||||
安装依赖后可直接运行:
|
||||
|
||||
```bash
|
||||
python -m pytest
|
||||
```
|
||||
|
||||
测试覆盖了这些主要模块:
|
||||
|
||||
- 健康检查接口
|
||||
- 文本、图像、视频、QA、微调路由
|
||||
- 各 Service 的基本成功与异常路径
|
||||
- 配置加载和客户端适配
|
||||
|
||||
## 设计文档
|
||||
|
||||
项目内已有更详细的设计资料,可配合 README 阅读:
|
||||
|
||||
- `docs/superpowers/specs/2026-04-10-ai-service-design.md`
|
||||
- `docs/superpowers/plans/2026-04-10-ai-service-impl.md`
|
||||
- `specs/001-ai-service-requirements/`
|
||||
|
||||
如果你刚接手这个服务,建议阅读顺序是:
|
||||
|
||||
1. 本 README,先搞清楚服务职责、接口和运行方式
|
||||
2. 设计文档,再看架构和设计决策
|
||||
3. `app/services` 与 `tests`,最后进入实现细节
|
||||
|
||||
## 已知约束
|
||||
|
||||
- 文本提取目前只支持 `txt`、`pdf`、`docx`
|
||||
- 视频接口依赖对象存储可读取文件大小
|
||||
- 视频任务状态不持久化在本服务内,由上游系统负责管理
|
||||
- 图像问答采用 base64 内联图片,不依赖外网可访问的 presigned URL
|
||||
- 如果 `.env` 中的回调地址不可达,视频任务会记录错误日志,但不会自动重试
|
||||
|
||||
## 开发建议
|
||||
|
||||
- 新增接口时同步补齐 Pydantic 模型、Router 测试和 README/API 文档
|
||||
- 如果替换模型厂商,优先扩展 `app/clients/llm`
|
||||
- 如果替换存储实现,优先扩展 `app/clients/storage`
|
||||
- 任何输出路径规则变更,都应同步更新 README 和设计文档
|
||||
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/clients/__init__.py
Normal file
0
app/clients/__init__.py
Normal file
0
app/clients/llm/__init__.py
Normal file
0
app/clients/llm/__init__.py
Normal file
19
app/clients/llm/base.py
Normal file
19
app/clients/llm/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
@abstractmethod
|
||||
async def chat(self, model: str, messages: list[dict]) -> str:
|
||||
"""Send a text chat request and return the response content string."""
|
||||
|
||||
@abstractmethod
|
||||
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||
"""Send a multimodal (vision) chat request and return the response content string."""
|
||||
|
||||
@abstractmethod
|
||||
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||||
"""Submit a fine-tune job and return the job_id."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_finetune_status(self, job_id: str) -> dict:
|
||||
"""Return a dict with keys: job_id, status (raw SDK string), progress (int|None), error_message (str|None)."""
|
||||
132
app/clients/llm/qwen_client.py
Normal file
132
app/clients/llm/qwen_client.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.core.exceptions import LLMCallError
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class QwenClient(LLMClient):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
fine_tune_base_url: str | None = None,
|
||||
transport: httpx.BaseTransport | None = None,
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._fine_tune_base_url = (
|
||||
fine_tune_base_url.rstrip("/")
|
||||
if fine_tune_base_url
|
||||
else self._base_url.replace("/compatible-mode/v1", "/api/v1")
|
||||
)
|
||||
self._transport = transport
|
||||
|
||||
async def chat(self, model: str, messages: list[dict]) -> str:
|
||||
return await self._chat(model, messages)
|
||||
|
||||
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||
return await self._chat(model, messages)
|
||||
|
||||
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||||
try:
|
||||
file_bytes = await self._download_training_file(jsonl_url)
|
||||
file_id = await self._upload_training_file(file_bytes)
|
||||
payload = {
|
||||
"model": base_model,
|
||||
"training_file_ids": [file_id],
|
||||
}
|
||||
if hyperparams:
|
||||
payload["hyper_parameters"] = hyperparams
|
||||
data = await self._post_json(self._fine_tune_base_url, "/fine-tunes", payload)
|
||||
output = data.get("output", {})
|
||||
job_id = output.get("job_id") or data.get("job_id")
|
||||
if not job_id:
|
||||
raise LLMCallError("千问微调任务提交失败: 缺少 job_id")
|
||||
return job_id
|
||||
except LLMCallError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise LLMCallError(f"千问微调任务提交失败: {exc}") from exc
|
||||
|
||||
async def get_finetune_status(self, job_id: str) -> dict:
|
||||
try:
|
||||
data = await self._get_json(self._fine_tune_base_url, f"/fine-tunes/{job_id}")
|
||||
output = data.get("output", {})
|
||||
return {
|
||||
"job_id": output.get("job_id") or job_id,
|
||||
"status": output.get("status", "").lower(),
|
||||
"progress": output.get("progress"),
|
||||
"error_message": output.get("message"),
|
||||
}
|
||||
except LLMCallError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise LLMCallError(f"查询千问微调任务失败: {exc}") from exc
|
||||
|
||||
async def _chat(self, model: str, messages: list[dict]) -> str:
|
||||
try:
|
||||
data = await self._post_json(
|
||||
self._base_url,
|
||||
"/chat/completions",
|
||||
{"model": model, "messages": messages},
|
||||
)
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
if isinstance(content, list):
|
||||
return "".join(
|
||||
part.get("text", "") if isinstance(part, dict) else str(part)
|
||||
for part in content
|
||||
)
|
||||
logger.info("llm_call", extra={"model": model, "response_len": len(content)})
|
||||
return content
|
||||
except LLMCallError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("llm_call_error", extra={"model": model, "error": str(exc)})
|
||||
raise LLMCallError(f"千问大模型调用失败: {exc}") from exc
|
||||
|
||||
async def _download_training_file(self, jsonl_url: str) -> bytes:
|
||||
async with self._build_client() as client:
|
||||
response = await client.get(jsonl_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
async def _upload_training_file(self, file_bytes: bytes) -> str:
|
||||
async with self._build_client(base_url=self._base_url) as client:
|
||||
response = await client.post(
|
||||
"/files",
|
||||
data={"purpose": "fine-tune"},
|
||||
files={"file": ("training.jsonl", file_bytes, "application/jsonl")},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
file_id = data.get("id")
|
||||
if not file_id:
|
||||
raise LLMCallError("千问训练文件上传失败: 缺少 file id")
|
||||
return file_id
|
||||
|
||||
async def _post_json(self, base_url: str, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
async with self._build_client(base_url=base_url) as client:
|
||||
response = await client.post(path, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _get_json(self, base_url: str, path: str) -> dict[str, Any]:
|
||||
async with self._build_client(base_url=base_url) as client:
|
||||
response = await client.get(path)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _build_client(self, base_url: str | None = None) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
base_url=base_url or self._base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
},
|
||||
transport=self._transport,
|
||||
timeout=60,
|
||||
)
|
||||
68
app/clients/llm/zhipuai_client.py
Normal file
68
app/clients/llm/zhipuai_client.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import asyncio
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.core.exceptions import LLMCallError
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ZhipuAIClient(LLMClient):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self._client = ZhipuAI(api_key=api_key)
|
||||
|
||||
async def chat(self, model: str, messages: list[dict]) -> str:
|
||||
return await self._call(model, messages)
|
||||
|
||||
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||
return await self._call(model, messages)
|
||||
|
||||
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
resp = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._client.fine_tuning.jobs.create(
|
||||
training_file=jsonl_url,
|
||||
model=base_model,
|
||||
hyperparameters=hyperparams,
|
||||
),
|
||||
)
|
||||
return resp.id
|
||||
except Exception as exc:
|
||||
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
|
||||
|
||||
async def get_finetune_status(self, job_id: str) -> dict:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
resp = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._client.fine_tuning.jobs.retrieve(job_id),
|
||||
)
|
||||
return {
|
||||
"job_id": resp.id,
|
||||
"status": resp.status,
|
||||
"progress": int(resp.progress) if getattr(resp, "progress", None) is not None else None,
|
||||
"error_message": getattr(resp, "error_message", None),
|
||||
}
|
||||
except Exception as exc:
|
||||
raise LLMCallError(f"查询微调任务失败: {exc}") from exc
|
||||
|
||||
async def _call(self, model: str, messages: list[dict]) -> str:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
),
|
||||
)
|
||||
content = response.choices[0].message.content
|
||||
logger.info("llm_call", extra={"model": model, "response_len": len(content)})
|
||||
return content
|
||||
except Exception as exc:
|
||||
logger.error("llm_call_error", extra={"model": model, "error": str(exc)})
|
||||
raise LLMCallError(f"大模型调用失败: {exc}") from exc
|
||||
0
app/clients/storage/__init__.py
Normal file
0
app/clients/storage/__init__.py
Normal file
21
app/clients/storage/base.py
Normal file
21
app/clients/storage/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class StorageClient(ABC):
|
||||
@abstractmethod
|
||||
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||||
"""Download an object and return its raw bytes."""
|
||||
|
||||
@abstractmethod
|
||||
async def upload_bytes(
|
||||
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
||||
) -> None:
|
||||
"""Upload raw bytes to the given bucket/path."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||||
"""Return a presigned GET URL valid for `expires` seconds."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_object_size(self, bucket: str, path: str) -> int:
|
||||
"""Return the object size in bytes without downloading it."""
|
||||
71
app/clients/storage/rustfs_client.py
Normal file
71
app/clients/storage/rustfs_client.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import io
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.exceptions import StorageError
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RustFSClient(StorageClient):
|
||||
def __init__(self, endpoint: str, access_key: str, secret_key: str) -> None:
|
||||
self._s3 = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=endpoint,
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
)
|
||||
|
||||
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
resp = await loop.run_in_executor(
|
||||
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
|
||||
)
|
||||
return resp["Body"].read()
|
||||
except ClientError as exc:
|
||||
print(exc)
|
||||
raise StorageError(f"存储下载失败 [{bucket}/{path}]: {exc}") from exc
|
||||
|
||||
async def upload_bytes(
|
||||
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
||||
) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._s3.put_object(
|
||||
Bucket=bucket, Key=path, Body=io.BytesIO(data), ContentType=content_type
|
||||
),
|
||||
)
|
||||
except ClientError as exc:
|
||||
raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc
|
||||
|
||||
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
url = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._s3.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": bucket, "Key": path},
|
||||
ExpiresIn=expires,
|
||||
),
|
||||
)
|
||||
return url
|
||||
except ClientError as exc:
|
||||
raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc
|
||||
|
||||
async def get_object_size(self, bucket: str, path: str) -> int:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
resp = await loop.run_in_executor(
|
||||
None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
|
||||
)
|
||||
return resp["ContentLength"]
|
||||
except ClientError as exc:
|
||||
raise StorageError(f"获取文件大小失败 [{bucket}/{path}]: {exc}") from exc
|
||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
49
app/core/config.py
Normal file
49
app/core/config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Maps environment variable names to nested YAML key paths
|
||||
_ENV_OVERRIDES: dict[str, list[str]] = {
|
||||
"DASHSCOPE_API_KEY": ["dashscope", "api_key"],
|
||||
"DASHSCOPE_BASE_URL": ["dashscope", "base_url"],
|
||||
"DASHSCOPE_FINE_TUNE_BASE_URL": ["dashscope", "fine_tune_base_url"],
|
||||
"ZHIPUAI_API_KEY": ["zhipuai", "api_key"],
|
||||
"STORAGE_ACCESS_KEY": ["storage", "access_key"],
|
||||
"STORAGE_SECRET_KEY": ["storage", "secret_key"],
|
||||
"STORAGE_ENDPOINT": ["storage", "endpoint"],
|
||||
"BACKEND_CALLBACK_URL": ["backend", "callback_url"],
|
||||
"LOG_LEVEL": ["server", "log_level"],
|
||||
"MAX_VIDEO_SIZE_MB": ["video", "max_file_size_mb"],
|
||||
}
|
||||
|
||||
_CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
|
||||
|
||||
|
||||
def _set_nested(cfg: dict, keys: list[str], value: Any) -> None:
|
||||
for key in keys[:-1]:
|
||||
cfg = cfg.setdefault(key, {})
|
||||
# Coerce numeric env vars
|
||||
try:
|
||||
value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
cfg[keys[-1]] = value
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> dict:
|
||||
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
cfg: dict = yaml.safe_load(f)
|
||||
|
||||
for env_var, key_path in _ENV_OVERRIDES.items():
|
||||
value = os.environ.get(env_var)
|
||||
if value is not None:
|
||||
_set_nested(cfg, key_path, value)
|
||||
|
||||
return cfg
|
||||
28
app/core/dependencies.py
Normal file
28
app/core/dependencies.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.llm.qwen_client import QwenClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.clients.storage.rustfs_client import RustFSClient
|
||||
from app.core.config import get_config
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_llm_client() -> LLMClient:
|
||||
cfg = get_config()
|
||||
dashscope_cfg = cfg["dashscope"]
|
||||
return QwenClient(
|
||||
api_key=dashscope_cfg["api_key"],
|
||||
base_url=dashscope_cfg.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
fine_tune_base_url=dashscope_cfg.get("fine_tune_base_url"),
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_storage_client() -> StorageClient:
|
||||
cfg = get_config()
|
||||
return RustFSClient(
|
||||
endpoint=cfg["storage"]["endpoint"],
|
||||
access_key=cfg["storage"]["access_key"],
|
||||
secret_key=cfg["storage"]["secret_key"],
|
||||
)
|
||||
50
app/core/exceptions.py
Normal file
50
app/core/exceptions.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class AIServiceError(Exception):
|
||||
status_code: int = 500
|
||||
code: str = "INTERNAL_ERROR"
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(AIServiceError):
|
||||
status_code = 400
|
||||
code = "UNSUPPORTED_FILE_TYPE"
|
||||
|
||||
|
||||
class VideoTooLargeError(AIServiceError):
|
||||
status_code = 400
|
||||
code = "VIDEO_TOO_LARGE"
|
||||
|
||||
|
||||
class StorageError(AIServiceError):
|
||||
status_code = 502
|
||||
code = "STORAGE_ERROR"
|
||||
|
||||
|
||||
class LLMParseError(AIServiceError):
|
||||
status_code = 502
|
||||
code = "LLM_PARSE_ERROR"
|
||||
|
||||
|
||||
class LLMCallError(AIServiceError):
|
||||
status_code = 503
|
||||
code = "LLM_CALL_ERROR"
|
||||
|
||||
|
||||
async def ai_service_exception_handler(request: Request, exc: AIServiceError) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"code": exc.code, "message": exc.message},
|
||||
)
|
||||
|
||||
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"code": "INTERNAL_ERROR", "message": str(exc)},
|
||||
)
|
||||
19
app/core/json_utils.py
Normal file
19
app/core/json_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
from app.core.exceptions import LLMParseError
|
||||
|
||||
|
||||
def extract_json(text: str) -> any:
|
||||
"""Parse JSON from LLM response, stripping Markdown code fences if present."""
|
||||
text = text.strip()
|
||||
|
||||
# Strip ```json ... ``` or ``` ... ``` fences
|
||||
fence_match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", text)
|
||||
if fence_match:
|
||||
text = fence_match.group(1).strip()
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise LLMParseError(f"大模型返回非合法 JSON: {e}") from e
|
||||
62
app/core/logging.py
Normal file
62
app/core/logging.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(_JsonFormatter())
|
||||
logger.addHandler(handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
|
||||
class _JsonFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
payload = {
|
||||
"time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
if record.exc_info:
|
||||
payload["exc_info"] = self.formatException(record.exc_info)
|
||||
# Merge any extra fields passed via `extra=`
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in (
|
||||
"name", "msg", "args", "levelname", "levelno", "pathname",
|
||||
"filename", "module", "exc_info", "exc_text", "stack_info",
|
||||
"lineno", "funcName", "created", "msecs", "relativeCreated",
|
||||
"thread", "threadName", "processName", "process", "message",
|
||||
"taskName",
|
||||
):
|
||||
payload[key] = value
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, logger: logging.Logger | None = None) -> None:
|
||||
super().__init__(app)
|
||||
self._logger = logger or get_logger("request")
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
start = time.perf_counter()
|
||||
response = await call_next(request)
|
||||
duration_ms = round((time.perf_counter() - start) * 1000, 1)
|
||||
self._logger.info(
|
||||
"request",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"status": response.status_code,
|
||||
"duration_ms": duration_ms,
|
||||
},
|
||||
)
|
||||
return response
|
||||
51
app/main.py
Normal file
51
app/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.core.exceptions import (
|
||||
AIServiceError,
|
||||
ai_service_exception_handler,
|
||||
unhandled_exception_handler,
|
||||
)
|
||||
from app.core.logging import RequestLoggingMiddleware, get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("startup", extra={"message": "AI service starting"})
|
||||
yield
|
||||
logger.info("shutdown", extra={"message": "AI service stopping"})
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Label AI Service",
|
||||
description="知识图谱标注平台 AI 计算服务",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
app.add_exception_handler(AIServiceError, ai_service_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# Routers registered after implementation (imported lazily to avoid circular deps)
|
||||
from app.routers import text, image, video, qa, finetune # noqa: E402
|
||||
|
||||
app.include_router(text.router, prefix="/api/v1")
|
||||
app.include_router(image.router, prefix="/api/v1")
|
||||
app.include_router(video.router, prefix="/api/v1")
|
||||
app.include_router(qa.router, prefix="/api/v1")
|
||||
app.include_router(finetune.router, prefix="/api/v1")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
0
app/models/__init__.py
Normal file
0
app/models/__init__.py
Normal file
18
app/models/finetune_models.py
Normal file
18
app/models/finetune_models.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FinetuneStartRequest(BaseModel):
|
||||
jsonl_url: str
|
||||
base_model: str
|
||||
hyperparams: dict | None = None
|
||||
|
||||
|
||||
class FinetuneStartResponse(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
class FinetuneStatusResponse(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
progress: int | None = None
|
||||
error_message: str | None = None
|
||||
28
app/models/image_models.py
Normal file
28
app/models/image_models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BBox(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
w: int
|
||||
h: int
|
||||
|
||||
|
||||
class QuadrupleItem(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
qualifier: str | None = None
|
||||
bbox: BBox
|
||||
cropped_image_path: str
|
||||
|
||||
|
||||
class ImageExtractRequest(BaseModel):
|
||||
file_path: str
|
||||
task_id: int
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class ImageExtractResponse(BaseModel):
|
||||
items: list[QuadrupleItem]
|
||||
47
app/models/qa_models.py
Normal file
47
app/models/qa_models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TextQAItem(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
source_snippet: str
|
||||
|
||||
|
||||
class GenTextQARequest(BaseModel):
|
||||
items: list[TextQAItem]
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class QAPair(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class ImageQAItem(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
qualifier: str | None = None
|
||||
cropped_image_path: str
|
||||
|
||||
|
||||
class GenImageQARequest(BaseModel):
|
||||
items: list[ImageQAItem]
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class ImageQAPair(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
image_path: str
|
||||
|
||||
|
||||
class TextQAResponse(BaseModel):
|
||||
pairs: list[QAPair]
|
||||
|
||||
|
||||
class ImageQAResponse(BaseModel):
|
||||
pairs: list[ImageQAPair]
|
||||
25
app/models/text_models.py
Normal file
25
app/models/text_models.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SourceOffset(BaseModel):
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
class TripleItem(BaseModel):
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
source_snippet: str
|
||||
source_offset: SourceOffset
|
||||
|
||||
|
||||
class TextExtractRequest(BaseModel):
|
||||
file_path: str
|
||||
file_name: str
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class TextExtractResponse(BaseModel):
|
||||
items: list[TripleItem]
|
||||
38
app/models/video_models.py
Normal file
38
app/models/video_models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExtractFramesRequest(BaseModel):
|
||||
file_path: str
|
||||
source_id: int
|
||||
job_id: int
|
||||
mode: str = "interval"
|
||||
frame_interval: int = 30
|
||||
|
||||
|
||||
class VideoToTextRequest(BaseModel):
|
||||
file_path: str
|
||||
source_id: int
|
||||
job_id: int
|
||||
start_sec: float
|
||||
end_sec: float
|
||||
model: str | None = None
|
||||
prompt_template: str | None = None
|
||||
|
||||
|
||||
class FrameInfo(BaseModel):
|
||||
frame_index: int
|
||||
time_sec: float
|
||||
frame_path: str
|
||||
|
||||
|
||||
class VideoJobCallback(BaseModel):
|
||||
job_id: int
|
||||
status: str
|
||||
frames: list[FrameInfo] | None = None
|
||||
output_path: str | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class VideoAcceptedResponse(BaseModel):
|
||||
message: str
|
||||
job_id: int
|
||||
0
app/routers/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
28
app/routers/finetune.py
Normal file
28
app/routers/finetune.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.core.dependencies import get_llm_client
|
||||
from app.models.finetune_models import (
|
||||
FinetuneStartRequest,
|
||||
FinetuneStartResponse,
|
||||
FinetuneStatusResponse,
|
||||
)
|
||||
from app.services import finetune_service
|
||||
|
||||
router = APIRouter(tags=["Finetune"])
|
||||
|
||||
|
||||
@router.post("/finetune/start", response_model=FinetuneStartResponse)
|
||||
async def start_finetune(
|
||||
req: FinetuneStartRequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
) -> FinetuneStartResponse:
|
||||
return await finetune_service.submit_finetune(req, llm)
|
||||
|
||||
|
||||
@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse)
|
||||
async def get_status(
|
||||
job_id: str,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
) -> FinetuneStatusResponse:
|
||||
return await finetune_service.get_finetune_status(job_id, llm)
|
||||
18
app/routers/image.py
Normal file
18
app/routers/image.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.dependencies import get_llm_client, get_storage_client
|
||||
from app.models.image_models import ImageExtractRequest, ImageExtractResponse
|
||||
from app.services import image_service
|
||||
|
||||
router = APIRouter(tags=["Image"])
|
||||
|
||||
|
||||
@router.post("/image/extract", response_model=ImageExtractResponse)
|
||||
async def extract_image(
|
||||
req: ImageExtractRequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
storage: StorageClient = Depends(get_storage_client),
|
||||
) -> ImageExtractResponse:
|
||||
return await image_service.extract_quads(req, llm, storage)
|
||||
31
app/routers/qa.py
Normal file
31
app/routers/qa.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.dependencies import get_llm_client, get_storage_client
|
||||
from app.models.qa_models import (
|
||||
GenImageQARequest,
|
||||
GenTextQARequest,
|
||||
ImageQAResponse,
|
||||
TextQAResponse,
|
||||
)
|
||||
from app.services import qa_service
|
||||
|
||||
router = APIRouter(tags=["QA"])
|
||||
|
||||
|
||||
@router.post("/qa/gen-text", response_model=TextQAResponse)
|
||||
async def gen_text_qa(
|
||||
req: GenTextQARequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
) -> TextQAResponse:
|
||||
return await qa_service.gen_text_qa(req, llm)
|
||||
|
||||
|
||||
@router.post("/qa/gen-image", response_model=ImageQAResponse)
|
||||
async def gen_image_qa(
|
||||
req: GenImageQARequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
storage: StorageClient = Depends(get_storage_client),
|
||||
) -> ImageQAResponse:
|
||||
return await qa_service.gen_image_qa(req, llm, storage)
|
||||
18
app/routers/text.py
Normal file
18
app/routers/text.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.dependencies import get_llm_client, get_storage_client
|
||||
from app.models.text_models import TextExtractRequest, TextExtractResponse
|
||||
from app.services import text_service
|
||||
|
||||
router = APIRouter(tags=["Text"])
|
||||
|
||||
|
||||
@router.post("/text/extract", response_model=TextExtractResponse)
|
||||
async def extract_text(
|
||||
req: TextExtractRequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
storage: StorageClient = Depends(get_storage_client),
|
||||
) -> TextExtractResponse:
|
||||
return await text_service.extract_triples(req, llm, storage)
|
||||
69
app/routers/video.py
Normal file
69
app/routers/video.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.config import get_config
|
||||
from app.core.dependencies import get_llm_client, get_storage_client
|
||||
from app.core.exceptions import VideoTooLargeError
|
||||
from app.models.video_models import (
|
||||
ExtractFramesRequest,
|
||||
VideoAcceptedResponse,
|
||||
VideoToTextRequest,
|
||||
)
|
||||
from app.services import video_service
|
||||
|
||||
router = APIRouter(tags=["Video"])
|
||||
|
||||
|
||||
async def _check_video_size(storage: StorageClient, bucket: str, file_path: str, max_mb: int) -> None:
|
||||
size_bytes = await storage.get_object_size(bucket, file_path)
|
||||
if size_bytes > max_mb * 1024 * 1024:
|
||||
raise VideoTooLargeError(
|
||||
f"视频文件大小超出限制(最大 {max_mb}MB,当前 {size_bytes // 1024 // 1024}MB)"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/video/extract-frames", response_model=VideoAcceptedResponse, status_code=202)
|
||||
async def extract_frames(
|
||||
req: ExtractFramesRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
storage: StorageClient = Depends(get_storage_client),
|
||||
) -> VideoAcceptedResponse:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
max_mb = cfg["video"]["max_file_size_mb"]
|
||||
callback_url = cfg.get("backend", {}).get("callback_url", "")
|
||||
|
||||
await _check_video_size(storage, bucket, req.file_path, max_mb)
|
||||
|
||||
background_tasks.add_task(
|
||||
video_service.extract_frames_task,
|
||||
req,
|
||||
storage,
|
||||
callback_url,
|
||||
)
|
||||
return VideoAcceptedResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
||||
|
||||
|
||||
@router.post("/video/to-text", response_model=VideoAcceptedResponse, status_code=202)
|
||||
async def video_to_text(
|
||||
req: VideoToTextRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
storage: StorageClient = Depends(get_storage_client),
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
) -> VideoAcceptedResponse:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
max_mb = cfg["video"]["max_file_size_mb"]
|
||||
callback_url = cfg.get("backend", {}).get("callback_url", "")
|
||||
|
||||
await _check_video_size(storage, bucket, req.file_path, max_mb)
|
||||
|
||||
background_tasks.add_task(
|
||||
video_service.video_to_text_task,
|
||||
req,
|
||||
llm,
|
||||
storage,
|
||||
callback_url,
|
||||
)
|
||||
return VideoAcceptedResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
||||
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
35
app/services/finetune_service.py
Normal file
35
app/services/finetune_service.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.core.logging import get_logger
|
||||
from app.models.finetune_models import (
|
||||
FinetuneStartRequest,
|
||||
FinetuneStartResponse,
|
||||
FinetuneStatusResponse,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_STATUS_MAP = {
|
||||
"running": "RUNNING",
|
||||
"succeeded": "SUCCESS",
|
||||
"failed": "FAILED",
|
||||
}
|
||||
|
||||
|
||||
async def submit_finetune(req: FinetuneStartRequest, llm: LLMClient) -> FinetuneStartResponse:
|
||||
"""Submit a fine-tune job via the LLMClient interface and return the job ID."""
|
||||
job_id = await llm.submit_finetune(req.jsonl_url, req.base_model, req.hyperparams or {})
|
||||
logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model})
|
||||
return FinetuneStartResponse(job_id=job_id)
|
||||
|
||||
|
||||
async def get_finetune_status(job_id: str, llm: LLMClient) -> FinetuneStatusResponse:
|
||||
"""Retrieve fine-tune job status via the LLMClient interface."""
|
||||
raw = await llm.get_finetune_status(job_id)
|
||||
status = _STATUS_MAP.get(raw["status"], "RUNNING")
|
||||
logger.info("finetune_status", extra={"job_id": job_id, "status": status})
|
||||
return FinetuneStatusResponse(
|
||||
job_id=raw["job_id"],
|
||||
status=status,
|
||||
progress=raw["progress"],
|
||||
error_message=raw["error_message"],
|
||||
)
|
||||
90
app/services/image_service.py
Normal file
90
app/services/image_service.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import base64
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.config import get_config
|
||||
from app.core.json_utils import extract_json
|
||||
from app.core.logging import get_logger
|
||||
from app.models.image_models import (
|
||||
BBox,
|
||||
ImageExtractRequest,
|
||||
ImageExtractResponse,
|
||||
QuadrupleItem,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_DEFAULT_PROMPT = (
|
||||
"请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:"
|
||||
"subject(主体实体)、predicate(关系/属性)、object(客体实体)、"
|
||||
"qualifier(修饰信息,可为 null)、bbox({{x, y, w, h}} 像素坐标)。"
|
||||
)
|
||||
|
||||
|
||||
async def extract_quads(
|
||||
req: ImageExtractRequest,
|
||||
llm: LLMClient,
|
||||
storage: StorageClient,
|
||||
) -> ImageExtractResponse:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
model = req.model or cfg["models"]["default_vision"]
|
||||
|
||||
image_bytes = await storage.download_bytes(bucket, req.file_path)
|
||||
|
||||
# Decode with OpenCV for cropping; encode as base64 for LLM
|
||||
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
img_h, img_w = img.shape[:2]
|
||||
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||||
|
||||
prompt = req.prompt_template or _DEFAULT_PROMPT
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
raw = await llm.chat_vision(model, messages)
|
||||
logger.info("image_extract", extra={"file": req.file_path, "model": model})
|
||||
|
||||
items_raw = extract_json(raw)
|
||||
items: list[QuadrupleItem] = []
|
||||
|
||||
for idx, item in enumerate(items_raw):
|
||||
b = item["bbox"]
|
||||
# Clamp bbox to image dimensions
|
||||
x = max(0, min(int(b["x"]), img_w - 1))
|
||||
y = max(0, min(int(b["y"]), img_h - 1))
|
||||
w = min(int(b["w"]), img_w - x)
|
||||
h = min(int(b["h"]), img_h - y)
|
||||
|
||||
crop = img[y : y + h, x : x + w]
|
||||
_, crop_buf = cv2.imencode(".jpg", crop)
|
||||
crop_bytes = crop_buf.tobytes()
|
||||
|
||||
crop_path = f"crops/{req.task_id}/{idx}.jpg"
|
||||
await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg")
|
||||
|
||||
items.append(
|
||||
QuadrupleItem(
|
||||
subject=item["subject"],
|
||||
predicate=item["predicate"],
|
||||
object=item["object"],
|
||||
qualifier=item.get("qualifier"),
|
||||
bbox=BBox(x=x, y=y, w=w, h=h),
|
||||
cropped_image_path=crop_path,
|
||||
)
|
||||
)
|
||||
|
||||
return ImageExtractResponse(items=items)
|
||||
118
app/services/qa_service.py
Normal file
118
app/services/qa_service.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import base64
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.config import get_config
|
||||
from app.core.exceptions import LLMParseError
|
||||
from app.core.json_utils import extract_json
|
||||
from app.core.logging import get_logger
|
||||
from app.models.qa_models import (
|
||||
GenImageQARequest,
|
||||
GenTextQARequest,
|
||||
ImageQAPair,
|
||||
ImageQAResponse,
|
||||
QAPair,
|
||||
TextQAResponse,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_DEFAULT_TEXT_PROMPT = (
|
||||
"请根据以下知识三元组生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。\n\n"
|
||||
"三元组列表:\n{triples_text}"
|
||||
)
|
||||
|
||||
_DEFAULT_IMAGE_PROMPT = (
|
||||
"请根据图片内容和以下四元组信息生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。"
|
||||
)
|
||||
|
||||
|
||||
async def gen_text_qa(req: GenTextQARequest, llm: LLMClient) -> TextQAResponse:
|
||||
cfg = get_config()
|
||||
model = req.model or cfg["models"]["default_text"]
|
||||
|
||||
# Format all triples + source snippets into a single batch prompt
|
||||
triple_lines: list[str] = []
|
||||
for item in req.items:
|
||||
triple_lines.append(
|
||||
f"({item.subject}, {item.predicate}, {item.object}) — 来源: {item.source_snippet}"
|
||||
)
|
||||
triples_text = "\n".join(triple_lines)
|
||||
|
||||
prompt_template = req.prompt_template or _DEFAULT_TEXT_PROMPT
|
||||
if "{triples_text}" in prompt_template:
|
||||
prompt = prompt_template.format(triples_text=triples_text)
|
||||
else:
|
||||
prompt = prompt_template + "\n\n" + triples_text
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
raw = await llm.chat(model, messages)
|
||||
|
||||
items_raw = extract_json(raw)
|
||||
logger.info("gen_text_qa", extra={"items": len(req.items), "model": model})
|
||||
|
||||
if not isinstance(items_raw, list):
|
||||
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||
try:
|
||||
pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw]
|
||||
except (KeyError, TypeError):
|
||||
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||
return TextQAResponse(pairs=pairs)
|
||||
|
||||
|
||||
async def gen_image_qa(
|
||||
req: GenImageQARequest,
|
||||
llm: LLMClient,
|
||||
storage: StorageClient,
|
||||
) -> ImageQAResponse:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
model = req.model or cfg["models"]["default_vision"]
|
||||
|
||||
prompt = req.prompt_template or _DEFAULT_IMAGE_PROMPT
|
||||
|
||||
pairs: list[ImageQAPair] = []
|
||||
|
||||
for item in req.items:
|
||||
# Download cropped image bytes from storage
|
||||
image_bytes = await storage.download_bytes(bucket, item.cropped_image_path)
|
||||
|
||||
# Base64 encode inline for multimodal message
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||||
|
||||
# Build quad info text
|
||||
quad_text = f"{item.subject} — {item.predicate} — {item.object}"
|
||||
if item.qualifier:
|
||||
quad_text += f" ({item.qualifier})"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||
{"type": "text", "text": f"{prompt}\n\n{quad_text}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
raw = await llm.chat_vision(model, messages)
|
||||
|
||||
items_raw = extract_json(raw)
|
||||
logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model})
|
||||
|
||||
if not isinstance(items_raw, list):
|
||||
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||
try:
|
||||
for qa in items_raw:
|
||||
pairs.append(
|
||||
ImageQAPair(
|
||||
question=qa["question"],
|
||||
answer=qa["answer"],
|
||||
image_path=item.cropped_image_path,
|
||||
)
|
||||
)
|
||||
except (KeyError, TypeError):
|
||||
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||
|
||||
return ImageQAResponse(pairs=pairs)
|
||||
109
app/services/text_service.py
Normal file
109
app/services/text_service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import io
|
||||
|
||||
import pdfplumber
|
||||
import docx
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.config import get_config
|
||||
from app.core.exceptions import StorageError, UnsupportedFileTypeError
|
||||
from app.core.json_utils import extract_json
|
||||
from app.core.logging import get_logger
|
||||
from app.models.text_models import (
|
||||
SourceOffset,
|
||||
TextExtractRequest,
|
||||
TextExtractResponse,
|
||||
TripleItem,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_SUPPORTED_EXTENSIONS = {".txt", ".pdf", ".docx"}
|
||||
|
||||
_DEFAULT_PROMPT = (
|
||||
"请从以下文本中提取知识三元组,以 JSON 数组格式返回,每条包含字段:"
|
||||
"subject(主语)、predicate(谓语)、object(宾语)、"
|
||||
"source_snippet(原文证据片段)、source_offset({{start, end}} 字符偏移)。\n\n"
|
||||
"文本内容:\n{text}"
|
||||
)
|
||||
|
||||
|
||||
def _file_extension(file_name: str) -> str:
|
||||
idx = file_name.rfind(".")
|
||||
return file_name[idx:].lower() if idx != -1 else ""
|
||||
|
||||
|
||||
def _parse_txt(data: bytes) -> str:
|
||||
return data.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _parse_pdf(data: bytes) -> str:
|
||||
with pdfplumber.open(io.BytesIO(data)) as pdf:
|
||||
pages = [page.extract_text() or "" for page in pdf.pages]
|
||||
return "\n".join(pages)
|
||||
|
||||
|
||||
def _parse_docx(data: bytes) -> str:
|
||||
doc = docx.Document(io.BytesIO(data))
|
||||
return "\n".join(p.text for p in doc.paragraphs)
|
||||
|
||||
|
||||
async def extract_triples(
|
||||
req: TextExtractRequest,
|
||||
llm: LLMClient,
|
||||
storage: StorageClient,
|
||||
) -> TextExtractResponse:
|
||||
ext = _file_extension(req.file_name)
|
||||
if ext not in _SUPPORTED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError(f"不支持的文件格式: {ext}")
|
||||
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
model = req.model or cfg["models"]["default_text"]
|
||||
|
||||
try:
|
||||
data = await storage.download_bytes(bucket, req.file_path)
|
||||
logger.info("文件下载成功", extra={
|
||||
"file_name": req.file_name,
|
||||
"size_bytes": len(data)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("文件下载失败", extra={
|
||||
"file_name": req.file_name,
|
||||
"file_path": req.file_path,
|
||||
"bucket": bucket,
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e)
|
||||
}, exc_info=True)
|
||||
raise StorageError(f"下载文件失败: {str(e)}") from e
|
||||
|
||||
if ext == ".txt":
|
||||
text = _parse_txt(data)
|
||||
elif ext == ".pdf":
|
||||
text = _parse_pdf(data)
|
||||
else:
|
||||
text = _parse_docx(data)
|
||||
|
||||
prompt_template = req.prompt_template or _DEFAULT_PROMPT
|
||||
prompt = prompt_template.format(text=text) if "{text}" in prompt_template else prompt_template + "\n\n" + text
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
raw = await llm.chat(model, messages)
|
||||
|
||||
logger.info("text_extract", extra={"file": req.file_name, "model": model})
|
||||
|
||||
items_raw = extract_json(raw)
|
||||
items = [
|
||||
TripleItem(
|
||||
subject=item["subject"],
|
||||
predicate=item["predicate"],
|
||||
object=item["object"],
|
||||
source_snippet=item["source_snippet"],
|
||||
source_offset=SourceOffset(
|
||||
start=item["source_offset"]["start"],
|
||||
end=item["source_offset"]["end"],
|
||||
),
|
||||
)
|
||||
for item in items_raw
|
||||
]
|
||||
return TextExtractResponse(items=items)
|
||||
189
app/services/video_service.py
Normal file
189
app/services/video_service.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import cv2
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.clients.storage.base import StorageClient
|
||||
from app.core.config import get_config
|
||||
from app.core.logging import get_logger
|
||||
from app.models.video_models import ExtractFramesRequest, FrameInfo, VideoToTextRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def _post_callback(url: str, payload: dict) -> None:
|
||||
async with httpx.AsyncClient(timeout=10) as http:
|
||||
try:
|
||||
await http.post(url, json=payload)
|
||||
except Exception as exc:
|
||||
logger.error("callback_failed", extra={"url": url, "error": str(exc)})
|
||||
|
||||
|
||||
async def extract_frames_task(
|
||||
req: ExtractFramesRequest,
|
||||
storage: StorageClient,
|
||||
callback_url: str,
|
||||
) -> None:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
threshold = cfg["video"].get("keyframe_diff_threshold", 30.0)
|
||||
|
||||
tmp = None
|
||||
try:
|
||||
video_bytes = await storage.download_bytes(bucket, req.file_path)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
||||
f.write(video_bytes)
|
||||
tmp = f.name
|
||||
|
||||
cap = cv2.VideoCapture(tmp)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||
frames_info: list[FrameInfo] = []
|
||||
upload_index = 0
|
||||
prev_gray = None
|
||||
frame_idx = 0
|
||||
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
extract = False
|
||||
if req.mode == "interval":
|
||||
extract = (frame_idx % req.frame_interval == 0)
|
||||
else: # keyframe
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
|
||||
if prev_gray is None:
|
||||
extract = True
|
||||
else:
|
||||
diff = np.mean(np.abs(gray - prev_gray))
|
||||
extract = diff > threshold
|
||||
prev_gray = gray
|
||||
|
||||
if extract:
|
||||
time_sec = round(frame_idx / fps, 3)
|
||||
_, buf = cv2.imencode(".jpg", frame)
|
||||
frame_path = f"frames/{req.source_id}/{upload_index}.jpg"
|
||||
await storage.upload_bytes(bucket, frame_path, buf.tobytes(), "image/jpeg")
|
||||
frames_info.append(FrameInfo(
|
||||
frame_index=frame_idx,
|
||||
time_sec=time_sec,
|
||||
frame_path=frame_path,
|
||||
))
|
||||
upload_index += 1
|
||||
|
||||
frame_idx += 1
|
||||
|
||||
cap.release()
|
||||
|
||||
logger.info("extract_frames_done", extra={
|
||||
"job_id": req.job_id,
|
||||
"frames": len(frames_info),
|
||||
})
|
||||
await _post_callback(callback_url, {
|
||||
"job_id": req.job_id,
|
||||
"status": "SUCCESS",
|
||||
"frames": [f.model_dump() for f in frames_info],
|
||||
"output_path": None,
|
||||
"error_message": None,
|
||||
})
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("extract_frames_failed", extra={"job_id": req.job_id, "error": str(exc)})
|
||||
await _post_callback(callback_url, {
|
||||
"job_id": req.job_id,
|
||||
"status": "FAILED",
|
||||
"frames": None,
|
||||
"output_path": None,
|
||||
"error_message": str(exc),
|
||||
})
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
async def video_to_text_task(
|
||||
req: VideoToTextRequest,
|
||||
llm: LLMClient,
|
||||
storage: StorageClient,
|
||||
callback_url: str,
|
||||
) -> None:
|
||||
cfg = get_config()
|
||||
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||
sample_count = cfg["video"].get("frame_sample_count", 8)
|
||||
model = req.model or cfg["models"]["default_vision"]
|
||||
|
||||
tmp = None
|
||||
try:
|
||||
video_bytes = await storage.download_bytes(bucket, req.file_path)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
||||
f.write(video_bytes)
|
||||
tmp = f.name
|
||||
|
||||
cap = cv2.VideoCapture(tmp)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||
start_frame = int(req.start_sec * fps)
|
||||
end_frame = int(req.end_sec * fps)
|
||||
total = max(end_frame - start_frame, 1)
|
||||
|
||||
# Uniform sampling
|
||||
indices = [
|
||||
start_frame + int(i * total / sample_count)
|
||||
for i in range(sample_count)
|
||||
]
|
||||
indices = list(dict.fromkeys(indices)) # deduplicate
|
||||
|
||||
content: list[dict] = []
|
||||
for idx in indices:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
continue
|
||||
_, buf = cv2.imencode(".jpg", frame)
|
||||
b64 = base64.b64encode(buf.tobytes()).decode()
|
||||
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}})
|
||||
|
||||
cap.release()
|
||||
|
||||
prompt = req.prompt_template or "请用中文详细描述这段视频的内容,生成结构化文字描述。"
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
description = await llm.chat_vision(model, messages)
|
||||
|
||||
# Upload description text
|
||||
timestamp = int(time.time())
|
||||
output_path = f"video-text/{req.source_id}/{timestamp}.txt"
|
||||
await storage.upload_bytes(
|
||||
bucket, output_path, description.encode("utf-8"), "text/plain"
|
||||
)
|
||||
|
||||
logger.info("video_to_text_done", extra={"job_id": req.job_id, "output_path": output_path})
|
||||
await _post_callback(callback_url, {
|
||||
"job_id": req.job_id,
|
||||
"status": "SUCCESS",
|
||||
"frames": None,
|
||||
"output_path": output_path,
|
||||
"error_message": None,
|
||||
})
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("video_to_text_failed", extra={"job_id": req.job_id, "error": str(exc)})
|
||||
await _post_callback(callback_url, {
|
||||
"job_id": req.job_id,
|
||||
"status": "FAILED",
|
||||
"frames": None,
|
||||
"output_path": None,
|
||||
"error_message": str(exc),
|
||||
})
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
24
config.yaml
Normal file
24
config.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
server:
|
||||
port: 18000
|
||||
log_level: INFO
|
||||
|
||||
dashscope:
|
||||
api_key: "" # override with DASHSCOPE_API_KEY in .env or environment
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
fine_tune_base_url: "https://dashscope.aliyuncs.com/api/v1"
|
||||
|
||||
storage:
|
||||
buckets:
|
||||
source_data: "label-source-data"
|
||||
finetune_export: "finetune-export"
|
||||
|
||||
backend: {} # callback_url injected via BACKEND_CALLBACK_URL env var
|
||||
|
||||
video:
|
||||
frame_sample_count: 8 # uniform frames sampled for video-to-text
|
||||
max_file_size_mb: 500 # video size limit (override with MAX_VIDEO_SIZE_MB)
|
||||
keyframe_diff_threshold: 30.0 # grayscale mean-diff threshold for keyframe detection
|
||||
|
||||
models:
|
||||
default_text: "qwen-plus"
|
||||
default_vision: "qwen-vl-plus"
|
||||
12
docker-compose.python.yml
Normal file
12
docker-compose.python.yml
Normal file
@@ -0,0 +1,12 @@
|
||||
version: "3.9"
|
||||
|
||||
services:
|
||||
python-service:
|
||||
image: label-ai-service:latest
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: label-ai-service
|
||||
ports:
|
||||
- "18000:18000"
|
||||
restart: unless-stopped
|
||||
37
docker-compose.yml
Normal file
37
docker-compose.yml
Normal file
@@ -0,0 +1,37 @@
|
||||
version: "3.9"
|
||||
|
||||
services:
|
||||
ai-service:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
env_file:
|
||||
- .env
|
||||
depends_on:
|
||||
rustfs:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
|
||||
rustfs:
|
||||
image: rustfs/rustfs:latest
|
||||
ports:
|
||||
- "9000:9000"
|
||||
environment:
|
||||
RUSTFS_ACCESS_KEY: ${STORAGE_ACCESS_KEY}
|
||||
RUSTFS_SECRET_KEY: ${STORAGE_SECRET_KEY}
|
||||
volumes:
|
||||
- rustfs_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/health"]
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
start_period: 5s
|
||||
|
||||
volumes:
|
||||
rustfs_data:
|
||||
61
requirements.txt
Normal file
61
requirements.txt
Normal file
@@ -0,0 +1,61 @@
|
||||
# ============================================
|
||||
# Core Framework
|
||||
# ============================================
|
||||
fastapi==0.135.3
|
||||
uvicorn[standard]==0.44.0
|
||||
pydantic==2.12.5
|
||||
|
||||
# ============================================
|
||||
# AI & LLM
|
||||
# ============================================
|
||||
zhipuai==2.1.5.20250825
|
||||
|
||||
# ============================================
|
||||
# Storage (S3 Compatible)
|
||||
# ============================================
|
||||
boto3==1.42.87
|
||||
|
||||
# ============================================
|
||||
# Configuration
|
||||
# ============================================
|
||||
python-dotenv==1.2.2
|
||||
PyYAML==6.0.2
|
||||
|
||||
# ============================================
|
||||
# Document Processing
|
||||
# ============================================
|
||||
pdfplumber==0.11.9
|
||||
python-docx==1.2.0
|
||||
lxml==6.0.3
|
||||
|
||||
# ============================================
|
||||
# Image & Video Processing
|
||||
# ============================================
|
||||
opencv-python-headless==4.13.0.92
|
||||
pillow==12.2.0
|
||||
numpy==2.4.4
|
||||
|
||||
# ============================================
|
||||
# Authentication & Security
|
||||
# ============================================
|
||||
PyJWT==2.8.0
|
||||
cryptography==46.0.7
|
||||
|
||||
# ============================================
|
||||
# HTTP Client 用于测试客户端
|
||||
# ============================================
|
||||
httpx==0.28.1
|
||||
|
||||
# ============================================
|
||||
# Testing (Development Only)
|
||||
# ============================================
|
||||
pytest==9.0.3
|
||||
pytest-asyncio==1.3.0
|
||||
|
||||
# ============================================
|
||||
# Type Checking & Async Support
|
||||
# ============================================
|
||||
typing_extensions==4.14.1
|
||||
sniffio==1.3.1
|
||||
|
||||
|
||||
27
start.sh
Normal file
27
start.sh
Normal file
@@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
COMPOSE_CMD="docker compose"
|
||||
if ! docker compose version >/dev/null 2>&1; then
|
||||
if command -v docker-compose >/dev/null 2>&1; then
|
||||
COMPOSE_CMD="docker-compose"
|
||||
else
|
||||
echo "Error: docker compose and docker-compose are both unavailable." >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "==> Pulling latest code..."
|
||||
git pull
|
||||
|
||||
echo "==> Building image..."
|
||||
docker build -t label-ai-service:latest -f Dockerfile .
|
||||
|
||||
echo "==> Starting service..."
|
||||
$COMPOSE_CMD -f docker-compose.python.yml up -d
|
||||
|
||||
echo "==> Service started. Check logs with:"
|
||||
echo " $COMPOSE_CMD -f docker-compose.python.yml logs -f python-service"
|
||||
Reference in New Issue
Block a user