Files
label_backend/src/main/java/com/label/controller/ExportController.java

128 lines
5.5 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.label.controller;
import com.label.annotation.RequireRole;
import com.label.common.auth.TokenPrincipal;
import com.label.common.result.PageResult;
import com.label.common.result.Result;
import com.label.dto.ExportBatchCreateRequest;
import com.label.dto.FinetuneJobResponse;
import com.label.entity.TrainingDataset;
import com.label.entity.ExportBatch;
import com.label.service.ExportService;
import com.label.service.FinetuneService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.Map;
/**
* 训练数据导出与微调接口5 个端点,全部 ADMIN 权限)。
*/
@Tag(name = "导出管理", description = "训练样本查询、导出批次和微调任务")
@RestController
@RequestMapping("/label")
@RequiredArgsConstructor
public class ExportController {
private final ExportService exportService;
private final FinetuneService finetuneService;
/** GET /api/training/samples — 分页查询已审批可导出样本 */
@Operation(summary = "分页查询可导出训练样本")
@GetMapping("/api/training/samples")
@RequireRole("ADMIN")
public Result<PageResult<TrainingDataset>> listSamples(
@Parameter(description = "页码,从 1 开始", example = "1")
@RequestParam(defaultValue = "1") int page,
@Parameter(description = "每页条数", example = "20")
@RequestParam(defaultValue = "20") int pageSize,
@Parameter(description = "样本类型过滤可选值EXTRACTION、QA_GENERATION", example = "EXTRACTION")
@RequestParam(required = false) String sampleType,
@Parameter(description = "是否已导出过滤", example = "false")
@RequestParam(required = false) Boolean exported,
HttpServletRequest request) {
return Result.success(exportService.listSamples(page, pageSize, sampleType, exported, principal(request)));
}
/** POST /api/export/batch — 创建导出批次 */
@Operation(summary = "创建导出批次")
@PostMapping("/api/export/batch")
@RequireRole("ADMIN")
@ResponseStatus(HttpStatus.CREATED)
public Result<ExportBatch> createBatch(
@io.swagger.v3.oas.annotations.parameters.RequestBody(
description = "创建训练数据导出批次请求体",
required = true)
@RequestBody ExportBatchCreateRequest body,
HttpServletRequest request) {
return Result.success(exportService.createBatch(body.getSampleIds(), principal(request)));
}
/** POST /api/export/{batchId}/finetune — 提交微调任务 */
@Operation(summary = "提交微调任务")
@PostMapping("/api/export/{batchId}/finetune")
@RequireRole("ADMIN")
public Result<FinetuneJobResponse> triggerFinetune(
@Parameter(description = "导出批次 ID", example = "501")
@PathVariable Long batchId,
HttpServletRequest request) {
return Result.success(toFinetuneJobResponse(finetuneService.trigger(batchId, principal(request))));
}
/** GET /api/export/{batchId}/status — 查询微调状态 */
@Operation(summary = "查询微调状态")
@GetMapping("/api/export/{batchId}/status")
@RequireRole("ADMIN")
public Result<FinetuneJobResponse> getFinetuneStatus(
@Parameter(description = "导出批次 ID", example = "501")
@PathVariable Long batchId,
HttpServletRequest request) {
return Result.success(toFinetuneJobResponse(finetuneService.getStatus(batchId, principal(request))));
}
/** GET /api/export/list — 分页查询导出批次列表 */
@Operation(summary = "分页查询导出批次")
@GetMapping("/api/export/list")
@RequireRole("ADMIN")
public Result<PageResult<ExportBatch>> listBatches(
@Parameter(description = "页码,从 1 开始", example = "1")
@RequestParam(defaultValue = "1") int page,
@Parameter(description = "每页条数", example = "20")
@RequestParam(defaultValue = "20") int pageSize,
HttpServletRequest request) {
return Result.success(exportService.listBatches(page, pageSize, principal(request)));
}
private FinetuneJobResponse toFinetuneJobResponse(Map<String, Object> values) {
FinetuneJobResponse response = new FinetuneJobResponse();
response.setBatchId(asLong(values.get("batchId")));
response.setGlmJobId(asString(values.get("glmJobId")));
response.setFinetuneStatus(asString(values.get("finetuneStatus")));
response.setProgress(asInteger(values.get("progress")));
response.setErrorMessage(asString(values.get("errorMessage")));
return response;
}
private Long asLong(Object value) {
return value == null ? null : Long.parseLong(value.toString());
}
private Integer asInteger(Object value) {
return value == null ? null : Integer.parseInt(value.toString());
}
private String asString(Object value) {
return value == null ? null : value.toString();
}
private TokenPrincipal principal(HttpServletRequest request) {
return (TokenPrincipal) request.getAttribute("__token_principal__");
}
}