Compare commits
21 Commits
54d7a8bf55
...
7adc88bab7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7adc88bab7 | ||
|
|
68567b98b3 | ||
|
|
1e327ea92f | ||
|
|
a0ec71d877 | ||
|
|
65842b5e8b | ||
|
|
ec279262a7 | ||
|
|
0880e1018c | ||
|
|
603382d1fa | ||
|
|
00f092e728 | ||
|
|
4211e587ee | ||
|
|
0274bb470a | ||
|
|
2876c179ac | ||
|
|
dd8da386f4 | ||
|
|
e1eb5e47b1 | ||
|
|
4162d9f4e6 | ||
|
|
092f9dbfc5 | ||
|
|
e0d080ceea | ||
|
|
10887da4ab | ||
|
|
3892c6e60f | ||
|
|
f9f84937db | ||
|
|
f5a43a4bbc |
10
.env
Normal file
10
.env
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Required — fill in before running
|
||||||
|
ZHIPUAI_API_KEY=your-zhipuai-api-key-here
|
||||||
|
STORAGE_ACCESS_KEY=your-storage-access-key
|
||||||
|
STORAGE_SECRET_KEY=your-storage-secret-key
|
||||||
|
STORAGE_ENDPOINT=http://rustfs:9000
|
||||||
|
|
||||||
|
# Optional overrides
|
||||||
|
BACKEND_CALLBACK_URL=http://label-backend:8080/api/ai/callback
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
# MAX_VIDEO_SIZE_MB=200
|
||||||
37
.gitignore
vendored
Normal file
37
.gitignore
vendored
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
*$py.class
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
*.env
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# Temp files (video processing)
|
||||||
|
tmp/
|
||||||
|
*.tmp
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
.specify/
|
||||||
|
.claude/
|
||||||
|
docs/
|
||||||
|
specs/
|
||||||
|
tests/
|
||||||
|
CLAUDE.md
|
||||||
|
pytest.ini
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
18
Dockerfile
Normal file
18
Dockerfile
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
libgl1 \
|
||||||
|
libglib2.0-0 \
|
||||||
|
curl \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
419
README.md
419
README.md
@@ -0,0 +1,419 @@
|
|||||||
|
# label_ai_service
|
||||||
|
|
||||||
|
`label_ai_service` 是知识图谱智能标注平台的 AI 计算服务,基于 FastAPI 提供独立部署的推理与预处理能力。它不直接访问数据库,而是通过 ZhipuAI GLM 系列模型完成结构化抽取,通过 RustFS 读写原始文件和处理结果,并通过 HTTP 回调把异步视频任务结果通知上游后端。
|
||||||
|
|
||||||
|
当前服务覆盖 6 类核心能力:
|
||||||
|
|
||||||
|
- 文本三元组提取,支持 `TXT`、`PDF`、`DOCX`
|
||||||
|
- 图像四元组提取,并自动裁剪 `bbox` 区域图
|
||||||
|
- 视频抽帧,支持固定间隔和近似关键帧两种模式
|
||||||
|
- 视频转文本,将视频片段描述输出为文本文件
|
||||||
|
- 基于文本或图片证据生成问答对
|
||||||
|
- 向 ZhipuAI 提交微调任务并查询状态
|
||||||
|
|
||||||
|
## 适用场景
|
||||||
|
|
||||||
|
这个服务适合作为 `label-backend` 的 AI 能力侧车服务,也可以单独运行,用于验证文件解析、图像理解、视频预处理和问答生成流程。
|
||||||
|
|
||||||
|
典型调用链如下:
|
||||||
|
|
||||||
|
1. Java 后端把原始文本、图片或视频上传到 RustFS。
|
||||||
|
2. Java 后端调用 `label_ai_service` 的 REST API。
|
||||||
|
3. AI 服务从 RustFS 读取文件,调用 GLM 模型做抽取或生成。
|
||||||
|
4. 结果以 JSON 返回,或写回 RustFS 后通过回调通知上游。
|
||||||
|
|
||||||
|
## 功能概览
|
||||||
|
|
||||||
|
| 能力 | 接口 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| 健康检查 | `GET /health` | 用于容器存活探测和联调自检 |
|
||||||
|
| 文本三元组提取 | `POST /api/v1/text/extract` | 从文档中提取 `subject / predicate / object / source_snippet / source_offset` |
|
||||||
|
| 图像四元组提取 | `POST /api/v1/image/extract` | 从图片中提取 `subject / predicate / object / qualifier / bbox`,并输出裁剪图路径 |
|
||||||
|
| 视频抽帧 | `POST /api/v1/video/extract-frames` | 异步抽取视频帧,结果通过回调返回 |
|
||||||
|
| 视频转文本 | `POST /api/v1/video/to-text` | 异步抽样视频代表帧,生成中文描述文本并上传到对象存储 |
|
||||||
|
| 文本问答生成 | `POST /api/v1/qa/gen-text` | 基于三元组和原文证据生成问答对 |
|
||||||
|
| 图像问答生成 | `POST /api/v1/qa/gen-image` | 基于裁剪图和四元组生成问答对 |
|
||||||
|
| 微调任务提交 | `POST /api/v1/finetune/start` | 向 ZhipuAI 提交微调任务 |
|
||||||
|
| 微调状态查询 | `GET /api/v1/finetune/status/{job_id}` | 查询微调任务状态和进度 |
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
- Python 3.12
|
||||||
|
- FastAPI
|
||||||
|
- Pydantic v2
|
||||||
|
- ZhipuAI Python SDK
|
||||||
|
- boto3
|
||||||
|
- OpenCV
|
||||||
|
- pdfplumber
|
||||||
|
- python-docx
|
||||||
|
- httpx
|
||||||
|
- pytest / pytest-asyncio
|
||||||
|
|
||||||
|
## 架构说明
|
||||||
|
|
||||||
|
### 外部依赖
|
||||||
|
|
||||||
|
- ZhipuAI
|
||||||
|
- 文本与多模态推理
|
||||||
|
- 微调任务提交与查询
|
||||||
|
- RustFS 或任意 S3 兼容对象存储
|
||||||
|
- 原始文件读取
|
||||||
|
- 裁剪图、视频帧、视频描述文本写回
|
||||||
|
- 上游回调接口
|
||||||
|
- 视频任务完成后接收结果
|
||||||
|
|
||||||
|
### 处理边界
|
||||||
|
|
||||||
|
- 服务本身不负责文件上传,也不维护任务状态库。
|
||||||
|
- 文本、图像接口是同步返回。
|
||||||
|
- 视频接口是异步返回 `202 Accepted`,真实处理结果走回调。
|
||||||
|
- 服务默认不做鉴权,通常由上游网关或后端负责访问控制。
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```text
|
||||||
|
label_ai_service/
|
||||||
|
├── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── clients/
|
||||||
|
│ │ ├── llm/
|
||||||
|
│ │ └── storage/
|
||||||
|
│ ├── core/
|
||||||
|
│ ├── models/
|
||||||
|
│ ├── routers/
|
||||||
|
│ └── services/
|
||||||
|
├── docs/
|
||||||
|
│ └── superpowers/
|
||||||
|
├── specs/
|
||||||
|
├── tests/
|
||||||
|
├── config.yaml
|
||||||
|
├── .env
|
||||||
|
├── Dockerfile
|
||||||
|
├── docker-compose.yml
|
||||||
|
├── requirements.txt
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
目录职责:
|
||||||
|
|
||||||
|
- `app/main.py`
|
||||||
|
- FastAPI 应用入口,注册中间件、异常处理器和所有路由
|
||||||
|
- `app/clients`
|
||||||
|
- 第三方依赖适配层,当前包含 ZhipuAI 和 RustFS
|
||||||
|
- `app/services`
|
||||||
|
- 业务核心逻辑,负责文件解析、提示词拼装、结果转换和异步任务处理
|
||||||
|
- `app/routers`
|
||||||
|
- HTTP 接口层
|
||||||
|
- `app/models`
|
||||||
|
- 请求与响应模型
|
||||||
|
- `app/core`
|
||||||
|
- 配置、日志、中间件、异常等通用模块
|
||||||
|
- `tests`
|
||||||
|
- Router、Service、Config 和 Client 的测试
|
||||||
|
|
||||||
|
## 配置说明
|
||||||
|
|
||||||
|
配置采用 `config.yaml + .env` 分层方式:
|
||||||
|
|
||||||
|
- `config.yaml`
|
||||||
|
- 存放稳定、可提交的结构化配置
|
||||||
|
- `.env`
|
||||||
|
- 存放密钥和环境差异项
|
||||||
|
|
||||||
|
环境变量会覆盖 `config.yaml` 中的同名配置。
|
||||||
|
|
||||||
|
### config.yaml
|
||||||
|
|
||||||
|
当前项目默认配置如下:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
port: 8000
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
storage:
|
||||||
|
buckets:
|
||||||
|
source_data: "source-data"
|
||||||
|
finetune_export: "finetune-export"
|
||||||
|
|
||||||
|
backend: {}
|
||||||
|
|
||||||
|
video:
|
||||||
|
frame_sample_count: 8
|
||||||
|
max_file_size_mb: 200
|
||||||
|
keyframe_diff_threshold: 30.0
|
||||||
|
|
||||||
|
models:
|
||||||
|
default_text: "glm-4-flash"
|
||||||
|
default_vision: "glm-4v-flash"
|
||||||
|
```
|
||||||
|
|
||||||
|
### .env
|
||||||
|
|
||||||
|
建议至少配置这些变量:
|
||||||
|
|
||||||
|
| 变量名 | 必填 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| `ZHIPUAI_API_KEY` | 是 | ZhipuAI API Key |
|
||||||
|
| `STORAGE_ACCESS_KEY` | 是 | RustFS/S3 Access Key |
|
||||||
|
| `STORAGE_SECRET_KEY` | 是 | RustFS/S3 Secret Key |
|
||||||
|
| `STORAGE_ENDPOINT` | 是 | RustFS/S3 Endpoint,例如 `http://rustfs:9000` |
|
||||||
|
| `BACKEND_CALLBACK_URL` | 否 | 视频异步任务回调地址 |
|
||||||
|
| `LOG_LEVEL` | 否 | 日志级别,默认 `INFO` |
|
||||||
|
| `MAX_VIDEO_SIZE_MB` | 否 | 覆盖视频大小上限 |
|
||||||
|
|
||||||
|
`.env` 示例:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
ZHIPUAI_API_KEY=your-zhipuai-api-key-here
|
||||||
|
STORAGE_ACCESS_KEY=your-storage-access-key
|
||||||
|
STORAGE_SECRET_KEY=your-storage-secret-key
|
||||||
|
STORAGE_ENDPOINT=http://rustfs:9000
|
||||||
|
BACKEND_CALLBACK_URL=http://label-backend:8080/api/ai/callback
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
```
|
||||||
|
|
||||||
|
## 本地运行
|
||||||
|
|
||||||
|
### 方式一:直接运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
Windows PowerShell 可以使用:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
python -m venv .venv
|
||||||
|
.\.venv\Scripts\Activate.ps1
|
||||||
|
pip install -r requirements.txt
|
||||||
|
python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
启动后访问:
|
||||||
|
|
||||||
|
- Swagger UI: `http://localhost:8000/docs`
|
||||||
|
- 健康检查: `http://localhost:8000/health`
|
||||||
|
|
||||||
|
### 方式二:Docker Compose
|
||||||
|
|
||||||
|
项目自带的 Compose 文件会启动:
|
||||||
|
|
||||||
|
- `ai-service`
|
||||||
|
- `rustfs`
|
||||||
|
|
||||||
|
启动命令:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
如果你要联调视频异步任务,请确保 `BACKEND_CALLBACK_URL` 指向一个可访问的后端地址。否则任务本身会继续处理,但回调会失败并记录错误日志。
|
||||||
|
|
||||||
|
## API 使用示例
|
||||||
|
|
||||||
|
### 1. 健康检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/health
|
||||||
|
```
|
||||||
|
|
||||||
|
返回:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"status":"ok"}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 文本三元组提取
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/text/extract \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"file_path": "text/202404/123.txt",
|
||||||
|
"file_name": "设备规范.txt"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 图像四元组提取
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/image/extract \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"file_path": "image/202404/456.jpg",
|
||||||
|
"task_id": 789
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 视频抽帧
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/video/extract-frames \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"file_path": "video/202404/001.mp4",
|
||||||
|
"source_id": 10,
|
||||||
|
"job_id": 42,
|
||||||
|
"mode": "interval",
|
||||||
|
"frame_interval": 30
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 视频转文本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/video/to-text \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"file_path": "video/202404/001.mp4",
|
||||||
|
"source_id": 10,
|
||||||
|
"job_id": 43,
|
||||||
|
"start_sec": 0,
|
||||||
|
"end_sec": 60
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 文本问答生成
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/qa/gen-text \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"subject": "变压器",
|
||||||
|
"predicate": "额定电压",
|
||||||
|
"object": "110kV",
|
||||||
|
"source_snippet": "该变压器额定电压为110kV"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 图像问答生成
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/qa/gen-image \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"subject": "电缆接头",
|
||||||
|
"predicate": "位于",
|
||||||
|
"object": "配电箱左侧",
|
||||||
|
"cropped_image_path": "crops/1/0.jpg"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8. 微调任务提交
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8000/api/v1/finetune/start \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"jsonl_url": "https://example.com/train.jsonl",
|
||||||
|
"base_model": "glm-4-flash",
|
||||||
|
"hyperparams": {
|
||||||
|
"epochs": 3,
|
||||||
|
"learning_rate": 0.0001
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数据输出约定
|
||||||
|
|
||||||
|
当前服务会主动写入这些派生结果:
|
||||||
|
|
||||||
|
| 类型 | 路径模式 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| 图像裁剪图 | `crops/{task_id}/{index}.jpg` | 图像提取结果的局部证据图 |
|
||||||
|
| 视频抽帧图片 | `frames/{source_id}/{index}.jpg` | 视频帧提取结果 |
|
||||||
|
| 视频文本描述 | `video-text/{source_id}/{timestamp}.txt` | 视频转文本结果 |
|
||||||
|
|
||||||
|
说明:
|
||||||
|
|
||||||
|
- 这些对象默认写入 `storage.buckets.source_data`
|
||||||
|
- 原始文件的上传路径由上游系统决定
|
||||||
|
- 服务不会替上游生成原始文件路径,只消费请求里传入的 `file_path`
|
||||||
|
|
||||||
|
## 日志与错误处理
|
||||||
|
|
||||||
|
### 日志
|
||||||
|
|
||||||
|
日志使用 JSON 格式输出,适合直接接入容器日志平台。请求日志会带上:
|
||||||
|
|
||||||
|
- `method`
|
||||||
|
- `path`
|
||||||
|
- `status`
|
||||||
|
- `duration_ms`
|
||||||
|
|
||||||
|
LLM 调用和后台任务也会输出结构化字段,方便排查接口超时、回调失败和模型解析错误。
|
||||||
|
|
||||||
|
### 错误码
|
||||||
|
|
||||||
|
统一错误返回格式:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"code":"ERROR_CODE","message":"具体描述"}
|
||||||
|
```
|
||||||
|
|
||||||
|
当前主要错误码:
|
||||||
|
|
||||||
|
| 错误码 | HTTP 状态码 | 含义 |
|
||||||
|
|---|---|---|
|
||||||
|
| `UNSUPPORTED_FILE_TYPE` | 400 | 文本提取文件格式不支持 |
|
||||||
|
| `VIDEO_TOO_LARGE` | 400 | 视频大小超过限制 |
|
||||||
|
| `STORAGE_ERROR` | 502 | 对象存储访问失败 |
|
||||||
|
| `LLM_PARSE_ERROR` | 502 | 模型返回内容无法解析为预期 JSON |
|
||||||
|
| `LLM_CALL_ERROR` | 503 | 模型调用或微调接口调用失败 |
|
||||||
|
| `INTERNAL_ERROR` | 500 | 未捕获异常 |
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
安装依赖后可直接运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
测试覆盖了这些主要模块:
|
||||||
|
|
||||||
|
- 健康检查接口
|
||||||
|
- 文本、图像、视频、QA、微调路由
|
||||||
|
- 各 Service 的基本成功与异常路径
|
||||||
|
- 配置加载和客户端适配
|
||||||
|
|
||||||
|
## 设计文档
|
||||||
|
|
||||||
|
项目内已有更详细的设计资料,可配合 README 阅读:
|
||||||
|
|
||||||
|
- `docs/superpowers/specs/2026-04-10-ai-service-design.md`
|
||||||
|
- `docs/superpowers/plans/2026-04-10-ai-service-impl.md`
|
||||||
|
- `specs/001-ai-service-requirements/`
|
||||||
|
|
||||||
|
如果你刚接手这个服务,建议阅读顺序是:
|
||||||
|
|
||||||
|
1. 本 README,先搞清楚服务职责、接口和运行方式
|
||||||
|
2. 设计文档,再看架构和设计决策
|
||||||
|
3. `app/services` 与 `tests`,最后进入实现细节
|
||||||
|
|
||||||
|
## 已知约束
|
||||||
|
|
||||||
|
- 文本提取目前只支持 `txt`、`pdf`、`docx`
|
||||||
|
- 视频接口依赖对象存储可读取文件大小
|
||||||
|
- 视频任务状态不持久化在本服务内,由上游系统负责管理
|
||||||
|
- 图像问答采用 base64 内联图片,不依赖外网可访问的 presigned URL
|
||||||
|
- 如果 `.env` 中的回调地址不可达,视频任务会记录错误日志,但不会自动重试
|
||||||
|
|
||||||
|
## 开发建议
|
||||||
|
|
||||||
|
- 新增接口时同步补齐 Pydantic 模型、Router 测试和 README/API 文档
|
||||||
|
- 如果替换模型厂商,优先扩展 `app/clients/llm`
|
||||||
|
- 如果替换存储实现,优先扩展 `app/clients/storage`
|
||||||
|
- 任何输出路径规则变更,都应同步更新 README 和设计文档
|
||||||
|
|||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/clients/__init__.py
Normal file
0
app/clients/__init__.py
Normal file
0
app/clients/llm/__init__.py
Normal file
0
app/clients/llm/__init__.py
Normal file
19
app/clients/llm/base.py
Normal file
19
app/clients/llm/base.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(self, model: str, messages: list[dict]) -> str:
|
||||||
|
"""Send a text chat request and return the response content string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||||
|
"""Send a multimodal (vision) chat request and return the response content string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||||||
|
"""Submit a fine-tune job and return the job_id."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_finetune_status(self, job_id: str) -> dict:
|
||||||
|
"""Return a dict with keys: job_id, status (raw SDK string), progress (int|None), error_message (str|None)."""
|
||||||
68
app/clients/llm/zhipuai_client.py
Normal file
68
app/clients/llm/zhipuai_client.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from zhipuai import ZhipuAI
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.core.exceptions import LLMCallError
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIClient(LLMClient):
|
||||||
|
def __init__(self, api_key: str) -> None:
|
||||||
|
self._client = ZhipuAI(api_key=api_key)
|
||||||
|
|
||||||
|
async def chat(self, model: str, messages: list[dict]) -> str:
|
||||||
|
return await self._call(model, messages)
|
||||||
|
|
||||||
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||||
|
return await self._call(model, messages)
|
||||||
|
|
||||||
|
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._client.fine_tuning.jobs.create(
|
||||||
|
training_file=jsonl_url,
|
||||||
|
model=base_model,
|
||||||
|
hyperparameters=hyperparams,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return resp.id
|
||||||
|
except Exception as exc:
|
||||||
|
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
|
||||||
|
|
||||||
|
async def get_finetune_status(self, job_id: str) -> dict:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._client.fine_tuning.jobs.retrieve(job_id),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"job_id": resp.id,
|
||||||
|
"status": resp.status,
|
||||||
|
"progress": int(resp.progress) if getattr(resp, "progress", None) is not None else None,
|
||||||
|
"error_message": getattr(resp, "error_message", None),
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
raise LLMCallError(f"查询微调任务失败: {exc}") from exc
|
||||||
|
|
||||||
|
async def _call(self, model: str, messages: list[dict]) -> str:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
logger.info("llm_call", extra={"model": model, "response_len": len(content)})
|
||||||
|
return content
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("llm_call_error", extra={"model": model, "error": str(exc)})
|
||||||
|
raise LLMCallError(f"大模型调用失败: {exc}") from exc
|
||||||
0
app/clients/storage/__init__.py
Normal file
0
app/clients/storage/__init__.py
Normal file
21
app/clients/storage/base.py
Normal file
21
app/clients/storage/base.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class StorageClient(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||||||
|
"""Download an object and return its raw bytes."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def upload_bytes(
|
||||||
|
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
||||||
|
) -> None:
|
||||||
|
"""Upload raw bytes to the given bucket/path."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||||||
|
"""Return a presigned GET URL valid for `expires` seconds."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_object_size(self, bucket: str, path: str) -> int:
|
||||||
|
"""Return the object size in bytes without downloading it."""
|
||||||
70
app/clients/storage/rustfs_client.py
Normal file
70
app/clients/storage/rustfs_client.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.exceptions import StorageError
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RustFSClient(StorageClient):
|
||||||
|
def __init__(self, endpoint: str, access_key: str, secret_key: str) -> None:
|
||||||
|
self._s3 = boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=endpoint,
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
|
||||||
|
)
|
||||||
|
return resp["Body"].read()
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"存储下载失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
|
async def upload_bytes(
|
||||||
|
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
||||||
|
) -> None:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._s3.put_object(
|
||||||
|
Bucket=bucket, Key=path, Body=io.BytesIO(data), ContentType=content_type
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
|
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
url = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._s3.generate_presigned_url(
|
||||||
|
"get_object",
|
||||||
|
Params={"Bucket": bucket, "Key": path},
|
||||||
|
ExpiresIn=expires,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
|
async def get_object_size(self, bucket: str, path: str) -> int:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
|
||||||
|
)
|
||||||
|
return resp["ContentLength"]
|
||||||
|
except ClientError as exc:
|
||||||
|
raise StorageError(f"获取文件大小失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
46
app/core/config.py
Normal file
46
app/core/config.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Maps environment variable names to nested YAML key paths
|
||||||
|
_ENV_OVERRIDES: dict[str, list[str]] = {
|
||||||
|
"ZHIPUAI_API_KEY": ["zhipuai", "api_key"],
|
||||||
|
"STORAGE_ACCESS_KEY": ["storage", "access_key"],
|
||||||
|
"STORAGE_SECRET_KEY": ["storage", "secret_key"],
|
||||||
|
"STORAGE_ENDPOINT": ["storage", "endpoint"],
|
||||||
|
"BACKEND_CALLBACK_URL": ["backend", "callback_url"],
|
||||||
|
"LOG_LEVEL": ["server", "log_level"],
|
||||||
|
"MAX_VIDEO_SIZE_MB": ["video", "max_file_size_mb"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _set_nested(cfg: dict, keys: list[str], value: Any) -> None:
|
||||||
|
for key in keys[:-1]:
|
||||||
|
cfg = cfg.setdefault(key, {})
|
||||||
|
# Coerce numeric env vars
|
||||||
|
try:
|
||||||
|
value = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
cfg[keys[-1]] = value
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_config() -> dict:
|
||||||
|
with open(_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||||
|
cfg: dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
for env_var, key_path in _ENV_OVERRIDES.items():
|
||||||
|
value = os.environ.get(env_var)
|
||||||
|
if value is not None:
|
||||||
|
_set_nested(cfg, key_path, value)
|
||||||
|
|
||||||
|
return cfg
|
||||||
23
app/core/dependencies.py
Normal file
23
app/core/dependencies.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.clients.storage.rustfs_client import RustFSClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_llm_client() -> LLMClient:
|
||||||
|
cfg = get_config()
|
||||||
|
return ZhipuAIClient(api_key=cfg["zhipuai"]["api_key"])
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_storage_client() -> StorageClient:
|
||||||
|
cfg = get_config()
|
||||||
|
return RustFSClient(
|
||||||
|
endpoint=cfg["storage"]["endpoint"],
|
||||||
|
access_key=cfg["storage"]["access_key"],
|
||||||
|
secret_key=cfg["storage"]["secret_key"],
|
||||||
|
)
|
||||||
50
app/core/exceptions.py
Normal file
50
app/core/exceptions.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
class AIServiceError(Exception):
|
||||||
|
status_code: int = 500
|
||||||
|
code: str = "INTERNAL_ERROR"
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedFileTypeError(AIServiceError):
|
||||||
|
status_code = 400
|
||||||
|
code = "UNSUPPORTED_FILE_TYPE"
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTooLargeError(AIServiceError):
|
||||||
|
status_code = 400
|
||||||
|
code = "VIDEO_TOO_LARGE"
|
||||||
|
|
||||||
|
|
||||||
|
class StorageError(AIServiceError):
|
||||||
|
status_code = 502
|
||||||
|
code = "STORAGE_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMParseError(AIServiceError):
|
||||||
|
status_code = 502
|
||||||
|
code = "LLM_PARSE_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCallError(AIServiceError):
|
||||||
|
status_code = 503
|
||||||
|
code = "LLM_CALL_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
async def ai_service_exception_handler(request: Request, exc: AIServiceError) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={"code": exc.code, "message": exc.message},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"code": "INTERNAL_ERROR", "message": str(exc)},
|
||||||
|
)
|
||||||
19
app/core/json_utils.py
Normal file
19
app/core/json_utils.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from app.core.exceptions import LLMParseError
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json(text: str) -> any:
|
||||||
|
"""Parse JSON from LLM response, stripping Markdown code fences if present."""
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Strip ```json ... ``` or ``` ... ``` fences
|
||||||
|
fence_match = re.search(r"```(?:json)?\s*([\s\S]+?)\s*```", text)
|
||||||
|
if fence_match:
|
||||||
|
text = fence_match.group(1).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise LLMParseError(f"大模型返回非合法 JSON: {e}") from e
|
||||||
62
app/core/logging.py
Normal file
62
app/core/logging.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(_JsonFormatter())
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
class _JsonFormatter(logging.Formatter):
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
payload = {
|
||||||
|
"time": self.formatTime(record, datefmt="%Y-%m-%dT%H:%M:%S"),
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"message": record.getMessage(),
|
||||||
|
}
|
||||||
|
if record.exc_info:
|
||||||
|
payload["exc_info"] = self.formatException(record.exc_info)
|
||||||
|
# Merge any extra fields passed via `extra=`
|
||||||
|
for key, value in record.__dict__.items():
|
||||||
|
if key not in (
|
||||||
|
"name", "msg", "args", "levelname", "levelno", "pathname",
|
||||||
|
"filename", "module", "exc_info", "exc_text", "stack_info",
|
||||||
|
"lineno", "funcName", "created", "msecs", "relativeCreated",
|
||||||
|
"thread", "threadName", "processName", "process", "message",
|
||||||
|
"taskName",
|
||||||
|
):
|
||||||
|
payload[key] = value
|
||||||
|
return json.dumps(payload, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||||
|
def __init__(self, app, logger: logging.Logger | None = None) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
self._logger = logger or get_logger("request")
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
start = time.perf_counter()
|
||||||
|
response = await call_next(request)
|
||||||
|
duration_ms = round((time.perf_counter() - start) * 1000, 1)
|
||||||
|
self._logger.info(
|
||||||
|
"request",
|
||||||
|
extra={
|
||||||
|
"method": request.method,
|
||||||
|
"path": request.url.path,
|
||||||
|
"status": response.status_code,
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response
|
||||||
50
app/main.py
Normal file
50
app/main.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.core.exceptions import (
|
||||||
|
AIServiceError,
|
||||||
|
ai_service_exception_handler,
|
||||||
|
unhandled_exception_handler,
|
||||||
|
)
|
||||||
|
from app.core.logging import RequestLoggingMiddleware, get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
logger.info("startup", extra={"message": "AI service starting"})
|
||||||
|
yield
|
||||||
|
logger.info("shutdown", extra={"message": "AI service stopping"})
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Label AI Service",
|
||||||
|
description="知识图谱标注平台 AI 计算服务",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(RequestLoggingMiddleware)
|
||||||
|
app.add_exception_handler(AIServiceError, ai_service_exception_handler)
|
||||||
|
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health", tags=["Health"])
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# Routers registered after implementation (imported lazily to avoid circular deps)
|
||||||
|
from app.routers import text, image, video, qa, finetune # noqa: E402
|
||||||
|
|
||||||
|
app.include_router(text.router, prefix="/api/v1")
|
||||||
|
app.include_router(image.router, prefix="/api/v1")
|
||||||
|
app.include_router(video.router, prefix="/api/v1")
|
||||||
|
app.include_router(qa.router, prefix="/api/v1")
|
||||||
|
app.include_router(finetune.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
0
app/models/__init__.py
Normal file
0
app/models/__init__.py
Normal file
18
app/models/finetune_models.py
Normal file
18
app/models/finetune_models.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class FinetuneStartRequest(BaseModel):
|
||||||
|
jsonl_url: str
|
||||||
|
base_model: str
|
||||||
|
hyperparams: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FinetuneStartResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class FinetuneStatusResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
status: str
|
||||||
|
progress: int | None = None
|
||||||
|
error_message: str | None = None
|
||||||
28
app/models/image_models.py
Normal file
28
app/models/image_models.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BBox(BaseModel):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
w: int
|
||||||
|
h: int
|
||||||
|
|
||||||
|
|
||||||
|
class QuadrupleItem(BaseModel):
|
||||||
|
subject: str
|
||||||
|
predicate: str
|
||||||
|
object: str
|
||||||
|
qualifier: str | None = None
|
||||||
|
bbox: BBox
|
||||||
|
cropped_image_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExtractRequest(BaseModel):
|
||||||
|
file_path: str
|
||||||
|
task_id: int
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ImageExtractResponse(BaseModel):
|
||||||
|
items: list[QuadrupleItem]
|
||||||
47
app/models/qa_models.py
Normal file
47
app/models/qa_models.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TextQAItem(BaseModel):
|
||||||
|
subject: str
|
||||||
|
predicate: str
|
||||||
|
object: str
|
||||||
|
source_snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
class GenTextQARequest(BaseModel):
|
||||||
|
items: list[TextQAItem]
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class QAPair(BaseModel):
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageQAItem(BaseModel):
|
||||||
|
subject: str
|
||||||
|
predicate: str
|
||||||
|
object: str
|
||||||
|
qualifier: str | None = None
|
||||||
|
cropped_image_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class GenImageQARequest(BaseModel):
|
||||||
|
items: list[ImageQAItem]
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ImageQAPair(BaseModel):
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
image_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class TextQAResponse(BaseModel):
|
||||||
|
pairs: list[QAPair]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageQAResponse(BaseModel):
|
||||||
|
pairs: list[ImageQAPair]
|
||||||
25
app/models/text_models.py
Normal file
25
app/models/text_models.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SourceOffset(BaseModel):
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
|
||||||
|
class TripleItem(BaseModel):
|
||||||
|
subject: str
|
||||||
|
predicate: str
|
||||||
|
object: str
|
||||||
|
source_snippet: str
|
||||||
|
source_offset: SourceOffset
|
||||||
|
|
||||||
|
|
||||||
|
class TextExtractRequest(BaseModel):
|
||||||
|
file_path: str
|
||||||
|
file_name: str
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TextExtractResponse(BaseModel):
|
||||||
|
items: list[TripleItem]
|
||||||
38
app/models/video_models.py
Normal file
38
app/models/video_models.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractFramesRequest(BaseModel):
|
||||||
|
file_path: str
|
||||||
|
source_id: int
|
||||||
|
job_id: int
|
||||||
|
mode: str = "interval"
|
||||||
|
frame_interval: int = 30
|
||||||
|
|
||||||
|
|
||||||
|
class VideoToTextRequest(BaseModel):
|
||||||
|
file_path: str
|
||||||
|
source_id: int
|
||||||
|
job_id: int
|
||||||
|
start_sec: float
|
||||||
|
end_sec: float
|
||||||
|
model: str | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FrameInfo(BaseModel):
|
||||||
|
frame_index: int
|
||||||
|
time_sec: float
|
||||||
|
frame_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class VideoJobCallback(BaseModel):
|
||||||
|
job_id: int
|
||||||
|
status: str
|
||||||
|
frames: list[FrameInfo] | None = None
|
||||||
|
output_path: str | None = None
|
||||||
|
error_message: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class VideoAcceptedResponse(BaseModel):
|
||||||
|
message: str
|
||||||
|
job_id: int
|
||||||
0
app/routers/__init__.py
Normal file
0
app/routers/__init__.py
Normal file
28
app/routers/finetune.py
Normal file
28
app/routers/finetune.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.core.dependencies import get_llm_client
|
||||||
|
from app.models.finetune_models import (
|
||||||
|
FinetuneStartRequest,
|
||||||
|
FinetuneStartResponse,
|
||||||
|
FinetuneStatusResponse,
|
||||||
|
)
|
||||||
|
from app.services import finetune_service
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Finetune"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/finetune/start", response_model=FinetuneStartResponse)
|
||||||
|
async def start_finetune(
|
||||||
|
req: FinetuneStartRequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
) -> FinetuneStartResponse:
|
||||||
|
return await finetune_service.submit_finetune(req, llm)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse)
|
||||||
|
async def get_status(
|
||||||
|
job_id: str,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
) -> FinetuneStatusResponse:
|
||||||
|
return await finetune_service.get_finetune_status(job_id, llm)
|
||||||
18
app/routers/image.py
Normal file
18
app/routers/image.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.dependencies import get_llm_client, get_storage_client
|
||||||
|
from app.models.image_models import ImageExtractRequest, ImageExtractResponse
|
||||||
|
from app.services import image_service
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Image"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/image/extract", response_model=ImageExtractResponse)
|
||||||
|
async def extract_image(
|
||||||
|
req: ImageExtractRequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
storage: StorageClient = Depends(get_storage_client),
|
||||||
|
) -> ImageExtractResponse:
|
||||||
|
return await image_service.extract_quads(req, llm, storage)
|
||||||
31
app/routers/qa.py
Normal file
31
app/routers/qa.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.dependencies import get_llm_client, get_storage_client
|
||||||
|
from app.models.qa_models import (
|
||||||
|
GenImageQARequest,
|
||||||
|
GenTextQARequest,
|
||||||
|
ImageQAResponse,
|
||||||
|
TextQAResponse,
|
||||||
|
)
|
||||||
|
from app.services import qa_service
|
||||||
|
|
||||||
|
router = APIRouter(tags=["QA"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/qa/gen-text", response_model=TextQAResponse)
|
||||||
|
async def gen_text_qa(
|
||||||
|
req: GenTextQARequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
) -> TextQAResponse:
|
||||||
|
return await qa_service.gen_text_qa(req, llm)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/qa/gen-image", response_model=ImageQAResponse)
|
||||||
|
async def gen_image_qa(
|
||||||
|
req: GenImageQARequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
storage: StorageClient = Depends(get_storage_client),
|
||||||
|
) -> ImageQAResponse:
|
||||||
|
return await qa_service.gen_image_qa(req, llm, storage)
|
||||||
18
app/routers/text.py
Normal file
18
app/routers/text.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.dependencies import get_llm_client, get_storage_client
|
||||||
|
from app.models.text_models import TextExtractRequest, TextExtractResponse
|
||||||
|
from app.services import text_service
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Text"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/text/extract", response_model=TextExtractResponse)
|
||||||
|
async def extract_text(
|
||||||
|
req: TextExtractRequest,
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
storage: StorageClient = Depends(get_storage_client),
|
||||||
|
) -> TextExtractResponse:
|
||||||
|
return await text_service.extract_triples(req, llm, storage)
|
||||||
69
app/routers/video.py
Normal file
69
app/routers/video.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
from app.core.dependencies import get_llm_client, get_storage_client
|
||||||
|
from app.core.exceptions import VideoTooLargeError
|
||||||
|
from app.models.video_models import (
|
||||||
|
ExtractFramesRequest,
|
||||||
|
VideoAcceptedResponse,
|
||||||
|
VideoToTextRequest,
|
||||||
|
)
|
||||||
|
from app.services import video_service
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Video"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_video_size(storage: StorageClient, bucket: str, file_path: str, max_mb: int) -> None:
|
||||||
|
size_bytes = await storage.get_object_size(bucket, file_path)
|
||||||
|
if size_bytes > max_mb * 1024 * 1024:
|
||||||
|
raise VideoTooLargeError(
|
||||||
|
f"视频文件大小超出限制(最大 {max_mb}MB,当前 {size_bytes // 1024 // 1024}MB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/video/extract-frames", response_model=VideoAcceptedResponse, status_code=202)
|
||||||
|
async def extract_frames(
|
||||||
|
req: ExtractFramesRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
storage: StorageClient = Depends(get_storage_client),
|
||||||
|
) -> VideoAcceptedResponse:
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
max_mb = cfg["video"]["max_file_size_mb"]
|
||||||
|
callback_url = cfg.get("backend", {}).get("callback_url", "")
|
||||||
|
|
||||||
|
await _check_video_size(storage, bucket, req.file_path, max_mb)
|
||||||
|
|
||||||
|
background_tasks.add_task(
|
||||||
|
video_service.extract_frames_task,
|
||||||
|
req,
|
||||||
|
storage,
|
||||||
|
callback_url,
|
||||||
|
)
|
||||||
|
return VideoAcceptedResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/video/to-text", response_model=VideoAcceptedResponse, status_code=202)
|
||||||
|
async def video_to_text(
|
||||||
|
req: VideoToTextRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
storage: StorageClient = Depends(get_storage_client),
|
||||||
|
llm: LLMClient = Depends(get_llm_client),
|
||||||
|
) -> VideoAcceptedResponse:
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
max_mb = cfg["video"]["max_file_size_mb"]
|
||||||
|
callback_url = cfg.get("backend", {}).get("callback_url", "")
|
||||||
|
|
||||||
|
await _check_video_size(storage, bucket, req.file_path, max_mb)
|
||||||
|
|
||||||
|
background_tasks.add_task(
|
||||||
|
video_service.video_to_text_task,
|
||||||
|
req,
|
||||||
|
llm,
|
||||||
|
storage,
|
||||||
|
callback_url,
|
||||||
|
)
|
||||||
|
return VideoAcceptedResponse(message="任务已接受,后台处理中", job_id=req.job_id)
|
||||||
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
35
app/services/finetune_service.py
Normal file
35
app/services/finetune_service.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.finetune_models import (
|
||||||
|
FinetuneStartRequest,
|
||||||
|
FinetuneStartResponse,
|
||||||
|
FinetuneStatusResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_STATUS_MAP = {
|
||||||
|
"running": "RUNNING",
|
||||||
|
"succeeded": "SUCCESS",
|
||||||
|
"failed": "FAILED",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def submit_finetune(req: FinetuneStartRequest, llm: LLMClient) -> FinetuneStartResponse:
|
||||||
|
"""Submit a fine-tune job via the LLMClient interface and return the job ID."""
|
||||||
|
job_id = await llm.submit_finetune(req.jsonl_url, req.base_model, req.hyperparams or {})
|
||||||
|
logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model})
|
||||||
|
return FinetuneStartResponse(job_id=job_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_finetune_status(job_id: str, llm: LLMClient) -> FinetuneStatusResponse:
|
||||||
|
"""Retrieve fine-tune job status via the LLMClient interface."""
|
||||||
|
raw = await llm.get_finetune_status(job_id)
|
||||||
|
status = _STATUS_MAP.get(raw["status"], "RUNNING")
|
||||||
|
logger.info("finetune_status", extra={"job_id": job_id, "status": status})
|
||||||
|
return FinetuneStatusResponse(
|
||||||
|
job_id=raw["job_id"],
|
||||||
|
status=status,
|
||||||
|
progress=raw["progress"],
|
||||||
|
error_message=raw["error_message"],
|
||||||
|
)
|
||||||
90
app/services/image_service.py
Normal file
90
app/services/image_service.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
from app.core.json_utils import extract_json
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.image_models import (
|
||||||
|
BBox,
|
||||||
|
ImageExtractRequest,
|
||||||
|
ImageExtractResponse,
|
||||||
|
QuadrupleItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_PROMPT = (
|
||||||
|
"请分析这张图片,提取其中的知识四元组,以 JSON 数组格式返回,每条包含字段:"
|
||||||
|
"subject(主体实体)、predicate(关系/属性)、object(客体实体)、"
|
||||||
|
"qualifier(修饰信息,可为 null)、bbox({{x, y, w, h}} 像素坐标)。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_quads(
|
||||||
|
req: ImageExtractRequest,
|
||||||
|
llm: LLMClient,
|
||||||
|
storage: StorageClient,
|
||||||
|
) -> ImageExtractResponse:
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
model = req.model or cfg["models"]["default_vision"]
|
||||||
|
|
||||||
|
image_bytes = await storage.download_bytes(bucket, req.file_path)
|
||||||
|
|
||||||
|
# Decode with OpenCV for cropping; encode as base64 for LLM
|
||||||
|
nparr = np.frombuffer(image_bytes, np.uint8)
|
||||||
|
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||||
|
img_h, img_w = img.shape[:2]
|
||||||
|
|
||||||
|
b64 = base64.b64encode(image_bytes).decode()
|
||||||
|
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||||||
|
|
||||||
|
prompt = req.prompt_template or _DEFAULT_PROMPT
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
raw = await llm.chat_vision(model, messages)
|
||||||
|
logger.info("image_extract", extra={"file": req.file_path, "model": model})
|
||||||
|
|
||||||
|
items_raw = extract_json(raw)
|
||||||
|
items: list[QuadrupleItem] = []
|
||||||
|
|
||||||
|
for idx, item in enumerate(items_raw):
|
||||||
|
b = item["bbox"]
|
||||||
|
# Clamp bbox to image dimensions
|
||||||
|
x = max(0, min(int(b["x"]), img_w - 1))
|
||||||
|
y = max(0, min(int(b["y"]), img_h - 1))
|
||||||
|
w = min(int(b["w"]), img_w - x)
|
||||||
|
h = min(int(b["h"]), img_h - y)
|
||||||
|
|
||||||
|
crop = img[y : y + h, x : x + w]
|
||||||
|
_, crop_buf = cv2.imencode(".jpg", crop)
|
||||||
|
crop_bytes = crop_buf.tobytes()
|
||||||
|
|
||||||
|
crop_path = f"crops/{req.task_id}/{idx}.jpg"
|
||||||
|
await storage.upload_bytes(bucket, crop_path, crop_bytes, "image/jpeg")
|
||||||
|
|
||||||
|
items.append(
|
||||||
|
QuadrupleItem(
|
||||||
|
subject=item["subject"],
|
||||||
|
predicate=item["predicate"],
|
||||||
|
object=item["object"],
|
||||||
|
qualifier=item.get("qualifier"),
|
||||||
|
bbox=BBox(x=x, y=y, w=w, h=h),
|
||||||
|
cropped_image_path=crop_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageExtractResponse(items=items)
|
||||||
118
app/services/qa_service.py
Normal file
118
app/services/qa_service.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
from app.core.exceptions import LLMParseError
|
||||||
|
from app.core.json_utils import extract_json
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.qa_models import (
|
||||||
|
GenImageQARequest,
|
||||||
|
GenTextQARequest,
|
||||||
|
ImageQAPair,
|
||||||
|
ImageQAResponse,
|
||||||
|
QAPair,
|
||||||
|
TextQAResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_TEXT_PROMPT = (
|
||||||
|
"请根据以下知识三元组生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。\n\n"
|
||||||
|
"三元组列表:\n{triples_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_DEFAULT_IMAGE_PROMPT = (
|
||||||
|
"请根据图片内容和以下四元组信息生成问答对,以 JSON 数组格式返回,每条包含 question 和 answer 字段。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def gen_text_qa(req: GenTextQARequest, llm: LLMClient) -> TextQAResponse:
|
||||||
|
cfg = get_config()
|
||||||
|
model = req.model or cfg["models"]["default_text"]
|
||||||
|
|
||||||
|
# Format all triples + source snippets into a single batch prompt
|
||||||
|
triple_lines: list[str] = []
|
||||||
|
for item in req.items:
|
||||||
|
triple_lines.append(
|
||||||
|
f"({item.subject}, {item.predicate}, {item.object}) — 来源: {item.source_snippet}"
|
||||||
|
)
|
||||||
|
triples_text = "\n".join(triple_lines)
|
||||||
|
|
||||||
|
prompt_template = req.prompt_template or _DEFAULT_TEXT_PROMPT
|
||||||
|
if "{triples_text}" in prompt_template:
|
||||||
|
prompt = prompt_template.format(triples_text=triples_text)
|
||||||
|
else:
|
||||||
|
prompt = prompt_template + "\n\n" + triples_text
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
raw = await llm.chat(model, messages)
|
||||||
|
|
||||||
|
items_raw = extract_json(raw)
|
||||||
|
logger.info("gen_text_qa", extra={"items": len(req.items), "model": model})
|
||||||
|
|
||||||
|
if not isinstance(items_raw, list):
|
||||||
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||||
|
try:
|
||||||
|
pairs = [QAPair(question=item["question"], answer=item["answer"]) for item in items_raw]
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||||
|
return TextQAResponse(pairs=pairs)
|
||||||
|
|
||||||
|
|
||||||
|
async def gen_image_qa(
|
||||||
|
req: GenImageQARequest,
|
||||||
|
llm: LLMClient,
|
||||||
|
storage: StorageClient,
|
||||||
|
) -> ImageQAResponse:
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
model = req.model or cfg["models"]["default_vision"]
|
||||||
|
|
||||||
|
prompt = req.prompt_template or _DEFAULT_IMAGE_PROMPT
|
||||||
|
|
||||||
|
pairs: list[ImageQAPair] = []
|
||||||
|
|
||||||
|
for item in req.items:
|
||||||
|
# Download cropped image bytes from storage
|
||||||
|
image_bytes = await storage.download_bytes(bucket, item.cropped_image_path)
|
||||||
|
|
||||||
|
# Base64 encode inline for multimodal message
|
||||||
|
b64 = base64.b64encode(image_bytes).decode()
|
||||||
|
image_data_url = f"data:image/jpeg;base64,{b64}"
|
||||||
|
|
||||||
|
# Build quad info text
|
||||||
|
quad_text = f"{item.subject} — {item.predicate} — {item.object}"
|
||||||
|
if item.qualifier:
|
||||||
|
quad_text += f" ({item.qualifier})"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": image_data_url}},
|
||||||
|
{"type": "text", "text": f"{prompt}\n\n{quad_text}"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
raw = await llm.chat_vision(model, messages)
|
||||||
|
|
||||||
|
items_raw = extract_json(raw)
|
||||||
|
logger.info("gen_image_qa", extra={"path": item.cropped_image_path, "model": model})
|
||||||
|
|
||||||
|
if not isinstance(items_raw, list):
|
||||||
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||||
|
try:
|
||||||
|
for qa in items_raw:
|
||||||
|
pairs.append(
|
||||||
|
ImageQAPair(
|
||||||
|
question=qa["question"],
|
||||||
|
answer=qa["answer"],
|
||||||
|
image_path=item.cropped_image_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
raise LLMParseError("大模型返回的问答对格式不正确")
|
||||||
|
|
||||||
|
return ImageQAResponse(pairs=pairs)
|
||||||
95
app/services/text_service.py
Normal file
95
app/services/text_service.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import io
|
||||||
|
|
||||||
|
import pdfplumber
|
||||||
|
import docx
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
from app.core.exceptions import UnsupportedFileTypeError
|
||||||
|
from app.core.json_utils import extract_json
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.text_models import (
|
||||||
|
SourceOffset,
|
||||||
|
TextExtractRequest,
|
||||||
|
TextExtractResponse,
|
||||||
|
TripleItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_SUPPORTED_EXTENSIONS = {".txt", ".pdf", ".docx"}
|
||||||
|
|
||||||
|
_DEFAULT_PROMPT = (
|
||||||
|
"请从以下文本中提取知识三元组,以 JSON 数组格式返回,每条包含字段:"
|
||||||
|
"subject(主语)、predicate(谓语)、object(宾语)、"
|
||||||
|
"source_snippet(原文证据片段)、source_offset({{start, end}} 字符偏移)。\n\n"
|
||||||
|
"文本内容:\n{text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _file_extension(file_name: str) -> str:
|
||||||
|
idx = file_name.rfind(".")
|
||||||
|
return file_name[idx:].lower() if idx != -1 else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_txt(data: bytes) -> str:
|
||||||
|
return data.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_pdf(data: bytes) -> str:
|
||||||
|
with pdfplumber.open(io.BytesIO(data)) as pdf:
|
||||||
|
pages = [page.extract_text() or "" for page in pdf.pages]
|
||||||
|
return "\n".join(pages)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_docx(data: bytes) -> str:
|
||||||
|
doc = docx.Document(io.BytesIO(data))
|
||||||
|
return "\n".join(p.text for p in doc.paragraphs)
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_triples(
|
||||||
|
req: TextExtractRequest,
|
||||||
|
llm: LLMClient,
|
||||||
|
storage: StorageClient,
|
||||||
|
) -> TextExtractResponse:
|
||||||
|
ext = _file_extension(req.file_name)
|
||||||
|
if ext not in _SUPPORTED_EXTENSIONS:
|
||||||
|
raise UnsupportedFileTypeError(f"不支持的文件格式: {ext}")
|
||||||
|
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
model = req.model or cfg["models"]["default_text"]
|
||||||
|
|
||||||
|
data = await storage.download_bytes(bucket, req.file_path)
|
||||||
|
|
||||||
|
if ext == ".txt":
|
||||||
|
text = _parse_txt(data)
|
||||||
|
elif ext == ".pdf":
|
||||||
|
text = _parse_pdf(data)
|
||||||
|
else:
|
||||||
|
text = _parse_docx(data)
|
||||||
|
|
||||||
|
prompt_template = req.prompt_template or _DEFAULT_PROMPT
|
||||||
|
prompt = prompt_template.format(text=text) if "{text}" in prompt_template else prompt_template + "\n\n" + text
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
raw = await llm.chat(model, messages)
|
||||||
|
|
||||||
|
logger.info("text_extract", extra={"file": req.file_name, "model": model})
|
||||||
|
|
||||||
|
items_raw = extract_json(raw)
|
||||||
|
items = [
|
||||||
|
TripleItem(
|
||||||
|
subject=item["subject"],
|
||||||
|
predicate=item["predicate"],
|
||||||
|
object=item["object"],
|
||||||
|
source_snippet=item["source_snippet"],
|
||||||
|
source_offset=SourceOffset(
|
||||||
|
start=item["source_offset"]["start"],
|
||||||
|
end=item["source_offset"]["end"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for item in items_raw
|
||||||
|
]
|
||||||
|
return TextExtractResponse(items=items)
|
||||||
189
app/services/video_service.py
Normal file
189
app/services/video_service.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import httpx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
|
from app.clients.storage.base import StorageClient
|
||||||
|
from app.core.config import get_config
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.video_models import ExtractFramesRequest, FrameInfo, VideoToTextRequest
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _post_callback(url: str, payload: dict) -> None:
|
||||||
|
async with httpx.AsyncClient(timeout=10) as http:
|
||||||
|
try:
|
||||||
|
await http.post(url, json=payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("callback_failed", extra={"url": url, "error": str(exc)})
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_frames_task(
|
||||||
|
req: ExtractFramesRequest,
|
||||||
|
storage: StorageClient,
|
||||||
|
callback_url: str,
|
||||||
|
) -> None:
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
threshold = cfg["video"].get("keyframe_diff_threshold", 30.0)
|
||||||
|
|
||||||
|
tmp = None
|
||||||
|
try:
|
||||||
|
video_bytes = await storage.download_bytes(bucket, req.file_path)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
||||||
|
f.write(video_bytes)
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(tmp)
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||||
|
frames_info: list[FrameInfo] = []
|
||||||
|
upload_index = 0
|
||||||
|
prev_gray = None
|
||||||
|
frame_idx = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
extract = False
|
||||||
|
if req.mode == "interval":
|
||||||
|
extract = (frame_idx % req.frame_interval == 0)
|
||||||
|
else: # keyframe
|
||||||
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)
|
||||||
|
if prev_gray is None:
|
||||||
|
extract = True
|
||||||
|
else:
|
||||||
|
diff = np.mean(np.abs(gray - prev_gray))
|
||||||
|
extract = diff > threshold
|
||||||
|
prev_gray = gray
|
||||||
|
|
||||||
|
if extract:
|
||||||
|
time_sec = round(frame_idx / fps, 3)
|
||||||
|
_, buf = cv2.imencode(".jpg", frame)
|
||||||
|
frame_path = f"frames/{req.source_id}/{upload_index}.jpg"
|
||||||
|
await storage.upload_bytes(bucket, frame_path, buf.tobytes(), "image/jpeg")
|
||||||
|
frames_info.append(FrameInfo(
|
||||||
|
frame_index=frame_idx,
|
||||||
|
time_sec=time_sec,
|
||||||
|
frame_path=frame_path,
|
||||||
|
))
|
||||||
|
upload_index += 1
|
||||||
|
|
||||||
|
frame_idx += 1
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
logger.info("extract_frames_done", extra={
|
||||||
|
"job_id": req.job_id,
|
||||||
|
"frames": len(frames_info),
|
||||||
|
})
|
||||||
|
await _post_callback(callback_url, {
|
||||||
|
"job_id": req.job_id,
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"frames": [f.model_dump() for f in frames_info],
|
||||||
|
"output_path": None,
|
||||||
|
"error_message": None,
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("extract_frames_failed", extra={"job_id": req.job_id, "error": str(exc)})
|
||||||
|
await _post_callback(callback_url, {
|
||||||
|
"job_id": req.job_id,
|
||||||
|
"status": "FAILED",
|
||||||
|
"frames": None,
|
||||||
|
"output_path": None,
|
||||||
|
"error_message": str(exc),
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
if tmp and os.path.exists(tmp):
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
async def video_to_text_task(
|
||||||
|
req: VideoToTextRequest,
|
||||||
|
llm: LLMClient,
|
||||||
|
storage: StorageClient,
|
||||||
|
callback_url: str,
|
||||||
|
) -> None:
|
||||||
|
cfg = get_config()
|
||||||
|
bucket = cfg["storage"]["buckets"]["source_data"]
|
||||||
|
sample_count = cfg["video"].get("frame_sample_count", 8)
|
||||||
|
model = req.model or cfg["models"]["default_vision"]
|
||||||
|
|
||||||
|
tmp = None
|
||||||
|
try:
|
||||||
|
video_bytes = await storage.download_bytes(bucket, req.file_path)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
|
||||||
|
f.write(video_bytes)
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(tmp)
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||||
|
start_frame = int(req.start_sec * fps)
|
||||||
|
end_frame = int(req.end_sec * fps)
|
||||||
|
total = max(end_frame - start_frame, 1)
|
||||||
|
|
||||||
|
# Uniform sampling
|
||||||
|
indices = [
|
||||||
|
start_frame + int(i * total / sample_count)
|
||||||
|
for i in range(sample_count)
|
||||||
|
]
|
||||||
|
indices = list(dict.fromkeys(indices)) # deduplicate
|
||||||
|
|
||||||
|
content: list[dict] = []
|
||||||
|
for idx in indices:
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
continue
|
||||||
|
_, buf = cv2.imencode(".jpg", frame)
|
||||||
|
b64 = base64.b64encode(buf.tobytes()).decode()
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}})
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
prompt = req.prompt_template or "请用中文详细描述这段视频的内容,生成结构化文字描述。"
|
||||||
|
content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
description = await llm.chat_vision(model, messages)
|
||||||
|
|
||||||
|
# Upload description text
|
||||||
|
timestamp = int(time.time())
|
||||||
|
output_path = f"video-text/{req.source_id}/{timestamp}.txt"
|
||||||
|
await storage.upload_bytes(
|
||||||
|
bucket, output_path, description.encode("utf-8"), "text/plain"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("video_to_text_done", extra={"job_id": req.job_id, "output_path": output_path})
|
||||||
|
await _post_callback(callback_url, {
|
||||||
|
"job_id": req.job_id,
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"frames": None,
|
||||||
|
"output_path": output_path,
|
||||||
|
"error_message": None,
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("video_to_text_failed", extra={"job_id": req.job_id, "error": str(exc)})
|
||||||
|
await _post_callback(callback_url, {
|
||||||
|
"job_id": req.job_id,
|
||||||
|
"status": "FAILED",
|
||||||
|
"frames": None,
|
||||||
|
"output_path": None,
|
||||||
|
"error_message": str(exc),
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
if tmp and os.path.exists(tmp):
|
||||||
|
os.unlink(tmp)
|
||||||
19
config.yaml
Normal file
19
config.yaml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
server:
|
||||||
|
port: 8000
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
storage:
|
||||||
|
buckets:
|
||||||
|
source_data: "source-data"
|
||||||
|
finetune_export: "finetune-export"
|
||||||
|
|
||||||
|
backend: {} # callback_url injected via BACKEND_CALLBACK_URL env var
|
||||||
|
|
||||||
|
video:
|
||||||
|
frame_sample_count: 8 # uniform frames sampled for video-to-text
|
||||||
|
max_file_size_mb: 200 # video size limit (override with MAX_VIDEO_SIZE_MB)
|
||||||
|
keyframe_diff_threshold: 30.0 # grayscale mean-diff threshold for keyframe detection
|
||||||
|
|
||||||
|
models:
|
||||||
|
default_text: "glm-4-flash"
|
||||||
|
default_vision: "glm-4v-flash"
|
||||||
37
docker-compose.yml
Normal file
37
docker-compose.yml
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
version: "3.9"
|
||||||
|
|
||||||
|
services:
|
||||||
|
ai-service:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
depends_on:
|
||||||
|
rustfs:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
|
||||||
|
rustfs:
|
||||||
|
image: rustfs/rustfs:latest
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
environment:
|
||||||
|
RUSTFS_ACCESS_KEY: ${STORAGE_ACCESS_KEY}
|
||||||
|
RUSTFS_SECRET_KEY: ${STORAGE_SECRET_KEY}
|
||||||
|
volumes:
|
||||||
|
- rustfs_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:9000/health"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
start_period: 5s
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
rustfs_data:
|
||||||
16
requirements.txt
Normal file
16
requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
fastapi>=0.111.0
|
||||||
|
uvicorn[standard]>=0.29.0
|
||||||
|
pydantic>=2.7.0
|
||||||
|
zhipuai>=2.1.0
|
||||||
|
boto3>=1.34.0
|
||||||
|
pdfplumber>=0.11.0
|
||||||
|
python-docx>=1.1.0
|
||||||
|
opencv-python-headless>=4.9.0
|
||||||
|
numpy>=1.26.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
pyyaml>=6.0.0
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.23.0
|
||||||
Reference in New Issue
Block a user