Files
label_ai_service/tests/test_llm_client.py
wh 0880e1018c refactor: finetune through LLMClient interface + get_running_loop
- Add submit_finetune and get_finetune_status abstract methods to LLMClient base
- Implement both methods in ZhipuAIClient using asyncio.get_running_loop()
- Rewrite finetune_service to call llm.submit_finetune / llm.get_finetune_status
  instead of accessing llm._client directly, restoring interface encapsulation
- Replace asyncio.get_event_loop() with get_running_loop() in ZhipuAIClient._call
  and all four methods in RustFSClient (deprecated in Python 3.10+)
- Update test_finetune_service to mock the LLMClient interface methods as AsyncMocks
- Add two new tests in test_llm_client for submit_finetune and get_finetune_status
2026-04-10 16:43:28 +08:00

82 lines
2.6 KiB
Python

import pytest
from unittest.mock import MagicMock, patch
from app.clients.llm.zhipuai_client import ZhipuAIClient
from app.core.exceptions import LLMCallError
@pytest.fixture
def mock_sdk_response():
resp = MagicMock()
resp.choices[0].message.content = '{"result": "ok"}'
return resp
@pytest.fixture
def client():
with patch("app.clients.llm.zhipuai_client.ZhipuAI"):
c = ZhipuAIClient(api_key="test-key")
return c
@pytest.mark.asyncio
async def test_chat_returns_content(client, mock_sdk_response):
client._client.chat.completions.create.return_value = mock_sdk_response
result = await client.chat("glm-4-flash", [{"role": "user", "content": "hello"}])
assert result == '{"result": "ok"}'
@pytest.mark.asyncio
async def test_chat_vision_returns_content(client, mock_sdk_response):
client._client.chat.completions.create.return_value = mock_sdk_response
result = await client.chat_vision("glm-4v-flash", [{"role": "user", "content": []}])
assert result == '{"result": "ok"}'
@pytest.mark.asyncio
async def test_llm_call_error_on_sdk_exception(client):
client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded")
with pytest.raises(LLMCallError, match="大模型调用失败"):
await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}])
@pytest.mark.asyncio
async def test_submit_finetune_returns_job_id(client):
"""submit_finetune should call the SDK and return the job id."""
resp = MagicMock()
resp.id = "glm-ft-newjob"
client._client.fine_tuning.jobs.create.return_value = resp
job_id = await client.submit_finetune(
jsonl_url="s3://bucket/train.jsonl",
base_model="glm-4",
hyperparams={"n_epochs": 2},
)
assert job_id == "glm-ft-newjob"
client._client.fine_tuning.jobs.create.assert_called_once_with(
training_file="s3://bucket/train.jsonl",
model="glm-4",
hyperparameters={"n_epochs": 2},
)
@pytest.mark.asyncio
async def test_get_finetune_status_returns_correct_dict(client):
"""get_finetune_status should return a normalized dict with progress coerced to int."""
resp = MagicMock()
resp.id = "glm-ft-abc"
resp.status = "running"
resp.progress = "75" # SDK may return string; should be coerced to int
resp.error_message = None
client._client.fine_tuning.jobs.retrieve.return_value = resp
result = await client.get_finetune_status("glm-ft-abc")
assert result == {
"job_id": "glm-ft-abc",
"status": "running",
"progress": 75,
"error_message": None,
}