refactor: finetune through LLMClient interface + get_running_loop

- Add submit_finetune and get_finetune_status abstract methods to LLMClient base
- Implement both methods in ZhipuAIClient using asyncio.get_running_loop()
- Rewrite finetune_service to call llm.submit_finetune / llm.get_finetune_status
  instead of accessing llm._client directly, restoring interface encapsulation
- Replace asyncio.get_event_loop() with get_running_loop() in ZhipuAIClient._call
  and all four methods in RustFSClient (deprecated in Python 3.10+)
- Update test_finetune_service to mock the LLMClient interface methods as AsyncMocks
- Add two new tests in test_llm_client for submit_finetune and get_finetune_status
This commit is contained in:
wh
2026-04-10 16:43:28 +08:00
parent 603382d1fa
commit 0880e1018c
6 changed files with 130 additions and 79 deletions

View File

@@ -9,3 +9,11 @@ class LLMClient(ABC):
@abstractmethod @abstractmethod
async def chat_vision(self, model: str, messages: list[dict]) -> str: async def chat_vision(self, model: str, messages: list[dict]) -> str:
"""Send a multimodal (vision) chat request and return the response content string.""" """Send a multimodal (vision) chat request and return the response content string."""
@abstractmethod
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
"""Submit a fine-tune job and return the job_id."""
@abstractmethod
async def get_finetune_status(self, job_id: str) -> dict:
"""Return a dict with keys: job_id, status (raw SDK string), progress (int|None), error_message (str|None)."""

View File

