diff --git a/app/clients/llm/base.py b/app/clients/llm/base.py index 33ab1d8..a842a13 100644 --- a/app/clients/llm/base.py +++ b/app/clients/llm/base.py @@ -9,3 +9,11 @@ class LLMClient(ABC): @abstractmethod async def chat_vision(self, model: str, messages: list[dict]) -> str: """Send a multimodal (vision) chat request and return the response content string.""" + + @abstractmethod + async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str: + """Submit a fine-tune job and return the job_id.""" + + @abstractmethod + async def get_finetune_status(self, job_id: str) -> dict: + """Return a dict with keys: job_id, status (raw SDK string), progress (int|None), error_message (str|None).""" diff --git a/app/clients/llm/zhipuai_client.py b/app/clients/llm/zhipuai_client.py index a92322d..87d9928 100644 --- a/app/clients/llm/zhipuai_client.py +++ b/app/clients/llm/zhipuai_client.py @@ -19,8 +19,39 @@ class ZhipuAIClient(LLMClient): async def chat_vision(self, model: str, messages: list[dict]) -> str: return await self._call(model, messages) + async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str: + loop = asyncio.get_running_loop() + try: + resp = await loop.run_in_executor( + None, + lambda: self._client.fine_tuning.jobs.create( + training_file=jsonl_url, + model=base_model, + hyperparameters=hyperparams, + ), + ) + return resp.id + except Exception as exc: + raise LLMCallError(f"微调任务提交失败: {exc}") from exc + + async def get_finetune_status(self, job_id: str) -> dict: + loop = asyncio.get_running_loop() + try: + resp = await loop.run_in_executor( + None, + lambda: self._client.fine_tuning.jobs.retrieve(job_id), + ) + return { + "job_id": resp.id, + "status": resp.status, + "progress": int(resp.progress) if getattr(resp, "progress", None) is not None else None, + "error_message": getattr(resp, "error_message", None), + } + except Exception as exc: + raise LLMCallError(f"查询微调任务失败: {exc}") from exc + async def _call(self, model: str, messages: list[dict]) -> str: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: response = await loop.run_in_executor( None, diff --git a/app/clients/storage/rustfs_client.py b/app/clients/storage/rustfs_client.py index 8ef105a..19708d6 100644 --- a/app/clients/storage/rustfs_client.py +++ b/app/clients/storage/rustfs_client.py @@ -21,7 +21,7 @@ class RustFSClient(StorageClient): ) async def download_bytes(self, bucket: str, path: str) -> bytes: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: resp = await loop.run_in_executor( None, lambda: self._s3.get_object(Bucket=bucket, Key=path) @@ -33,7 +33,7 @@ class RustFSClient(StorageClient): async def upload_bytes( self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream" ) -> None: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: await loop.run_in_executor( None, @@ -45,7 +45,7 @@ class RustFSClient(StorageClient): raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: url = await loop.run_in_executor( None, @@ -60,7 +60,7 @@ class RustFSClient(StorageClient): raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc async def get_object_size(self, bucket: str, path: str) -> int: - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() try: resp = await loop.run_in_executor( None, lambda: self._s3.head_object(Bucket=bucket, Key=path) diff --git a/app/services/finetune_service.py b/app/services/finetune_service.py index c55c389..aca2868 100644 --- a/app/services/finetune_service.py +++ b/app/services/finetune_service.py @@ -1,6 +1,4 @@ -import asyncio - -from app.core.exceptions import LLMCallError +from app.clients.llm.base import LLMClient from app.core.logging import get_logger from app.models.finetune_models import ( FinetuneStartRequest, @@ -17,45 +15,21 @@ _STATUS_MAP = { } -async def submit_finetune(req: FinetuneStartRequest, llm) -> FinetuneStartResponse: - """Submit a fine-tune job to ZhipuAI and return the job ID.""" - loop = asyncio.get_event_loop() - try: - response = await loop.run_in_executor( - None, - lambda: llm._client.fine_tuning.jobs.create( - training_file=req.jsonl_url, - model=req.base_model, - hyperparameters=req.hyperparams or {}, - ), - ) - job_id = response.id - logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model}) - return FinetuneStartResponse(job_id=job_id) - except Exception as exc: - logger.error("finetune_submit_error", extra={"error": str(exc)}) - raise LLMCallError(f"微调任务提交失败: {exc}") from exc +async def submit_finetune(req: FinetuneStartRequest, llm: LLMClient) -> FinetuneStartResponse: + """Submit a fine-tune job via the LLMClient interface and return the job ID.""" + job_id = await llm.submit_finetune(req.jsonl_url, req.base_model, req.hyperparams or {}) + logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model}) + return FinetuneStartResponse(job_id=job_id) -async def get_finetune_status(job_id: str, llm) -> FinetuneStatusResponse: - """Retrieve fine-tune job status from ZhipuAI.""" - loop = asyncio.get_event_loop() - try: - response = await loop.run_in_executor( - None, - lambda: llm._client.fine_tuning.jobs.retrieve(job_id), - ) - status_raw = response.status - status = _STATUS_MAP.get(status_raw, "RUNNING") # conservative fallback - progress = getattr(response, "progress", None) - error_message = getattr(response, "error_message", None) - logger.info("finetune_status", extra={"job_id": job_id, "status": status}) - return FinetuneStatusResponse( - job_id=job_id, - status=status, - progress=progress, - error_message=error_message, - ) - except Exception as exc: - logger.error("finetune_status_error", extra={"job_id": job_id, "error": str(exc)}) - raise LLMCallError(f"微调状态查询失败: {exc}") from exc +async def get_finetune_status(job_id: str, llm: LLMClient) -> FinetuneStatusResponse: + """Retrieve fine-tune job status via the LLMClient interface.""" + raw = await llm.get_finetune_status(job_id) + status = _STATUS_MAP.get(raw["status"], "RUNNING") + logger.info("finetune_status", extra={"job_id": job_id, "status": status}) + return FinetuneStatusResponse( + job_id=raw["job_id"], + status=status, + progress=raw["progress"], + error_message=raw["error_message"], + ) diff --git a/tests/test_finetune_service.py b/tests/test_finetune_service.py index 6d458c5..51d93dd 100644 --- a/tests/test_finetune_service.py +++ b/tests/test_finetune_service.py @@ -1,8 +1,8 @@ -"""T046: Tests for finetune_service — written FIRST (TDD), must FAIL before implementation.""" -import asyncio +"""Tests for finetune_service — uses LLMClient interface (no internal SDK access).""" import pytest -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, AsyncMock +from app.clients.llm.base import LLMClient from app.core.exceptions import LLMCallError from app.models.finetune_models import ( FinetuneStartRequest, @@ -16,18 +16,15 @@ from app.models.finetune_models import ( # --------------------------------------------------------------------------- def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None): - """Return a mock that looks like ZhipuAIClient with ._client.fine_tuning.jobs.*""" - create_resp = MagicMock() - create_resp.id = job_id - - retrieve_resp = MagicMock() - retrieve_resp.status = status - retrieve_resp.progress = progress - retrieve_resp.error_message = None # explicitly set to avoid MagicMock auto-attribute - - llm = MagicMock() - llm._client.fine_tuning.jobs.create.return_value = create_resp - llm._client.fine_tuning.jobs.retrieve.return_value = retrieve_resp + """Return a MagicMock(spec=LLMClient) with submit_finetune and get_finetune_status as AsyncMocks.""" + llm = MagicMock(spec=LLMClient) + llm.submit_finetune = AsyncMock(return_value=job_id) + llm.get_finetune_status = AsyncMock(return_value={ + "job_id": job_id, + "status": status, + "progress": progress, + "error_message": None, + }) return llm @@ -53,7 +50,7 @@ async def test_submit_finetune_returns_job_id(): @pytest.mark.asyncio -async def test_submit_finetune_calls_sdk_with_correct_params(): +async def test_submit_finetune_calls_interface_with_correct_params(): from app.services.finetune_service import submit_finetune llm = _make_llm(job_id="glm-ft-xyz") @@ -65,16 +62,16 @@ async def test_submit_finetune_calls_sdk_with_correct_params(): await submit_finetune(req, llm) - llm._client.fine_tuning.jobs.create.assert_called_once_with( - training_file="s3://bucket/train.jsonl", - model="glm-4", - hyperparameters={"n_epochs": 5}, + llm.submit_finetune.assert_awaited_once_with( + "s3://bucket/train.jsonl", + "glm-4", + {"n_epochs": 5}, ) @pytest.mark.asyncio async def test_submit_finetune_none_hyperparams_passes_empty_dict(): - """hyperparams=None should be passed as {} to the SDK.""" + """hyperparams=None should be passed as {} to the interface.""" from app.services.finetune_service import submit_finetune llm = _make_llm(job_id="glm-ft-nohp") @@ -85,19 +82,19 @@ async def test_submit_finetune_none_hyperparams_passes_empty_dict(): await submit_finetune(req, llm) - llm._client.fine_tuning.jobs.create.assert_called_once_with( - training_file="s3://bucket/train.jsonl", - model="glm-4", - hyperparameters={}, + llm.submit_finetune.assert_awaited_once_with( + "s3://bucket/train.jsonl", + "glm-4", + {}, ) @pytest.mark.asyncio -async def test_submit_finetune_raises_llm_call_error_on_sdk_failure(): +async def test_submit_finetune_raises_llm_call_error_on_failure(): from app.services.finetune_service import submit_finetune - llm = MagicMock() - llm._client.fine_tuning.jobs.create.side_effect = RuntimeError("SDK exploded") + llm = MagicMock(spec=LLMClient) + llm.submit_finetune = AsyncMock(side_effect=LLMCallError("微调任务提交失败: SDK exploded")) req = FinetuneStartRequest( jsonl_url="s3://bucket/train.jsonl", @@ -144,11 +141,11 @@ async def test_get_finetune_status_includes_progress(): @pytest.mark.asyncio -async def test_get_finetune_status_raises_llm_call_error_on_sdk_failure(): +async def test_get_finetune_status_raises_llm_call_error_on_failure(): from app.services.finetune_service import get_finetune_status - llm = MagicMock() - llm._client.fine_tuning.jobs.retrieve.side_effect = RuntimeError("SDK exploded") + llm = MagicMock(spec=LLMClient) + llm.get_finetune_status = AsyncMock(side_effect=LLMCallError("查询微调任务失败: SDK exploded")) with pytest.raises(LLMCallError): await get_finetune_status("glm-ft-bad", llm) diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py index 39a586b..e5d0734 100644 --- a/tests/test_llm_client.py +++ b/tests/test_llm_client.py @@ -38,3 +38,44 @@ async def test_llm_call_error_on_sdk_exception(client): client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded") with pytest.raises(LLMCallError, match="大模型调用失败"): await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}]) + + +@pytest.mark.asyncio +async def test_submit_finetune_returns_job_id(client): + """submit_finetune should call the SDK and return the job id.""" + resp = MagicMock() + resp.id = "glm-ft-newjob" + client._client.fine_tuning.jobs.create.return_value = resp + + job_id = await client.submit_finetune( + jsonl_url="s3://bucket/train.jsonl", + base_model="glm-4", + hyperparams={"n_epochs": 2}, + ) + + assert job_id == "glm-ft-newjob" + client._client.fine_tuning.jobs.create.assert_called_once_with( + training_file="s3://bucket/train.jsonl", + model="glm-4", + hyperparameters={"n_epochs": 2}, + ) + + +@pytest.mark.asyncio +async def test_get_finetune_status_returns_correct_dict(client): + """get_finetune_status should return a normalized dict with progress coerced to int.""" + resp = MagicMock() + resp.id = "glm-ft-abc" + resp.status = "running" + resp.progress = "75" # SDK may return string; should be coerced to int + resp.error_message = None + client._client.fine_tuning.jobs.retrieve.return_value = resp + + result = await client.get_finetune_status("glm-ft-abc") + + assert result == { + "job_id": "glm-ft-abc", + "status": "running", + "progress": 75, + "error_message": None, + }