diff --git a/app/models/finetune_models.py b/app/models/finetune_models.py new file mode 100644 index 0000000..36ac21b --- /dev/null +++ b/app/models/finetune_models.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + + +class FinetuneStartRequest(BaseModel): + jsonl_url: str + base_model: str + hyperparams: dict | None = None + + +class FinetuneStartResponse(BaseModel): + job_id: str + + +class FinetuneStatusResponse(BaseModel): + job_id: str + status: str + progress: int | None = None + error_message: str | None = None diff --git a/app/routers/finetune.py b/app/routers/finetune.py index f16ec0f..47136a9 100644 --- a/app/routers/finetune.py +++ b/app/routers/finetune.py @@ -1,3 +1,28 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends + +from app.clients.llm.base import LLMClient +from app.core.dependencies import get_llm_client +from app.models.finetune_models import ( + FinetuneStartRequest, + FinetuneStartResponse, + FinetuneStatusResponse, +) +from app.services import finetune_service router = APIRouter(tags=["Finetune"]) + + +@router.post("/finetune/start", response_model=FinetuneStartResponse) +async def start_finetune( + req: FinetuneStartRequest, + llm: LLMClient = Depends(get_llm_client), +) -> FinetuneStartResponse: + return await finetune_service.submit_finetune(req, llm) + + +@router.get("/finetune/status/{job_id}", response_model=FinetuneStatusResponse) +async def get_status( + job_id: str, + llm: LLMClient = Depends(get_llm_client), +) -> FinetuneStatusResponse: + return await finetune_service.get_finetune_status(job_id, llm) diff --git a/app/services/finetune_service.py b/app/services/finetune_service.py new file mode 100644 index 0000000..c55c389 --- /dev/null +++ b/app/services/finetune_service.py @@ -0,0 +1,61 @@ +import asyncio + +from app.core.exceptions import LLMCallError +from app.core.logging import get_logger +from app.models.finetune_models import ( + FinetuneStartRequest, + FinetuneStartResponse, + FinetuneStatusResponse, +) + +logger = get_logger(__name__) + +_STATUS_MAP = { + "running": "RUNNING", + "succeeded": "SUCCESS", + "failed": "FAILED", +} + + +async def submit_finetune(req: FinetuneStartRequest, llm) -> FinetuneStartResponse: + """Submit a fine-tune job to ZhipuAI and return the job ID.""" + loop = asyncio.get_event_loop() + try: + response = await loop.run_in_executor( + None, + lambda: llm._client.fine_tuning.jobs.create( + training_file=req.jsonl_url, + model=req.base_model, + hyperparameters=req.hyperparams or {}, + ), + ) + job_id = response.id + logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model}) + return FinetuneStartResponse(job_id=job_id) + except Exception as exc: + logger.error("finetune_submit_error", extra={"error": str(exc)}) + raise LLMCallError(f"微调任务提交失败: {exc}") from exc + + +async def get_finetune_status(job_id: str, llm) -> FinetuneStatusResponse: + """Retrieve fine-tune job status from ZhipuAI.""" + loop = asyncio.get_event_loop() + try: + response = await loop.run_in_executor( + None, + lambda: llm._client.fine_tuning.jobs.retrieve(job_id), + ) + status_raw = response.status + status = _STATUS_MAP.get(status_raw, "RUNNING") # conservative fallback + progress = getattr(response, "progress", None) + error_message = getattr(response, "error_message", None) + logger.info("finetune_status", extra={"job_id": job_id, "status": status}) + return FinetuneStatusResponse( + job_id=job_id, + status=status, + progress=progress, + error_message=error_message, + ) + except Exception as exc: + logger.error("finetune_status_error", extra={"job_id": job_id, "error": str(exc)}) + raise LLMCallError(f"微调状态查询失败: {exc}") from exc diff --git a/tests/test_finetune_router.py b/tests/test_finetune_router.py new file mode 100644 index 0000000..6678195 --- /dev/null +++ b/tests/test_finetune_router.py @@ -0,0 +1,112 @@ +"""T050: Integration tests for finetune router endpoints.""" +import pytest +from unittest.mock import MagicMock, patch + +from app.core.exceptions import LLMCallError +from app.models.finetune_models import FinetuneStartResponse, FinetuneStatusResponse + + +# --------------------------------------------------------------------------- +# POST /api/v1/finetune/start +# --------------------------------------------------------------------------- + +def test_finetune_start_returns_200_with_job_id(client): + start_resp = FinetuneStartResponse(job_id="glm-ft-router-test") + + with patch("app.routers.finetune.finetune_service.submit_finetune") as mock_submit: + mock_submit.return_value = start_resp + + resp = client.post( + "/api/v1/finetune/start", + json={ + "jsonl_url": "s3://bucket/train.jsonl", + "base_model": "glm-4", + "hyperparams": {"n_epochs": 3}, + }, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["job_id"] == "glm-ft-router-test" + + +def test_finetune_start_without_hyperparams(client): + start_resp = FinetuneStartResponse(job_id="glm-ft-nohp") + + with patch("app.routers.finetune.finetune_service.submit_finetune") as mock_submit: + mock_submit.return_value = start_resp + + resp = client.post( + "/api/v1/finetune/start", + json={ + "jsonl_url": "s3://bucket/train.jsonl", + "base_model": "glm-4", + }, + ) + + assert resp.status_code == 200 + assert resp.json()["job_id"] == "glm-ft-nohp" + + +def test_finetune_start_llm_call_error_returns_503(client): + with patch("app.routers.finetune.finetune_service.submit_finetune") as mock_submit: + mock_submit.side_effect = LLMCallError("SDK failed") + + resp = client.post( + "/api/v1/finetune/start", + json={ + "jsonl_url": "s3://bucket/train.jsonl", + "base_model": "glm-4", + }, + ) + + assert resp.status_code == 503 + assert resp.json()["code"] == "LLM_CALL_ERROR" + + +# --------------------------------------------------------------------------- +# GET /api/v1/finetune/status/{job_id} +# --------------------------------------------------------------------------- + +def test_finetune_status_returns_200_with_fields(client): + status_resp = FinetuneStatusResponse( + job_id="glm-ft-router-test", + status="RUNNING", + progress=30, + ) + + with patch("app.routers.finetune.finetune_service.get_finetune_status") as mock_status: + mock_status.return_value = status_resp + + resp = client.get("/api/v1/finetune/status/glm-ft-router-test") + + assert resp.status_code == 200 + data = resp.json() + assert data["job_id"] == "glm-ft-router-test" + assert data["status"] == "RUNNING" + assert data["progress"] == 30 + + +def test_finetune_status_succeeded(client): + status_resp = FinetuneStatusResponse( + job_id="glm-ft-done", + status="SUCCESS", + ) + + with patch("app.routers.finetune.finetune_service.get_finetune_status") as mock_status: + mock_status.return_value = status_resp + + resp = client.get("/api/v1/finetune/status/glm-ft-done") + + assert resp.status_code == 200 + assert resp.json()["status"] == "SUCCESS" + + +def test_finetune_status_llm_call_error_returns_503(client): + with patch("app.routers.finetune.finetune_service.get_finetune_status") as mock_status: + mock_status.side_effect = LLMCallError("SDK failed") + + resp = client.get("/api/v1/finetune/status/glm-ft-bad") + + assert resp.status_code == 503 + assert resp.json()["code"] == "LLM_CALL_ERROR" diff --git a/tests/test_finetune_service.py b/tests/test_finetune_service.py new file mode 100644 index 0000000..6d458c5 --- /dev/null +++ b/tests/test_finetune_service.py @@ -0,0 +1,154 @@ +"""T046: Tests for finetune_service — written FIRST (TDD), must FAIL before implementation.""" +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from app.core.exceptions import LLMCallError +from app.models.finetune_models import ( + FinetuneStartRequest, + FinetuneStartResponse, + FinetuneStatusResponse, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None): + """Return a mock that looks like ZhipuAIClient with ._client.fine_tuning.jobs.*""" + create_resp = MagicMock() + create_resp.id = job_id + + retrieve_resp = MagicMock() + retrieve_resp.status = status + retrieve_resp.progress = progress + retrieve_resp.error_message = None # explicitly set to avoid MagicMock auto-attribute + + llm = MagicMock() + llm._client.fine_tuning.jobs.create.return_value = create_resp + llm._client.fine_tuning.jobs.retrieve.return_value = retrieve_resp + return llm + + +# --------------------------------------------------------------------------- +# submit_finetune +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_submit_finetune_returns_job_id(): + from app.services.finetune_service import submit_finetune + + llm = _make_llm(job_id="glm-ft-abc123") + req = FinetuneStartRequest( + jsonl_url="s3://bucket/train.jsonl", + base_model="glm-4", + hyperparams={"n_epochs": 3}, + ) + + result = await submit_finetune(req, llm) + + assert isinstance(result, FinetuneStartResponse) + assert result.job_id == "glm-ft-abc123" + + +@pytest.mark.asyncio +async def test_submit_finetune_calls_sdk_with_correct_params(): + from app.services.finetune_service import submit_finetune + + llm = _make_llm(job_id="glm-ft-xyz") + req = FinetuneStartRequest( + jsonl_url="s3://bucket/train.jsonl", + base_model="glm-4", + hyperparams={"n_epochs": 5}, + ) + + await submit_finetune(req, llm) + + llm._client.fine_tuning.jobs.create.assert_called_once_with( + training_file="s3://bucket/train.jsonl", + model="glm-4", + hyperparameters={"n_epochs": 5}, + ) + + +@pytest.mark.asyncio +async def test_submit_finetune_none_hyperparams_passes_empty_dict(): + """hyperparams=None should be passed as {} to the SDK.""" + from app.services.finetune_service import submit_finetune + + llm = _make_llm(job_id="glm-ft-nohp") + req = FinetuneStartRequest( + jsonl_url="s3://bucket/train.jsonl", + base_model="glm-4", + ) + + await submit_finetune(req, llm) + + llm._client.fine_tuning.jobs.create.assert_called_once_with( + training_file="s3://bucket/train.jsonl", + model="glm-4", + hyperparameters={}, + ) + + +@pytest.mark.asyncio +async def test_submit_finetune_raises_llm_call_error_on_sdk_failure(): + from app.services.finetune_service import submit_finetune + + llm = MagicMock() + llm._client.fine_tuning.jobs.create.side_effect = RuntimeError("SDK exploded") + + req = FinetuneStartRequest( + jsonl_url="s3://bucket/train.jsonl", + base_model="glm-4", + ) + + with pytest.raises(LLMCallError): + await submit_finetune(req, llm) + + +# --------------------------------------------------------------------------- +# get_finetune_status — status mapping +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +@pytest.mark.parametrize("sdk_status,expected", [ + ("running", "RUNNING"), + ("succeeded", "SUCCESS"), + ("failed", "FAILED"), + ("pending", "RUNNING"), # unknown → conservative RUNNING + ("queued", "RUNNING"), # unknown → conservative RUNNING + ("cancelled", "RUNNING"), # unknown → conservative RUNNING +]) +async def test_get_finetune_status_maps_status(sdk_status, expected): + from app.services.finetune_service import get_finetune_status + + llm = _make_llm(status=sdk_status) + + result = await get_finetune_status("glm-ft-test", llm) + + assert isinstance(result, FinetuneStatusResponse) + assert result.status == expected + assert result.job_id == "glm-ft-test" + + +@pytest.mark.asyncio +async def test_get_finetune_status_includes_progress(): + from app.services.finetune_service import get_finetune_status + + llm = _make_llm(status="running", progress=42) + result = await get_finetune_status("glm-ft-test", llm) + + assert result.progress == 42 + + +@pytest.mark.asyncio +async def test_get_finetune_status_raises_llm_call_error_on_sdk_failure(): + from app.services.finetune_service import get_finetune_status + + llm = MagicMock() + llm._client.fine_tuning.jobs.retrieve.side_effect = RuntimeError("SDK exploded") + + with pytest.raises(LLMCallError): + await get_finetune_status("glm-ft-bad", llm) diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..0f2b3e4 --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,8 @@ +"""T047: Health check endpoint test — GET /health → 200 {"status": "ok"}""" +from fastapi.testclient import TestClient + + +def test_health_returns_ok(client: TestClient): + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"}