refactor: finetune through LLMClient interface + get_running_loop
- Add submit_finetune and get_finetune_status abstract methods to LLMClient base - Implement both methods in ZhipuAIClient using asyncio.get_running_loop() - Rewrite finetune_service to call llm.submit_finetune / llm.get_finetune_status instead of accessing llm._client directly, restoring interface encapsulation - Replace asyncio.get_event_loop() with get_running_loop() in ZhipuAIClient._call and all four methods in RustFSClient (deprecated in Python 3.10+) - Update test_finetune_service to mock the LLMClient interface methods as AsyncMocks - Add two new tests in test_llm_client for submit_finetune and get_finetune_status
This commit is contained in:
@@ -9,3 +9,11 @@ class LLMClient(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||||
"""Send a multimodal (vision) chat request and return the response content string."""
|
"""Send a multimodal (vision) chat request and return the response content string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||||||
|
"""Submit a fine-tune job and return the job_id."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_finetune_status(self, job_id: str) -> dict:
|
||||||
|
"""Return a dict with keys: job_id, status (raw SDK string), progress (int|None), error_message (str|None)."""
|
||||||
|
|||||||
@@ -19,8 +19,39 @@ class ZhipuAIClient(LLMClient):
|
|||||||
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||||||
return await self._call(model, messages)
|
return await self._call(model, messages)
|
||||||
|
|
||||||
|
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._client.fine_tuning.jobs.create(
|
||||||
|
training_file=jsonl_url,
|
||||||
|
model=base_model,
|
||||||
|
hyperparameters=hyperparams,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return resp.id
|
||||||
|
except Exception as exc:
|
||||||
|
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
|
||||||
|
|
||||||
|
async def get_finetune_status(self, job_id: str) -> dict:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self._client.fine_tuning.jobs.retrieve(job_id),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"job_id": resp.id,
|
||||||
|
"status": resp.status,
|
||||||
|
"progress": int(resp.progress) if getattr(resp, "progress", None) is not None else None,
|
||||||
|
"error_message": getattr(resp, "error_message", None),
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
raise LLMCallError(f"查询微调任务失败: {exc}") from exc
|
||||||
|
|
||||||
async def _call(self, model: str, messages: list[dict]) -> str:
|
async def _call(self, model: str, messages: list[dict]) -> str:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
response = await loop.run_in_executor(
|
response = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class RustFSClient(StorageClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
async def download_bytes(self, bucket: str, path: str) -> bytes:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
resp = await loop.run_in_executor(
|
resp = await loop.run_in_executor(
|
||||||
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
|
None, lambda: self._s3.get_object(Bucket=bucket, Key=path)
|
||||||
@@ -33,7 +33,7 @@ class RustFSClient(StorageClient):
|
|||||||
async def upload_bytes(
|
async def upload_bytes(
|
||||||
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
self, bucket: str, path: str, data: bytes, content_type: str = "application/octet-stream"
|
||||||
) -> None:
|
) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -45,7 +45,7 @@ class RustFSClient(StorageClient):
|
|||||||
raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc
|
raise StorageError(f"存储上传失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
async def get_presigned_url(self, bucket: str, path: str, expires: int = 3600) -> str:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
url = await loop.run_in_executor(
|
url = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -60,7 +60,7 @@ class RustFSClient(StorageClient):
|
|||||||
raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc
|
raise StorageError(f"生成预签名 URL 失败 [{bucket}/{path}]: {exc}") from exc
|
||||||
|
|
||||||
async def get_object_size(self, bucket: str, path: str) -> int:
|
async def get_object_size(self, bucket: str, path: str) -> int:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
resp = await loop.run_in_executor(
|
resp = await loop.run_in_executor(
|
||||||
None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
|
None, lambda: self._s3.head_object(Bucket=bucket, Key=path)
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import asyncio
|
from app.clients.llm.base import LLMClient
|
||||||
|
|
||||||
from app.core.exceptions import LLMCallError
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.finetune_models import (
|
from app.models.finetune_models import (
|
||||||
FinetuneStartRequest,
|
FinetuneStartRequest,
|
||||||
@@ -17,45 +15,21 @@ _STATUS_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def submit_finetune(req: FinetuneStartRequest, llm) -> FinetuneStartResponse:
|
async def submit_finetune(req: FinetuneStartRequest, llm: LLMClient) -> FinetuneStartResponse:
|
||||||
"""Submit a fine-tune job to ZhipuAI and return the job ID."""
|
"""Submit a fine-tune job via the LLMClient interface and return the job ID."""
|
||||||
loop = asyncio.get_event_loop()
|
job_id = await llm.submit_finetune(req.jsonl_url, req.base_model, req.hyperparams or {})
|
||||||
try:
|
logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model})
|
||||||
response = await loop.run_in_executor(
|
return FinetuneStartResponse(job_id=job_id)
|
||||||
None,
|
|
||||||
lambda: llm._client.fine_tuning.jobs.create(
|
|
||||||
training_file=req.jsonl_url,
|
|
||||||
model=req.base_model,
|
|
||||||
hyperparameters=req.hyperparams or {},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
job_id = response.id
|
|
||||||
logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model})
|
|
||||||
return FinetuneStartResponse(job_id=job_id)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("finetune_submit_error", extra={"error": str(exc)})
|
|
||||||
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
|
|
||||||
|
|
||||||
|
|
||||||
async def get_finetune_status(job_id: str, llm) -> FinetuneStatusResponse:
|
async def get_finetune_status(job_id: str, llm: LLMClient) -> FinetuneStatusResponse:
|
||||||
"""Retrieve fine-tune job status from ZhipuAI."""
|
"""Retrieve fine-tune job status via the LLMClient interface."""
|
||||||
loop = asyncio.get_event_loop()
|
raw = await llm.get_finetune_status(job_id)
|
||||||
try:
|
status = _STATUS_MAP.get(raw["status"], "RUNNING")
|
||||||
response = await loop.run_in_executor(
|
logger.info("finetune_status", extra={"job_id": job_id, "status": status})
|
||||||
None,
|
return FinetuneStatusResponse(
|
||||||
lambda: llm._client.fine_tuning.jobs.retrieve(job_id),
|
job_id=raw["job_id"],
|
||||||
)
|
status=status,
|
||||||
status_raw = response.status
|
progress=raw["progress"],
|
||||||
status = _STATUS_MAP.get(status_raw, "RUNNING") # conservative fallback
|
error_message=raw["error_message"],
|
||||||
progress = getattr(response, "progress", None)
|
)
|
||||||
error_message = getattr(response, "error_message", None)
|
|
||||||
logger.info("finetune_status", extra={"job_id": job_id, "status": status})
|
|
||||||
return FinetuneStatusResponse(
|
|
||||||
job_id=job_id,
|
|
||||||
status=status,
|
|
||||||
progress=progress,
|
|
||||||
error_message=error_message,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("finetune_status_error", extra={"job_id": job_id, "error": str(exc)})
|
|
||||||
raise LLMCallError(f"微调状态查询失败: {exc}") from exc
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""T046: Tests for finetune_service — written FIRST (TDD), must FAIL before implementation."""
|
"""Tests for finetune_service — uses LLMClient interface (no internal SDK access)."""
|
||||||
import asyncio
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, AsyncMock, patch
|
from unittest.mock import MagicMock, AsyncMock
|
||||||
|
|
||||||
|
from app.clients.llm.base import LLMClient
|
||||||
from app.core.exceptions import LLMCallError
|
from app.core.exceptions import LLMCallError
|
||||||
from app.models.finetune_models import (
|
from app.models.finetune_models import (
|
||||||
FinetuneStartRequest,
|
FinetuneStartRequest,
|
||||||
@@ -16,18 +16,15 @@ from app.models.finetune_models import (
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None):
|
def _make_llm(job_id: str = "glm-ft-test", status: str = "running", progress: int | None = None):
|
||||||
"""Return a mock that looks like ZhipuAIClient with ._client.fine_tuning.jobs.*"""
|
"""Return a MagicMock(spec=LLMClient) with submit_finetune and get_finetune_status as AsyncMocks."""
|
||||||
create_resp = MagicMock()
|
llm = MagicMock(spec=LLMClient)
|
||||||
create_resp.id = job_id
|
llm.submit_finetune = AsyncMock(return_value=job_id)
|
||||||
|
llm.get_finetune_status = AsyncMock(return_value={
|
||||||
retrieve_resp = MagicMock()
|
"job_id": job_id,
|
||||||
retrieve_resp.status = status
|
"status": status,
|
||||||
retrieve_resp.progress = progress
|
"progress": progress,
|
||||||
retrieve_resp.error_message = None # explicitly set to avoid MagicMock auto-attribute
|
"error_message": None,
|
||||||
|
})
|
||||||
llm = MagicMock()
|
|
||||||
llm._client.fine_tuning.jobs.create.return_value = create_resp
|
|
||||||
llm._client.fine_tuning.jobs.retrieve.return_value = retrieve_resp
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +50,7 @@ async def test_submit_finetune_returns_job_id():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_submit_finetune_calls_sdk_with_correct_params():
|
async def test_submit_finetune_calls_interface_with_correct_params():
|
||||||
from app.services.finetune_service import submit_finetune
|
from app.services.finetune_service import submit_finetune
|
||||||
|
|
||||||
llm = _make_llm(job_id="glm-ft-xyz")
|
llm = _make_llm(job_id="glm-ft-xyz")
|
||||||
@@ -65,16 +62,16 @@ async def test_submit_finetune_calls_sdk_with_correct_params():
|
|||||||
|
|
||||||
await submit_finetune(req, llm)
|
await submit_finetune(req, llm)
|
||||||
|
|
||||||
llm._client.fine_tuning.jobs.create.assert_called_once_with(
|
llm.submit_finetune.assert_awaited_once_with(
|
||||||
training_file="s3://bucket/train.jsonl",
|
"s3://bucket/train.jsonl",
|
||||||
model="glm-4",
|
"glm-4",
|
||||||
hyperparameters={"n_epochs": 5},
|
{"n_epochs": 5},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_submit_finetune_none_hyperparams_passes_empty_dict():
|
async def test_submit_finetune_none_hyperparams_passes_empty_dict():
|
||||||
"""hyperparams=None should be passed as {} to the SDK."""
|
"""hyperparams=None should be passed as {} to the interface."""
|
||||||
from app.services.finetune_service import submit_finetune
|
from app.services.finetune_service import submit_finetune
|
||||||
|
|
||||||
llm = _make_llm(job_id="glm-ft-nohp")
|
llm = _make_llm(job_id="glm-ft-nohp")
|
||||||
@@ -85,19 +82,19 @@ async def test_submit_finetune_none_hyperparams_passes_empty_dict():
|
|||||||
|
|
||||||
await submit_finetune(req, llm)
|
await submit_finetune(req, llm)
|
||||||
|
|
||||||
llm._client.fine_tuning.jobs.create.assert_called_once_with(
|
llm.submit_finetune.assert_awaited_once_with(
|
||||||
training_file="s3://bucket/train.jsonl",
|
"s3://bucket/train.jsonl",
|
||||||
model="glm-4",
|
"glm-4",
|
||||||
hyperparameters={},
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_submit_finetune_raises_llm_call_error_on_sdk_failure():
|
async def test_submit_finetune_raises_llm_call_error_on_failure():
|
||||||
from app.services.finetune_service import submit_finetune
|
from app.services.finetune_service import submit_finetune
|
||||||
|
|
||||||
llm = MagicMock()
|
llm = MagicMock(spec=LLMClient)
|
||||||
llm._client.fine_tuning.jobs.create.side_effect = RuntimeError("SDK exploded")
|
llm.submit_finetune = AsyncMock(side_effect=LLMCallError("微调任务提交失败: SDK exploded"))
|
||||||
|
|
||||||
req = FinetuneStartRequest(
|
req = FinetuneStartRequest(
|
||||||
jsonl_url="s3://bucket/train.jsonl",
|
jsonl_url="s3://bucket/train.jsonl",
|
||||||
@@ -144,11 +141,11 @@ async def test_get_finetune_status_includes_progress():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_finetune_status_raises_llm_call_error_on_sdk_failure():
|
async def test_get_finetune_status_raises_llm_call_error_on_failure():
|
||||||
from app.services.finetune_service import get_finetune_status
|
from app.services.finetune_service import get_finetune_status
|
||||||
|
|
||||||
llm = MagicMock()
|
llm = MagicMock(spec=LLMClient)
|
||||||
llm._client.fine_tuning.jobs.retrieve.side_effect = RuntimeError("SDK exploded")
|
llm.get_finetune_status = AsyncMock(side_effect=LLMCallError("查询微调任务失败: SDK exploded"))
|
||||||
|
|
||||||
with pytest.raises(LLMCallError):
|
with pytest.raises(LLMCallError):
|
||||||
await get_finetune_status("glm-ft-bad", llm)
|
await get_finetune_status("glm-ft-bad", llm)
|
||||||
|
|||||||
@@ -38,3 +38,44 @@ async def test_llm_call_error_on_sdk_exception(client):
|
|||||||
client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded")
|
client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded")
|
||||||
with pytest.raises(LLMCallError, match="大模型调用失败"):
|
with pytest.raises(LLMCallError, match="大模型调用失败"):
|
||||||
await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}])
|
await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_finetune_returns_job_id(client):
|
||||||
|
"""submit_finetune should call the SDK and return the job id."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.id = "glm-ft-newjob"
|
||||||
|
client._client.fine_tuning.jobs.create.return_value = resp
|
||||||
|
|
||||||
|
job_id = await client.submit_finetune(
|
||||||
|
jsonl_url="s3://bucket/train.jsonl",
|
||||||
|
base_model="glm-4",
|
||||||
|
hyperparams={"n_epochs": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert job_id == "glm-ft-newjob"
|
||||||
|
client._client.fine_tuning.jobs.create.assert_called_once_with(
|
||||||
|
training_file="s3://bucket/train.jsonl",
|
||||||
|
model="glm-4",
|
||||||
|
hyperparameters={"n_epochs": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_finetune_status_returns_correct_dict(client):
|
||||||
|
"""get_finetune_status should return a normalized dict with progress coerced to int."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.id = "glm-ft-abc"
|
||||||
|
resp.status = "running"
|
||||||
|
resp.progress = "75" # SDK may return string; should be coerced to int
|
||||||
|
resp.error_message = None
|
||||||
|
client._client.fine_tuning.jobs.retrieve.return_value = resp
|
||||||
|
|
||||||
|
result = await client.get_finetune_status("glm-ft-abc")
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"job_id": "glm-ft-abc",
|
||||||
|
"status": "running",
|
||||||
|
"progress": 75,
|
||||||
|
"error_message": None,
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user