Files
label_ai_service/tests/test_finetune_service.py

155 lines
4.8 KiB
Python
Raw Normal View History

"""T046: Tests for finetune_service — written FIRST (TDD), must FAIL before implementation."""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from app.core.exceptions import LLMCallError
from app.models.finetune_models import (
FinetuneStartRequest,
FinetuneStartResponse,
FinetuneStatusResponse,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None):
"""Return a mock that looks like ZhipuAIClient with ._client.fine_tuning.jobs.*"""
create_resp = MagicMock()
create_resp.id = job_id
retrieve_resp = MagicMock()
retrieve_resp.status = status
retrieve_resp.progress = progress
retrieve_resp.error_message = None # explicitly set to avoid MagicMock auto-attribute
llm = MagicMock()
llm._client.fine_tuning.jobs.create.return_value = create_resp
llm._client.fine_tuning.jobs.retrieve.return_value = retrieve_resp
return llm
# ---------------------------------------------------------------------------
# submit_finetune
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_submit_finetune_returns_job_id():
from app.services.finetune_service import submit_finetune
llm = _make_llm(job_id="glm-ft-abc123")
req = FinetuneStartRequest(
jsonl_url="s3://bucket/train.jsonl",
base_model="glm-4",
hyperparams={"n_epochs": 3},
)
result = await submit_finetune(req, llm)
assert isinstance(result, FinetuneStartResponse)
assert result.job_id == "glm-ft-abc123"
@pytest.mark.asyncio
async def test_submit_finetune_calls_sdk_with_correct_params():
from app.services.finetune_service import submit_finetune
llm = _make_llm(job_id="glm-ft-xyz")
req = FinetuneStartRequest(
jsonl_url="s3://bucket/train.jsonl",
base_model="glm-4",
hyperparams={"n_epochs": 5},
)
await submit_finetune(req, llm)
llm._client.fine_tuning.jobs.create.assert_called_once_with(
training_file="s3://bucket/train.jsonl",
model="glm-4",
hyperparameters={"n_epochs": 5},
)
@pytest.mark.asyncio
async def test_submit_finetune_none_hyperparams_passes_empty_dict():
"""hyperparams=None should be passed as {} to the SDK."""
from app.services.finetune_service import submit_finetune
llm = _make_llm(job_id="glm-ft-nohp")
req = FinetuneStartRequest(
jsonl_url="s3://bucket/train.jsonl",
base_model="glm-4",
)
await submit_finetune(req, llm)
llm._client.fine_tuning.jobs.create.assert_called_once_with(
training_file="s3://bucket/train.jsonl",
model="glm-4",
hyperparameters={},
)
@pytest.mark.asyncio
async def test_submit_finetune_raises_llm_call_error_on_sdk_failure():
from app.services.finetune_service import submit_finetune
llm = MagicMock()
llm._client.fine_tuning.jobs.create.side_effect = RuntimeError("SDK exploded")
req = FinetuneStartRequest(
jsonl_url="s3://bucket/train.jsonl",
base_model="glm-4",
)
with pytest.raises(LLMCallError):
await submit_finetune(req, llm)
# ---------------------------------------------------------------------------
# get_finetune_status — status mapping
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize("sdk_status,expected", [
("running", "RUNNING"),
("succeeded", "SUCCESS"),
("failed", "FAILED"),
("pending", "RUNNING"), # unknown → conservative RUNNING
("queued", "RUNNING"), # unknown → conservative RUNNING
("cancelled", "RUNNING"), # unknown → conservative RUNNING
])
async def test_get_finetune_status_maps_status(sdk_status, expected):
from app.services.finetune_service import get_finetune_status
llm = _make_llm(status=sdk_status)
result = await get_finetune_status("glm-ft-test", llm)
assert isinstance(result, FinetuneStatusResponse)
assert result.status == expected
assert result.job_id == "glm-ft-test"
@pytest.mark.asyncio
async def test_get_finetune_status_includes_progress():
from app.services.finetune_service import get_finetune_status
llm = _make_llm(status="running", progress=42)
result = await get_finetune_status("glm-ft-test", llm)
assert result.progress == 42
@pytest.mark.asyncio
async def test_get_finetune_status_raises_llm_call_error_on_sdk_failure():
from app.services.finetune_service import get_finetune_status
llm = MagicMock()
llm._client.fine_tuning.jobs.retrieve.side_effect = RuntimeError("SDK exploded")
with pytest.raises(LLMCallError):
await get_finetune_status("glm-ft-bad", llm)