feat(US7+US8): finetune management and health check test
- app/models/finetune_models.py: FinetuneStartRequest, FinetuneStartResponse, FinetuneStatusResponse
- app/services/finetune_service.py: submit_finetune + get_finetune_status via run_in_executor; status map running→RUNNING, succeeded→SUCCESS, failed→FAILED, unknown→RUNNING; LLMCallError on SDK failure
- app/routers/finetune.py: POST /finetune/start + GET /finetune/status/{job_id} with get_llm_client dependency
- tests/test_finetune_service.py: 12 unit tests (TDD, written before implementation)
- tests/test_finetune_router.py: 6 integration tests
- tests/test_health.py: GET /health → 200 {"status":"ok"}
Full suite: 72/72 passing (was 53)
This commit is contained in:
18
app/models/finetune_models.py
Normal file
18
app/models/finetune_models.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FinetuneStartRequest(BaseModel):
|
||||
jsonl_url: str
|
||||
base_model: str
|
||||
hyperparams: dict | None = None
|
||||
|
||||
|
||||
class FinetuneStartResponse(BaseModel):
|
||||
job_id: str
|
||||
|
||||
|
||||
class FinetuneStatusResponse(BaseModel):
|
||||
job_id: str
|
||||
status: str
|
||||
progress: int | None = None
|
||||
error_message: str | None = None
|
||||
@@ -1,3 +1,28 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.clients.llm.base import LLMClient
|
||||
from app.core.dependencies import get_llm_client
|
||||
from app.models.finetune_models import (
|
||||
FinetuneStartRequest,
|
||||
FinetuneStartResponse,
|
||||
FinetuneStatusResponse,
|
||||
)
|
||||
from app.services import finetune_service
|
||||
|
||||
router = APIRouter(tags=["Finetune"])
|
||||
|
||||
|
||||
@router.post("/finetune/start", response_model=FinetuneStartResponse)
|
||||
async def start_finetune(
|
||||
req: FinetuneStartRequest,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
) -> FinetuneStartResponse:
|
||||
return await finetune_service.submit_finetune(req, llm)
|
||||
|
||||
|
||||
@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse)
|
||||
async def get_status(
|
||||
job_id: str,
|
||||
llm: LLMClient = Depends(get_llm_client),
|
||||
) -> FinetuneStatusResponse:
|
||||
return await finetune_service.get_finetune_status(job_id, llm)
|
||||
|
||||
61
app/services/finetune_service.py
Normal file
61
app/services/finetune_service.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
|
||||
from app.core.exceptions import LLMCallError
|
||||
from app.core.logging import get_logger
|
||||
from app.models.finetune_models import (
|
||||
FinetuneStartRequest,
|
||||
FinetuneStartResponse,
|
||||
FinetuneStatusResponse,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_STATUS_MAP = {
|
||||
"running": "RUNNING",
|
||||
"succeeded": "SUCCESS",
|
||||
"failed": "FAILED",
|
||||
}
|
||||
|
||||
|
||||
async def submit_finetune(req: FinetuneStartRequest, llm) -> FinetuneStartResponse:
|
||||
"""Submit a fine-tune job to ZhipuAI and return the job ID."""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: llm._client.fine_tuning.jobs.create(
|
||||
training_file=req.jsonl_url,
|
||||
model=req.base_model,
|
||||
hyperparameters=req.hyperparams or {},
|
||||
),
|
||||
)
|
||||
job_id = response.id
|
||||
logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model})
|
||||
return FinetuneStartResponse(job_id=job_id)
|
||||
except Exception as exc:
|
||||
logger.error("finetune_submit_error", extra={"error": str(exc)})
|
||||
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
|
||||
|
||||
|
||||
async def get_finetune_status(job_id: str, llm) -> FinetuneStatusResponse:
|
||||
"""Retrieve fine-tune job status from ZhipuAI."""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: llm._client.fine_tuning.jobs.retrieve(job_id),
|
||||
)
|
||||
status_raw = response.status
|
||||
status = _STATUS_MAP.get(status_raw, "RUNNING") # conservative fallback
|
||||
progress = getattr(response, "progress", None)
|
||||
error_message = getattr(response, "error_message", None)
|
||||
logger.info("finetune_status", extra={"job_id": job_id, "status": status})
|
||||
return FinetuneStatusResponse(
|
||||
job_id=job_id,
|
||||
status=status,
|
||||
progress=progress,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("finetune_status_error", extra={"job_id": job_id, "error": str(exc)})
|
||||
raise LLMCallError(f"微调状态查询失败: {exc}") from exc
|
||||
112
tests/test_finetune_router.py
Normal file
112
tests/test_finetune_router.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""T050: Integration tests for finetune router endpoints."""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.core.exceptions import LLMCallError
|
||||
from app.models.finetune_models import FinetuneStartResponse, FinetuneStatusResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/v1/finetune/start
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_finetune_start_returns_200_with_job_id(client):
|
||||
start_resp = FinetuneStartResponse(job_id="glm-ft-router-test")
|
||||
|
||||
with patch("app.routers.finetune.finetune_service.submit_finetune") as mock_submit:
|
||||
mock_submit.return_value = start_resp
|
||||
|
||||
resp = client.post(
|
||||
"/api/v1/finetune/start",
|
||||
json={
|
||||
"jsonl_url": "s3://bucket/train.jsonl",
|
||||
"base_model": "glm-4",
|
||||
"hyperparams": {"n_epochs": 3},
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["job_id"] == "glm-ft-router-test"
|
||||
|
||||
|
||||
def test_finetune_start_without_hyperparams(client):
|
||||
start_resp = FinetuneStartResponse(job_id="glm-ft-nohp")
|
||||
|
||||
with patch("app.routers.finetune.finetune_service.submit_finetune") as mock_submit:
|
||||
mock_submit.return_value = start_resp
|
||||
|
||||
resp = client.post(
|
||||
"/api/v1/finetune/start",
|
||||
json={
|
||||
"jsonl_url": "s3://bucket/train.jsonl",
|
||||
"base_model": "glm-4",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["job_id"] == "glm-ft-nohp"
|
||||
|
||||
|
||||
def test_finetune_start_llm_call_error_returns_503(client):
|
||||
with patch("app.routers.finetune.finetune_service.submit_finetune") as mock_submit:
|
||||
mock_submit.side_effect = LLMCallError("SDK failed")
|
||||
|
||||
resp = client.post(
|
||||
"/api/v1/finetune/start",
|
||||
json={
|
||||
"jsonl_url": "s3://bucket/train.jsonl",
|
||||
"base_model": "glm-4",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 503
|
||||
assert resp.json()["code"] == "LLM_CALL_ERROR"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/v1/finetune/status/{job_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_finetune_status_returns_200_with_fields(client):
|
||||
status_resp = FinetuneStatusResponse(
|
||||
job_id="glm-ft-router-test",
|
||||
status="RUNNING",
|
||||
progress=30,
|
||||
)
|
||||
|
||||
with patch("app.routers.finetune.finetune_service.get_finetune_status") as mock_status:
|
||||
mock_status.return_value = status_resp
|
||||
|
||||
resp = client.get("/api/v1/finetune/status/glm-ft-router-test")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["job_id"] == "glm-ft-router-test"
|
||||
assert data["status"] == "RUNNING"
|
||||
assert data["progress"] == 30
|
||||
|
||||
|
||||
def test_finetune_status_succeeded(client):
|
||||
status_resp = FinetuneStatusResponse(
|
||||
job_id="glm-ft-done",
|
||||
status="SUCCESS",
|
||||
)
|
||||
|
||||
with patch("app.routers.finetune.finetune_service.get_finetune_status") as mock_status:
|
||||
mock_status.return_value = status_resp
|
||||
|
||||
resp = client.get("/api/v1/finetune/status/glm-ft-done")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "SUCCESS"
|
||||
|
||||
|
||||
def test_finetune_status_llm_call_error_returns_503(client):
|
||||
with patch("app.routers.finetune.finetune_service.get_finetune_status") as mock_status:
|
||||
mock_status.side_effect = LLMCallError("SDK failed")
|
||||
|
||||
resp = client.get("/api/v1/finetune/status/glm-ft-bad")
|
||||
|
||||
assert resp.status_code == 503
|
||||
assert resp.json()["code"] == "LLM_CALL_ERROR"
|
||||
154
tests/test_finetune_service.py
Normal file
154
tests/test_finetune_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""T046: Tests for finetune_service — written FIRST (TDD), must FAIL before implementation."""
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from app.core.exceptions import LLMCallError
|
||||
from app.models.finetune_models import (
|
||||
FinetuneStartRequest,
|
||||
FinetuneStartResponse,
|
||||
FinetuneStatusResponse,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None):
|
||||
"""Return a mock that looks like ZhipuAIClient with ._client.fine_tuning.jobs.*"""
|
||||
create_resp = MagicMock()
|
||||
create_resp.id = job_id
|
||||
|
||||
retrieve_resp = MagicMock()
|
||||
retrieve_resp.status = status
|
||||
retrieve_resp.progress = progress
|
||||
retrieve_resp.error_message = None # explicitly set to avoid MagicMock auto-attribute
|
||||
|
||||
llm = MagicMock()
|
||||
llm._client.fine_tuning.jobs.create.return_value = create_resp
|
||||
llm._client.fine_tuning.jobs.retrieve.return_value = retrieve_resp
|
||||
return llm
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# submit_finetune
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_finetune_returns_job_id():
|
||||
from app.services.finetune_service import submit_finetune
|
||||
|
||||
llm = _make_llm(job_id="glm-ft-abc123")
|
||||
req = FinetuneStartRequest(
|
||||
jsonl_url="s3://bucket/train.jsonl",
|
||||
base_model="glm-4",
|
||||
hyperparams={"n_epochs": 3},
|
||||
)
|
||||
|
||||
result = await submit_finetune(req, llm)
|
||||
|
||||
assert isinstance(result, FinetuneStartResponse)
|
||||
assert result.job_id == "glm-ft-abc123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_finetune_calls_sdk_with_correct_params():
|
||||
from app.services.finetune_service import submit_finetune
|
||||
|
||||
llm = _make_llm(job_id="glm-ft-xyz")
|
||||
req = FinetuneStartRequest(
|
||||
jsonl_url="s3://bucket/train.jsonl",
|
||||
base_model="glm-4",
|
||||
hyperparams={"n_epochs": 5},
|
||||
)
|
||||
|
||||
await submit_finetune(req, llm)
|
||||
|
||||
llm._client.fine_tuning.jobs.create.assert_called_once_with(
|
||||
training_file="s3://bucket/train.jsonl",
|
||||
model="glm-4",
|
||||
hyperparameters={"n_epochs": 5},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_finetune_none_hyperparams_passes_empty_dict():
|
||||
"""hyperparams=None should be passed as {} to the SDK."""
|
||||
from app.services.finetune_service import submit_finetune
|
||||
|
||||
llm = _make_llm(job_id="glm-ft-nohp")
|
||||
req = FinetuneStartRequest(
|
||||
jsonl_url="s3://bucket/train.jsonl",
|
||||
base_model="glm-4",
|
||||
)
|
||||
|
||||
await submit_finetune(req, llm)
|
||||
|
||||
llm._client.fine_tuning.jobs.create.assert_called_once_with(
|
||||
training_file="s3://bucket/train.jsonl",
|
||||
model="glm-4",
|
||||
hyperparameters={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_finetune_raises_llm_call_error_on_sdk_failure():
|
||||
from app.services.finetune_service import submit_finetune
|
||||
|
||||
llm = MagicMock()
|
||||
llm._client.fine_tuning.jobs.create.side_effect = RuntimeError("SDK exploded")
|
||||
|
||||
req = FinetuneStartRequest(
|
||||
jsonl_url="s3://bucket/train.jsonl",
|
||||
base_model="glm-4",
|
||||
)
|
||||
|
||||
with pytest.raises(LLMCallError):
|
||||
await submit_finetune(req, llm)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_finetune_status — status mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sdk_status,expected", [
|
||||
("running", "RUNNING"),
|
||||
("succeeded", "SUCCESS"),
|
||||
("failed", "FAILED"),
|
||||
("pending", "RUNNING"), # unknown → conservative RUNNING
|
||||
("queued", "RUNNING"), # unknown → conservative RUNNING
|
||||
("cancelled", "RUNNING"), # unknown → conservative RUNNING
|
||||
])
|
||||
async def test_get_finetune_status_maps_status(sdk_status, expected):
|
||||
from app.services.finetune_service import get_finetune_status
|
||||
|
||||
llm = _make_llm(status=sdk_status)
|
||||
|
||||
result = await get_finetune_status("glm-ft-test", llm)
|
||||
|
||||
assert isinstance(result, FinetuneStatusResponse)
|
||||
assert result.status == expected
|
||||
assert result.job_id == "glm-ft-test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_finetune_status_includes_progress():
|
||||
from app.services.finetune_service import get_finetune_status
|
||||
|
||||
llm = _make_llm(status="running", progress=42)
|
||||
result = await get_finetune_status("glm-ft-test", llm)
|
||||
|
||||
assert result.progress == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_finetune_status_raises_llm_call_error_on_sdk_failure():
|
||||
from app.services.finetune_service import get_finetune_status
|
||||
|
||||
llm = MagicMock()
|
||||
llm._client.fine_tuning.jobs.retrieve.side_effect = RuntimeError("SDK exploded")
|
||||
|
||||
with pytest.raises(LLMCallError):
|
||||
await get_finetune_status("glm-ft-bad", llm)
|
||||
8
tests/test_health.py
Normal file
8
tests/test_health.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""T047: Health check endpoint test — GET /health → 200 {"status": "ok"}"""
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_health_returns_ok(client: TestClient):
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
Reference in New Issue
Block a user