62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
|
|
import asyncio
|
||
|
|
|
||
|
|
from app.core.exceptions import LLMCallError
|
||
|
|
from app.core.logging import get_logger
|
||
|
|
from app.models.finetune_models import (
|
||
|
|
FinetuneStartRequest,
|
||
|
|
FinetuneStartResponse,
|
||
|
|
FinetuneStatusResponse,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
_STATUS_MAP = {
|
||
|
|
"running": "RUNNING",
|
||
|
|
"succeeded": "SUCCESS",
|
||
|
|
"failed": "FAILED",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
async def submit_finetune(req: FinetuneStartRequest, llm) -> FinetuneStartResponse:
|
||
|
|
"""Submit a fine-tune job to ZhipuAI and return the job ID."""
|
||
|
|
loop = asyncio.get_event_loop()
|
||
|
|
try:
|
||
|
|
response = await loop.run_in_executor(
|
||
|
|
None,
|
||
|
|
lambda: llm._client.fine_tuning.jobs.create(
|
||
|
|
training_file=req.jsonl_url,
|
||
|
|
model=req.base_model,
|
||
|
|
hyperparameters=req.hyperparams or {},
|
||
|
|
),
|
||
|
|
)
|
||
|
|
job_id = response.id
|
||
|
|
logger.info("finetune_submit", extra={"job_id": job_id, "model": req.base_model})
|
||
|
|
return FinetuneStartResponse(job_id=job_id)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error("finetune_submit_error", extra={"error": str(exc)})
|
||
|
|
raise LLMCallError(f"微调任务提交失败: {exc}") from exc
|
||
|
|
|
||
|
|
|
||
|
|
async def get_finetune_status(job_id: str, llm) -> FinetuneStatusResponse:
|
||
|
|
"""Retrieve fine-tune job status from ZhipuAI."""
|
||
|
|
loop = asyncio.get_event_loop()
|
||
|
|
try:
|
||
|
|
response = await loop.run_in_executor(
|
||
|
|
None,
|
||
|
|
lambda: llm._client.fine_tuning.jobs.retrieve(job_id),
|
||
|
|
)
|
||
|
|
status_raw = response.status
|
||
|
|
status = _STATUS_MAP.get(status_raw, "RUNNING") # conservative fallback
|
||
|
|
progress = getattr(response, "progress", None)
|
||
|
|
error_message = getattr(response, "error_message", None)
|
||
|
|
logger.info("finetune_status", extra={"job_id": job_id, "status": status})
|
||
|
|
return FinetuneStatusResponse(
|
||
|
|
job_id=job_id,
|
||
|
|
status=status,
|
||
|
|
progress=progress,
|
||
|
|
error_message=error_message,
|
||
|
|
)
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error("finetune_status_error", extra={"job_id": job_id, "error": str(exc)})
|
||
|
|
raise LLMCallError(f"微调状态查询失败: {exc}") from exc
|