import pytest from unittest.mock import MagicMock, patch from app.clients.llm.zhipuai_client import ZhipuAIClient from app.core.exceptions import LLMCallError @pytest.fixture def mock_sdk_response(): resp = MagicMock() resp.choices[0].message.content = '{"result": "ok"}' return resp @pytest.fixture def client(): with patch("app.clients.llm.zhipuai_client.ZhipuAI"): c = ZhipuAIClient(api_key="test-key") return c @pytest.mark.asyncio async def test_chat_returns_content(client, mock_sdk_response): client._client.chat.completions.create.return_value = mock_sdk_response result = await client.chat("glm-4-flash", [{"role": "user", "content": "hello"}]) assert result == '{"result": "ok"}' @pytest.mark.asyncio async def test_chat_vision_returns_content(client, mock_sdk_response): client._client.chat.completions.create.return_value = mock_sdk_response result = await client.chat_vision("glm-4v-flash", [{"role": "user", "content": []}]) assert result == '{"result": "ok"}' @pytest.mark.asyncio async def test_llm_call_error_on_sdk_exception(client): client._client.chat.completions.create.side_effect = RuntimeError("quota exceeded") with pytest.raises(LLMCallError, match="大模型调用失败"): await client.chat("glm-4-flash", [{"role": "user", "content": "hi"}]) @pytest.mark.asyncio async def test_submit_finetune_returns_job_id(client): """submit_finetune should call the SDK and return the job id.""" resp = MagicMock() resp.id = "glm-ft-newjob" client._client.fine_tuning.jobs.create.return_value = resp job_id = await client.submit_finetune( jsonl_url="s3://bucket/train.jsonl", base_model="glm-4", hyperparams={"n_epochs": 2}, ) assert job_id == "glm-ft-newjob" client._client.fine_tuning.jobs.create.assert_called_once_with( training_file="s3://bucket/train.jsonl", model="glm-4", hyperparameters={"n_epochs": 2}, ) @pytest.mark.asyncio async def test_get_finetune_status_returns_correct_dict(client): """get_finetune_status should return a normalized dict with progress coerced to int.""" resp = MagicMock() resp.id = "glm-ft-abc" resp.status = "running" resp.progress = "75" # SDK may return string; should be coerced to int resp.error_message = None client._client.fine_tuning.jobs.retrieve.return_value = resp result = await client.get_finetune_status("glm-ft-abc") assert result == { "job_id": "glm-ft-abc", "status": "running", "progress": 75, "error_message": None, }