- Add submit_finetune and get_finetune_status abstract methods to LLMClient base - Implement both methods in ZhipuAIClient using asyncio.get_running_loop() - Rewrite finetune_service to call llm.submit_finetune / llm.get_finetune_status instead of accessing llm._client directly, restoring interface encapsulation - Replace asyncio.get_event_loop() with get_running_loop() in ZhipuAIClient._call and all four methods in RustFSClient (deprecated in Python 3.10+) - Update test_finetune_service to mock the LLMClient interface methods as AsyncMocks - Add two new tests in test_llm_client for submit_finetune and get_finetune_status
82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from app.clients.llm.zhipuai_client import ZhipuAIClient
|
|
from app.core.exceptions import LLMCallError
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_sdk_response():
|
|
resp = MagicMock()
|
|
resp.choices[0].message.content = '{"result": "ok"}'
|
|
return resp
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
with patch("app.clients.llm.zhipuai_client.ZhipuAI"):
|
|
c = ZhipuAIClient(api_key="test-key")
|
|
return c
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_returns_content(client, mock_sdk_response):
|
|
client._client.chat.completions.create.return_value = mock_sdk_response
|
|
result = await client.chat("glm-4-flash", [{"role": "user", "content": "hello"}])
|
|
assert result == '{"result": "ok"}'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_vision_returns_content(client, mock_sdk_response):
|
|
client._client.chat.completions.create.return_value = mock_sdk_response
|
|
result = await client.chat_vision("glm-4v-flash", [{"role": "user", "content": []}])
|
|
assert result == '{"result": "ok"}'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_call_error_on_sdk_exception(client):
|
|
client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded")
|
|
with pytest.raises(LLMCallError, match="大模型调用失败"):
|
|
await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_finetune_returns_job_id(client):
|
|
"""submit_finetune should call the SDK and return the job id."""
|
|
resp = MagicMock()
|
|
resp.id = "glm-ft-newjob"
|
|
client._client.fine_tuning.jobs.create.return_value = resp
|
|
|
|
job_id = await client.submit_finetune(
|
|
jsonl_url="s3://bucket/train.jsonl",
|
|
base_model="glm-4",
|
|
hyperparams={"n_epochs": 2},
|
|
)
|
|
|
|
assert job_id == "glm-ft-newjob"
|
|
client._client.fine_tuning.jobs.create.assert_called_once_with(
|
|
training_file="s3://bucket/train.jsonl",
|
|
model="glm-4",
|
|
hyperparameters={"n_epochs": 2},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_finetune_status_returns_correct_dict(client):
|
|
"""get_finetune_status should return a normalized dict with progress coerced to int."""
|
|
resp = MagicMock()
|
|
resp.id = "glm-ft-abc"
|
|
resp.status = "running"
|
|
resp.progress = "75" # SDK may return string; should be coerced to int
|
|
resp.error_message = None
|
|
client._client.fine_tuning.jobs.retrieve.return_value = resp
|
|
|
|
result = await client.get_finetune_status("glm-ft-abc")
|
|
|
|
assert result == {
|
|
"job_id": "glm-ft-abc",
|
|
"status": "running",
|
|
"progress": 75,
|
|
"error_message": None,
|
|
}
|