- Add submit_finetune and get_finetune_status abstract methods to LLMClient base - Implement both methods in ZhipuAIClient using asyncio.get_running_loop() - Rewrite finetune_service to call llm.submit_finetune / llm.get_finetune_status instead of accessing llm._client directly, restoring interface encapsulation - Replace asyncio.get_event_loop() with get_running_loop() in ZhipuAIClient._call and all four methods in RustFSClient (deprecated in Python 3.10+) - Update test_finetune_service to mock the LLMClient interface methods as AsyncMocks - Add two new tests in test_llm_client for submit_finetune and get_finetune_status
69 lines
2.5 KiB
Python
69 lines
2.5 KiB
Python
import asyncio
|
|
|
|
from zhipuai import ZhipuAI
|
|
|
|
from app.clients.llm.base import LLMClient
|
|
from app.core.exceptions import LLMCallError
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class ZhipuAIClient(LLMClient):
|
|
def __init__(self, api_key: str) -> None:
|
|
self._client = ZhipuAI(api_key=api_key)
|
|
|
|
async def chat(self, model: str, messages: list[dict]) -> str:
|
|
return await self._call(model, messages)
|
|
|
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
|
return await self._call(model, messages)
|
|
|
|
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
|
loop = asyncio.get_running_loop()
|
|
try:
|
|
resp = await loop.run_in_executor(
|
|
None,
|
|
lambda: self._client.fine_tuning.jobs.create(
|
|
training_file=jsonl_url,
|
|
model=base_model,
|
|
hyperparameters=hyperparams,
|
|
),
|
|
)
|
|
return resp.id
|
|
except Exception as exc:
|
|
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
|
|
|
|
async def get_finetune_status(self, job_id: str) -> dict:
|
|
loop = asyncio.get_running_loop()
|
|
try:
|
|
resp = await loop.run_in_executor(
|
|
None,
|
|
lambda: self._client.fine_tuning.jobs.retrieve(job_id),
|
|
)
|
|
return {
|
|
"job_id": resp.id,
|
|
"status": resp.status,
|
|
"progress": int(resp.progress) if getattr(resp, "progress", None) is not None else None,
|
|
"error_message": getattr(resp, "error_message", None),
|
|
}
|
|
except Exception as exc:
|
|
raise LLMCallError(f"查询微调任务失败: {exc}") from exc
|
|
|
|
async def _call(self, model: str, messages: list[dict]) -> str:
|
|
loop = asyncio.get_running_loop()
|
|
try:
|
|
response = await loop.run_in_executor(
|
|
None,
|
|
lambda: self._client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
),
|
|
)
|
|
content = response.choices[0].message.content
|
|
logger.info("llm_call", extra={"model": model, "response_len": len(content)})
|
|
return content
|
|
except Exception as exc:
|
|
logger.error("llm_call_error", extra={"model": model, "error": str(exc)})
|
|
raise LLMCallError(f"大模型调用失败: {exc}") from exc
|