133 lines
5.2 KiB
Python
133 lines
5.2 KiB
Python
|
|
from typing import Any
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
from app.clients.llm.base import LLMClient
|
||
|
|
from app.core.exceptions import LLMCallError
|
||
|
|
from app.core.logging import get_logger
|
||
|
|
|
||
|
|
logger = get_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class QwenClient(LLMClient):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
api_key: str,
|
||
|
|
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
|
|
fine_tune_base_url: str | None = None,
|
||
|
|
transport: httpx.BaseTransport | None = None,
|
||
|
|
) -> None:
|
||
|
|
self._api_key = api_key
|
||
|
|
self._base_url = base_url.rstrip("/")
|
||
|
|
self._fine_tune_base_url = (
|
||
|
|
fine_tune_base_url.rstrip("/")
|
||
|
|
if fine_tune_base_url
|
||
|
|
else self._base_url.replace("/compatible-mode/v1", "/api/v1")
|
||
|
|
)
|
||
|
|
self._transport = transport
|
||
|
|
|
||
|
|
async def chat(self, model: str, messages: list[dict]) -> str:
|
||
|
|
return await self._chat(model, messages)
|
||
|
|
|
||
|
|
async def chat_vision(self, model: str, messages: list[dict]) -> str:
|
||
|
|
return await self._chat(model, messages)
|
||
|
|
|
||
|
|
async def submit_finetune(self, jsonl_url: str, base_model: str, hyperparams: dict) -> str:
|
||
|
|
try:
|
||
|
|
file_bytes = await self._download_training_file(jsonl_url)
|
||
|
|
file_id = await self._upload_training_file(file_bytes)
|
||
|
|
payload = {
|
||
|
|
"model": base_model,
|
||
|
|
"training_file_ids": [file_id],
|
||
|
|
}
|
||
|
|
if hyperparams:
|
||
|
|
payload["hyper_parameters"] = hyperparams
|
||
|
|
data = await self._post_json(self._fine_tune_base_url, "/fine-tunes", payload)
|
||
|
|
output = data.get("output", {})
|
||
|
|
job_id = output.get("job_id") or data.get("job_id")
|
||
|
|
if not job_id:
|
||
|
|
raise LLMCallError("千问微调任务提交失败: 缺少 job_id")
|
||
|
|
return job_id
|
||
|
|
except LLMCallError:
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
raise LLMCallError(f"千问微调任务提交失败: {exc}") from exc
|
||
|
|
|
||
|
|
async def get_finetune_status(self, job_id: str) -> dict:
|
||
|
|
try:
|
||
|
|
data = await self._get_json(self._fine_tune_base_url, f"/fine-tunes/{job_id}")
|
||
|
|
output = data.get("output", {})
|
||
|
|
return {
|
||
|
|
"job_id": output.get("job_id") or job_id,
|
||
|
|
"status": output.get("status", "").lower(),
|
||
|
|
"progress": output.get("progress"),
|
||
|
|
"error_message": output.get("message"),
|
||
|
|
}
|
||
|
|
except LLMCallError:
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
raise LLMCallError(f"查询千问微调任务失败: {exc}") from exc
|
||
|
|
|
||
|
|
async def _chat(self, model: str, messages: list[dict]) -> str:
|
||
|
|
try:
|
||
|
|
data = await self._post_json(
|
||
|
|
self._base_url,
|
||
|
|
"/chat/completions",
|
||
|
|
{"model": model, "messages": messages},
|
||
|
|
)
|
||
|
|
content = data["choices"][0]["message"]["content"]
|
||
|
|
if isinstance(content, list):
|
||
|
|
return "".join(
|
||
|
|
part.get("text", "") if isinstance(part, dict) else str(part)
|
||
|
|
for part in content
|
||
|
|
)
|
||
|
|
logger.info("llm_call", extra={"model": model, "response_len": len(content)})
|
||
|
|
return content
|
||
|
|
except LLMCallError:
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
logger.error("llm_call_error", extra={"model": model, "error": str(exc)})
|
||
|
|
raise LLMCallError(f"千问大模型调用失败: {exc}") from exc
|
||
|
|
|
||
|
|
async def _download_training_file(self, jsonl_url: str) -> bytes:
|
||
|
|
async with self._build_client() as client:
|
||
|
|
response = await client.get(jsonl_url)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.content
|
||
|
|
|
||
|
|
async def _upload_training_file(self, file_bytes: bytes) -> str:
|
||
|
|
async with self._build_client(base_url=self._base_url) as client:
|
||
|
|
response = await client.post(
|
||
|
|
"/files",
|
||
|
|
data={"purpose": "fine-tune"},
|
||
|
|
files={"file": ("training.jsonl", file_bytes, "application/jsonl")},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
data = response.json()
|
||
|
|
file_id = data.get("id")
|
||
|
|
if not file_id:
|
||
|
|
raise LLMCallError("千问训练文件上传失败: 缺少 file id")
|
||
|
|
return file_id
|
||
|
|
|
||
|
|
async def _post_json(self, base_url: str, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
async with self._build_client(base_url=base_url) as client:
|
||
|
|
response = await client.post(path, json=payload)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def _get_json(self, base_url: str, path: str) -> dict[str, Any]:
|
||
|
|
async with self._build_client(base_url=base_url) as client:
|
||
|
|
response = await client.get(path)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
def _build_client(self, base_url: str | None = None) -> httpx.AsyncClient:
|
||
|
|
return httpx.AsyncClient(
|
||
|
|
base_url=base_url or self._base_url,
|
||
|
|
headers={
|
||
|
|
"Authorization": f"Bearer {self._api_key}",
|
||
|
|
},
|
||
|
|
transport=self._transport,
|
||
|
|
timeout=60,
|
||
|
|
)
|