@@ -19,8 +19,39 @@ class ZhipuAIClient(LLMClient):
async def chat_vision(self, model: str, messages: list[dict]) -> str: async def chat_vision(self, model: str, messages: list[dict]) -> str:
return await self._call(model, messages) return await self._call(model, messages)
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
loop = asyncio.get_running_loop()
try:
resp = await loop.run_in_executor(
None,
lambda: self._client.fine_tuning.jobs.create(
training_file=jsonl_url,
model=base_model,
hyperparameters=hyperparams,
),
)
return resp.id
except Exception as exc:
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
async def get_finetune_status(self, job_id: str) -> dict:
loop = asyncio.get_running_loop()
try:
resp = await loop.run_in_executor(
None,
lambda: self._client.fine_tuning.jobs.retrieve(job_id),
)
return {
"job_id": resp.id,
"status": resp.status,
"progress": int(resp.progress) if getattr(resp, "progress", None) is not None else None,
"error_message": getattr(resp, "error_message", None),
}
except Exception as exc:
raise LLMCallError(f"查询微调任务失败: {exc}") from exc
async def _call(self, model: str, messages: list[dict]) -> str: async def _call(self, model: str, messages: list[dict]) -> str:
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
try: try:
response = await loop.run_in_executor( response = await loop.run_in_executor(
None, None,

View File

@@ -21,7 +21,7 @@ class RustFSClient(StorageClient):
) )
async def download_bytes(self, bucket: str, path: str) -> bytes: async def download_bytes(self, bucket: str, path: str) -> bytes:
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
try: try:
resp = await loop.run_in_executor( resp = await loop.run_in_executor(
None, lambda: self._s3.get_object(Bucket=bucket, Key=path) None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
@@ -33,7 +33,7 @@ class RustFSClient(StorageClient):
async def upload_bytes( async def upload_bytes(
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream" self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
) -> None: ) -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
try: try:
await loop.run_in_executor( await loop.run_in_executor(
None, None,
@@ -45,7 +45,7 @@ class RustFSClient(StorageClient):
raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str: async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
try: try:
url = await loop.run_in_executor( url = await loop.run_in_executor(
None, None,
@@ -60,7 +60,7 @@ class RustFSClient(StorageClient):
raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc
async def get_object_size(self, bucket: str, path: str) -> int: async def get_object_size(self, bucket: str, path: str) -> int:
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
try: try:
resp = await loop.run_in_executor( resp = await loop.run_in_executor(
None, lambda: self._s3.head_object(Bucket=bucket, Key=path) None, lambda: self._s3.head_object(Bucket=bucket, Key=path)

View File

@@ -1,6 +1,4 @@
import asyncio from app.clients.llm.base import LLMClient
from app.core.exceptions import LLMCallError
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.finetune_models import ( from app.models.finetune_models import (
FinetuneStartRequest, FinetuneStartRequest,
@@ -17,45 +15,21 @@ _STATUS_MAP = {
} }
async def submit_finetune(req: FinetuneStartRequest, llm) -> FinetuneStartResponse: async def submit_finetune(req: FinetuneStartRequest, llm: LLMClient) -> FinetuneStartResponse:
"""Submit a fine-tune job to ZhipuAI and return the job ID.""" """Submit a fine-tune job via the LLMClient interface and return the job ID."""
loop = asyncio.get_event_loop() job_id = await llm.submit_finetune(req.jsonl_url, req.base_model, req.hyperparams or {})
try:
response = await loop.run_in_executor(
None,
lambda: llm._client.fine_tuning.jobs.create(
training_file=req.jsonl_url,
model=req.base_model,
hyperparameters=req.hyperparams or {},
),
)
job_id = response.id
logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model}) logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model})
return FinetuneStartResponse(job_id=job_id) return FinetuneStartResponse(job_id=job_id)
except Exception as exc:
logger.error("finetune_submit_error", extra={"error": str(exc)})
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
async def get_finetune_status(job_id: str, llm) -> FinetuneStatusResponse: async def get_finetune_status(job_id: str, llm: LLMClient) -> FinetuneStatusResponse:
"""Retrieve fine-tune job status from ZhipuAI.""" """Retrieve fine-tune job status via the LLMClient interface."""
loop = asyncio.get_event_loop() raw = await llm.get_finetune_status(job_id)
try: status = _STATUS_MAP.get(raw["status"], "RUNNING")
response = await loop.run_in_executor(
None,
lambda: llm._client.fine_tuning.jobs.retrieve(job_id),
)
status_raw = response.status
status = _STATUS_MAP.get(status_raw, "RUNNING") # conservative fallback
progress = getattr(response, "progress", None)
error_message = getattr(response, "error_message", None)
logger.info("finetune_status", extra={"job_id": job_id, "status": status}) logger.info("finetune_status", extra={"job_id": job_id, "status": status})
return FinetuneStatusResponse( return FinetuneStatusResponse(
job_id=job_id, job_id=raw["job_id"],
status=status, status=status,
progress=progress, progress=raw["progress"],
error_message=error_message, error_message=raw["error_message"],
) )
except Exception as exc:
logger.error("finetune_status_error", extra={"job_id": job_id, "error": str(exc)})
raise LLMCallError(f"微调状态查询失败: {exc}") from exc

View File

@@ -1,8 +1,8 @@
"""T046: Tests for finetune_service — written FIRST (TDD), must FAIL before implementation.""" """Tests for finetune_service — uses LLMClient interface (no internal SDK access)."""
import asyncio
import pytest import pytest
from unittest.mock import MagicMock, AsyncMock, patch from unittest.mock import MagicMock, AsyncMock
from app.clients.llm.base import LLMClient
from app.core.exceptions import LLMCallError from app.core.exceptions import LLMCallError
from app.models.finetune_models import ( from app.models.finetune_models import (
FinetuneStartRequest, FinetuneStartRequest,
@@ -16,18 +16,15 @@ from app.models.finetune_models import (
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None): def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None):
"""Return a mock that looks like ZhipuAIClient with ._client.fine_tuning.jobs.*""" """Return a MagicMock(spec=LLMClient) with submit_finetune and get_finetune_status as AsyncMocks."""
create_resp = MagicMock() llm = MagicMock(spec=LLMClient)
create_resp.id = job_id llm.submit_finetune = AsyncMock(return_value=job_id)
llm.get_finetune_status = AsyncMock(return_value={
retrieve_resp = MagicMock() "job_id": job_id,
retrieve_resp.status = status "status": status,
retrieve_resp.progress = progress "progress": progress,
retrieve_resp.error_message = None # explicitly set to avoid MagicMock auto-attribute "error_message": None,
})
llm = MagicMock()
llm._client.fine_tuning.jobs.create.return_value = create_resp
llm._client.fine_tuning.jobs.retrieve.return_value = retrieve_resp
return llm return llm
@@ -53,7 +50,7 @@ async def test_submit_finetune_returns_job_id():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_submit_finetune_calls_sdk_with_correct_params(): async def test_submit_finetune_calls_interface_with_correct_params():
from app.services.finetune_service import submit_finetune from app.services.finetune_service import submit_finetune
llm = _make_llm(job_id="glm-ft-xyz") llm = _make_llm(job_id="glm-ft-xyz")
@@ -65,16 +62,16 @@ async def test_submit_finetune_calls_sdk_with_correct_params():
await submit_finetune(req, llm) await submit_finetune(req, llm)
llm._client.fine_tuning.jobs.create.assert_called_once_with( llm.submit_finetune.assert_awaited_once_with(
training_file="s3://bucket/train.jsonl", "s3://bucket/train.jsonl",
model="glm-4", "glm-4",
hyperparameters={"n_epochs": 5}, {"n_epochs": 5},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_submit_finetune_none_hyperparams_passes_empty_dict(): async def test_submit_finetune_none_hyperparams_passes_empty_dict():
"""hyperparams=None should be passed as {} to the SDK.""" """hyperparams=None should be passed as {} to the interface."""
from app.services.finetune_service import submit_finetune from app.services.finetune_service import submit_finetune
llm = _make_llm(job_id="glm-ft-nohp") llm = _make_llm(job_id="glm-ft-nohp")
@@ -85,19 +82,19 @@ async def test_submit_finetune_none_hyperparams_passes_empty_dict():
await submit_finetune(req, llm) await submit_finetune(req, llm)
llm._client.fine_tuning.jobs.create.assert_called_once_with( llm.submit_finetune.assert_awaited_once_with(
training_file="s3://bucket/train.jsonl", "s3://bucket/train.jsonl",
model="glm-4", "glm-4",
hyperparameters={}, {},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_submit_finetune_raises_llm_call_error_on_sdk_failure(): async def test_submit_finetune_raises_llm_call_error_on_failure():
from app.services.finetune_service import submit_finetune from app.services.finetune_service import submit_finetune
llm = MagicMock() llm = MagicMock(spec=LLMClient)
llm._client.fine_tuning.jobs.create.side_effect = RuntimeError("SDK exploded") llm.submit_finetune = AsyncMock(side_effect=LLMCallError("微调任务提交失败: SDK exploded"))
req = FinetuneStartRequest( req = FinetuneStartRequest(
jsonl_url="s3://bucket/train.jsonl", jsonl_url="s3://bucket/train.jsonl",
@@ -144,11 +141,11 @@ async def test_get_finetune_status_includes_progress():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_finetune_status_raises_llm_call_error_on_sdk_failure(): async def test_get_finetune_status_raises_llm_call_error_on_failure():
from app.services.finetune_service import get_finetune_status from app.services.finetune_service import get_finetune_status
llm = MagicMock() llm = MagicMock(spec=LLMClient)
llm._client.fine_tuning.jobs.retrieve.side_effect = RuntimeError("SDK exploded") llm.get_finetune_status = AsyncMock(side_effect=LLMCallError("查询微调任务失败: SDK exploded"))
with pytest.raises(LLMCallError): with pytest.raises(LLMCallError):
await get_finetune_status("glm-ft-bad", llm) await get_finetune_status("glm-ft-bad", llm)

View File

@@ -38,3 +38,44 @@ async def test_llm_call_error_on_sdk_exception(client):
client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded") client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded")
with pytest.raises(LLMCallError, match="大模型调用失败"): with pytest.raises(LLMCallError, match="大模型调用失败"):
await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}]) await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}])
@pytest.mark.asyncio
async def test_submit_finetune_returns_job_id(client):
"""submit_finetune should call the SDK and return the job id."""
resp = MagicMock()
resp.id = "glm-ft-newjob"
client._client.fine_tuning.jobs.create.return_value = resp
job_id = await client.submit_finetune(
jsonl_url="s3://bucket/train.jsonl",
base_model="glm-4",
hyperparams={"n_epochs": 2},
)
assert job_id == "glm-ft-newjob"
client._client.fine_tuning.jobs.create.assert_called_once_with(
training_file="s3://bucket/train.jsonl",
model="glm-4",
hyperparameters={"n_epochs": 2},
)
@pytest.mark.asyncio
async def test_get_finetune_status_returns_correct_dict(client):
"""get_finetune_status should return a normalized dict with progress coerced to int."""
resp = MagicMock()
resp.id = "glm-ft-abc"
resp.status = "running"
resp.progress = "75" # SDK may return string; should be coerced to int
resp.error_message = None
client._client.fine_tuning.jobs.retrieve.return_value = resp
result = await client.get_finetune_status("glm-ft-abc")
assert result == {
"job_id": "glm-ft-abc",
"status": "running",
"progress": 75,
"error_message": None,
}