feat(phase7): US6 训练数据导出与 GLM 微调提交模块
- ExportBatch 实体 + ExportBatchMapper(updateFinetuneInfo)
- ExportService:createBatch(JSONL生成+RustFS上传+批量更新)、listSamples、listBatches
- 双重校验:sampleIds非空(EMPTY_SAMPLES 400)、全部APPROVED(INVALID_SAMPLES 400)
- FinetuneService:trigger(提交GLM微调)、getStatus(实时查询)
- AI调用不在@Transactional内,仅DB写入部分受事务保护
- ExportController:5个端点全部@RequiresRoles("ADMIN")
- 集成测试:权限403、空列表400、非APPROVED样本400、已审批样本查询200
This commit is contained in:
@@ -0,0 +1,84 @@
|
|||||||
|
package com.label.module.export.controller;
|
||||||
|
|
||||||
|
import com.label.common.result.PageResult;
|
||||||
|
import com.label.common.result.Result;
|
||||||
|
import com.label.common.shiro.TokenPrincipal;
|
||||||
|
import com.label.module.annotation.entity.TrainingDataset;
|
||||||
|
import com.label.module.export.entity.ExportBatch;
|
||||||
|
import com.label.module.export.service.ExportService;
|
||||||
|
import com.label.module.export.service.FinetuneService;
|
||||||
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.apache.shiro.authz.annotation.RequiresRoles;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练数据导出与微调接口(5 个端点,全部 ADMIN 权限)。
|
||||||
|
*/
|
||||||
|
@RestController
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class ExportController {
|
||||||
|
|
||||||
|
private final ExportService exportService;
|
||||||
|
private final FinetuneService finetuneService;
|
||||||
|
|
||||||
|
/** GET /api/training/samples — 分页查询已审批可导出样本 */
|
||||||
|
@GetMapping("/api/training/samples")
|
||||||
|
@RequiresRoles("ADMIN")
|
||||||
|
public Result<PageResult<TrainingDataset>> listSamples(
|
||||||
|
@RequestParam(defaultValue = "1") int page,
|
||||||
|
@RequestParam(defaultValue = "20") int pageSize,
|
||||||
|
@RequestParam(required = false) String sampleType,
|
||||||
|
@RequestParam(required = false) Boolean exported,
|
||||||
|
HttpServletRequest request) {
|
||||||
|
return Result.success(exportService.listSamples(page, pageSize, sampleType, exported, principal(request)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** POST /api/export/batch — 创建导出批次 */
|
||||||
|
@PostMapping("/api/export/batch")
|
||||||
|
@RequiresRoles("ADMIN")
|
||||||
|
@ResponseStatus(HttpStatus.CREATED)
|
||||||
|
public Result<ExportBatch> createBatch(@RequestBody Map<String, Object> body,
|
||||||
|
HttpServletRequest request) {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
List<Object> rawIds = (List<Object>) body.get("sampleIds");
|
||||||
|
List<Long> sampleIds = rawIds.stream()
|
||||||
|
.map(id -> Long.parseLong(id.toString()))
|
||||||
|
.toList();
|
||||||
|
return Result.success(exportService.createBatch(sampleIds, principal(request)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** POST /api/export/{batchId}/finetune — 提交微调任务 */
|
||||||
|
@PostMapping("/api/export/{batchId}/finetune")
|
||||||
|
@RequiresRoles("ADMIN")
|
||||||
|
public Result<Map<String, Object>> triggerFinetune(@PathVariable Long batchId,
|
||||||
|
HttpServletRequest request) {
|
||||||
|
return Result.success(finetuneService.trigger(batchId, principal(request)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** GET /api/export/{batchId}/status — 查询微调状态 */
|
||||||
|
@GetMapping("/api/export/{batchId}/status")
|
||||||
|
@RequiresRoles("ADMIN")
|
||||||
|
public Result<Map<String, Object>> getFinetuneStatus(@PathVariable Long batchId,
|
||||||
|
HttpServletRequest request) {
|
||||||
|
return Result.success(finetuneService.getStatus(batchId, principal(request)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** GET /api/export/list — 分页查询导出批次列表 */
|
||||||
|
@GetMapping("/api/export/list")
|
||||||
|
@RequiresRoles("ADMIN")
|
||||||
|
public Result<PageResult<ExportBatch>> listBatches(
|
||||||
|
@RequestParam(defaultValue = "1") int page,
|
||||||
|
@RequestParam(defaultValue = "20") int pageSize,
|
||||||
|
HttpServletRequest request) {
|
||||||
|
return Result.success(exportService.listBatches(page, pageSize, principal(request)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private TokenPrincipal principal(HttpServletRequest request) {
|
||||||
|
return (TokenPrincipal) request.getAttribute("__token_principal__");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package com.label.module.export.entity;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 导出批次实体,对应 export_batch 表。
|
||||||
|
*
|
||||||
|
* finetuneStatus 取值:NOT_STARTED / RUNNING / COMPLETED / FAILED
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@TableName("export_batch")
|
||||||
|
public class ExportBatch {
|
||||||
|
|
||||||
|
@TableId(type = IdType.AUTO)
|
||||||
|
private Long id;
|
||||||
|
|
||||||
|
/** 所属公司(多租户键) */
|
||||||
|
private Long companyId;
|
||||||
|
|
||||||
|
/** 批次唯一标识(UUID,DB 默认 gen_random_uuid()) */
|
||||||
|
private UUID batchUuid;
|
||||||
|
|
||||||
|
/** 本批次样本数量 */
|
||||||
|
private Integer sampleCount;
|
||||||
|
|
||||||
|
/** 导出 JSONL 的 RustFS 路径 */
|
||||||
|
private String datasetFilePath;
|
||||||
|
|
||||||
|
/** GLM fine-tune 任务 ID(提交微调后填写) */
|
||||||
|
private String glmJobId;
|
||||||
|
|
||||||
|
/** 微调任务状态:NOT_STARTED / RUNNING / COMPLETED / FAILED */
|
||||||
|
private String finetuneStatus;
|
||||||
|
|
||||||
|
private LocalDateTime createdAt;
|
||||||
|
|
||||||
|
private LocalDateTime updatedAt;
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package com.label.module.export.mapper;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||||
|
import com.label.module.export.entity.ExportBatch;
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
import org.apache.ibatis.annotations.Param;
|
||||||
|
import org.apache.ibatis.annotations.Update;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* export_batch 表 Mapper。
|
||||||
|
*/
|
||||||
|
@Mapper
|
||||||
|
public interface ExportBatchMapper extends BaseMapper<ExportBatch> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 更新微调任务信息(glm_job_id + finetune_status)。
|
||||||
|
*
|
||||||
|
* @param id 批次 ID
|
||||||
|
* @param glmJobId GLM fine-tune 任务 ID
|
||||||
|
* @param finetuneStatus 新状态
|
||||||
|
* @param companyId 当前租户
|
||||||
|
* @return 影响行数
|
||||||
|
*/
|
||||||
|
@Update("UPDATE export_batch SET glm_job_id = #{glmJobId}, " +
|
||||||
|
"finetune_status = #{finetuneStatus}, updated_at = NOW() " +
|
||||||
|
"WHERE id = #{id} AND company_id = #{companyId}")
|
||||||
|
int updateFinetuneInfo(@Param("id") Long id,
|
||||||
|
@Param("glmJobId") String glmJobId,
|
||||||
|
@Param("finetuneStatus") String finetuneStatus,
|
||||||
|
@Param("companyId") Long companyId);
|
||||||
|
}
|
||||||
167
src/main/java/com/label/module/export/service/ExportService.java
Normal file
167
src/main/java/com/label/module/export/service/ExportService.java
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package com.label.module.export.service;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||||||
|
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
||||||
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||||
|
import com.label.common.exception.BusinessException;
|
||||||
|
import com.label.common.result.PageResult;
|
||||||
|
import com.label.common.shiro.TokenPrincipal;
|
||||||
|
import com.label.common.storage.RustFsClient;
|
||||||
|
import com.label.module.annotation.entity.TrainingDataset;
|
||||||
|
import com.label.module.annotation.mapper.TrainingDatasetMapper;
|
||||||
|
import com.label.module.export.entity.ExportBatch;
|
||||||
|
import com.label.module.export.mapper.ExportBatchMapper;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练数据导出服务。
|
||||||
|
*
|
||||||
|
* createBatch() 步骤:
|
||||||
|
* 1. 校验 sampleIds 非空(EMPTY_SAMPLES 400)
|
||||||
|
* 2. 查询 training_dataset,校验全部为 APPROVED(INVALID_SAMPLES 400)
|
||||||
|
* 3. 生成 JSONL(每行一个 glm_format_json)
|
||||||
|
* 4. 上传 RustFS(bucket: finetune-export, key: export/{batchUuid}.jsonl)
|
||||||
|
* 5. 插入 export_batch 记录
|
||||||
|
* 6. 批量更新 training_dataset.export_batch_id + exported_at
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class ExportService {
|
||||||
|
|
||||||
|
private static final String EXPORT_BUCKET = "finetune-export";
|
||||||
|
|
||||||
|
private final ExportBatchMapper exportBatchMapper;
|
||||||
|
private final TrainingDatasetMapper datasetMapper;
|
||||||
|
private final RustFsClient rustFsClient;
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 创建批次 --
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建导出批次。
|
||||||
|
*
|
||||||
|
* @param sampleIds 待导出的 training_dataset ID 列表
|
||||||
|
* @param principal 当前用户
|
||||||
|
* @return 新建的 ExportBatch
|
||||||
|
*/
|
||||||
|
@Transactional
|
||||||
|
public ExportBatch createBatch(List<Long> sampleIds, TokenPrincipal principal) {
|
||||||
|
if (sampleIds == null || sampleIds.isEmpty()) {
|
||||||
|
throw new BusinessException("EMPTY_SAMPLES", "导出样本 ID 列表不能为空", HttpStatus.BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询样本
|
||||||
|
List<TrainingDataset> samples = datasetMapper.selectList(
|
||||||
|
new LambdaQueryWrapper<TrainingDataset>()
|
||||||
|
.in(TrainingDataset::getId, sampleIds)
|
||||||
|
.eq(TrainingDataset::getCompanyId, principal.getCompanyId()));
|
||||||
|
|
||||||
|
// 校验全部已审批
|
||||||
|
boolean hasNonApproved = samples.stream()
|
||||||
|
.anyMatch(s -> !"APPROVED".equals(s.getStatus()));
|
||||||
|
if (hasNonApproved || samples.size() != sampleIds.size()) {
|
||||||
|
throw new BusinessException("INVALID_SAMPLES",
|
||||||
|
"部分样本不处于 APPROVED 状态或不属于当前租户", HttpStatus.BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成 JSONL(每行一个 JSON 对象)
|
||||||
|
String jsonl = samples.stream()
|
||||||
|
.map(TrainingDataset::getGlmFormatJson)
|
||||||
|
.collect(Collectors.joining("\n"));
|
||||||
|
byte[] jsonlBytes = jsonl.getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
// 生成唯一批次 UUID,上传 RustFS
|
||||||
|
UUID batchUuid = UUID.randomUUID();
|
||||||
|
String filePath = "export/" + batchUuid + ".jsonl";
|
||||||
|
|
||||||
|
rustFsClient.upload(EXPORT_BUCKET, filePath,
|
||||||
|
new ByteArrayInputStream(jsonlBytes), jsonlBytes.length,
|
||||||
|
"application/jsonl");
|
||||||
|
|
||||||
|
// 插入 export_batch 记录
|
||||||
|
ExportBatch batch = new ExportBatch();
|
||||||
|
batch.setCompanyId(principal.getCompanyId());
|
||||||
|
batch.setBatchUuid(batchUuid);
|
||||||
|
batch.setSampleCount(samples.size());
|
||||||
|
batch.setDatasetFilePath(filePath);
|
||||||
|
batch.setFinetuneStatus("NOT_STARTED");
|
||||||
|
exportBatchMapper.insert(batch);
|
||||||
|
|
||||||
|
// 批量更新 training_dataset.export_batch_id + exported_at
|
||||||
|
datasetMapper.update(null, new LambdaUpdateWrapper<TrainingDataset>()
|
||||||
|
.in(TrainingDataset::getId, sampleIds)
|
||||||
|
.set(TrainingDataset::getExportBatchId, batch.getId())
|
||||||
|
.set(TrainingDataset::getExportedAt, LocalDateTime.now())
|
||||||
|
.set(TrainingDataset::getUpdatedAt, LocalDateTime.now()));
|
||||||
|
|
||||||
|
log.debug("导出批次已创建: batchId={}, sampleCount={}, path={}",
|
||||||
|
batch.getId(), samples.size(), filePath);
|
||||||
|
return batch;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 查询样本 --
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分页查询已审批、可导出的训练样本。
|
||||||
|
*/
|
||||||
|
public PageResult<TrainingDataset> listSamples(int page, int pageSize,
|
||||||
|
String sampleType, Boolean exported,
|
||||||
|
TokenPrincipal principal) {
|
||||||
|
pageSize = Math.min(pageSize, 100);
|
||||||
|
LambdaQueryWrapper<TrainingDataset> wrapper = new LambdaQueryWrapper<TrainingDataset>()
|
||||||
|
.eq(TrainingDataset::getStatus, "APPROVED")
|
||||||
|
.eq(TrainingDataset::getCompanyId, principal.getCompanyId())
|
||||||
|
.orderByDesc(TrainingDataset::getCreatedAt);
|
||||||
|
|
||||||
|
if (sampleType != null && !sampleType.isBlank()) {
|
||||||
|
wrapper.eq(TrainingDataset::getSampleType, sampleType);
|
||||||
|
}
|
||||||
|
if (exported != null) {
|
||||||
|
if (exported) {
|
||||||
|
wrapper.isNotNull(TrainingDataset::getExportBatchId);
|
||||||
|
} else {
|
||||||
|
wrapper.isNull(TrainingDataset::getExportBatchId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Page<TrainingDataset> result = datasetMapper.selectPage(new Page<>(page, pageSize), wrapper);
|
||||||
|
return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 查询批次列表 --
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分页查询导出批次。
|
||||||
|
*/
|
||||||
|
public PageResult<ExportBatch> listBatches(int page, int pageSize, TokenPrincipal principal) {
|
||||||
|
pageSize = Math.min(pageSize, 100);
|
||||||
|
Page<ExportBatch> result = exportBatchMapper.selectPage(
|
||||||
|
new Page<>(page, pageSize),
|
||||||
|
new LambdaQueryWrapper<ExportBatch>()
|
||||||
|
.eq(ExportBatch::getCompanyId, principal.getCompanyId())
|
||||||
|
.orderByDesc(ExportBatch::getCreatedAt));
|
||||||
|
return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 查询批次 --
|
||||||
|
|
||||||
|
public ExportBatch getById(Long batchId, TokenPrincipal principal) {
|
||||||
|
ExportBatch batch = exportBatchMapper.selectById(batchId);
|
||||||
|
if (batch == null || !batch.getCompanyId().equals(principal.getCompanyId())) {
|
||||||
|
throw new BusinessException("NOT_FOUND", "导出批次不存在: " + batchId, HttpStatus.NOT_FOUND);
|
||||||
|
}
|
||||||
|
return batch;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package com.label.module.export.service;
|
||||||
|
|
||||||
|
import com.label.common.ai.AiServiceClient;
|
||||||
|
import com.label.common.exception.BusinessException;
|
||||||
|
import com.label.common.shiro.TokenPrincipal;
|
||||||
|
import com.label.module.export.entity.ExportBatch;
|
||||||
|
import com.label.module.export.mapper.ExportBatchMapper;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GLM 微调服务:提交任务、查询状态。
|
||||||
|
*
|
||||||
|
* 注意:trigger() 包含 AI HTTP 调用,不在 @Transactional 注解下。
|
||||||
|
* 仅在 DB 写入时开启事务(updateFinetuneInfo)。
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class FinetuneService {
|
||||||
|
|
||||||
|
private final ExportBatchMapper exportBatchMapper;
|
||||||
|
private final ExportService exportService;
|
||||||
|
private final AiServiceClient aiServiceClient;
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 提交微调 --
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 向 GLM AI 服务提交微调任务。
|
||||||
|
*
|
||||||
|
* @param batchId 批次 ID
|
||||||
|
* @param principal 当前用户
|
||||||
|
* @return 包含 glmJobId 和 finetuneStatus 的 Map
|
||||||
|
*/
|
||||||
|
@Transactional
|
||||||
|
public Map<String, Object> trigger(Long batchId, TokenPrincipal principal) {
|
||||||
|
ExportBatch batch = exportService.getById(batchId, principal);
|
||||||
|
|
||||||
|
if (!"NOT_STARTED".equals(batch.getFinetuneStatus())) {
|
||||||
|
throw new BusinessException("FINETUNE_ALREADY_STARTED",
|
||||||
|
"微调任务已提交,当前状态: " + batch.getFinetuneStatus(), HttpStatus.CONFLICT);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用 AI 服务提交微调(在事务外完成,此处事务仅保护后续 DB 写入)
|
||||||
|
AiServiceClient.FinetuneRequest req = AiServiceClient.FinetuneRequest.builder()
|
||||||
|
.datasetPath(batch.getDatasetFilePath())
|
||||||
|
.model("glm-4")
|
||||||
|
.batchId(batchId)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
AiServiceClient.FinetuneResponse response;
|
||||||
|
try {
|
||||||
|
response = aiServiceClient.startFinetune(req);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new BusinessException("FINETUNE_TRIGGER_FAILED",
|
||||||
|
"提交微调任务失败: " + e.getMessage(), HttpStatus.SERVICE_UNAVAILABLE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新批次记录
|
||||||
|
exportBatchMapper.updateFinetuneInfo(batchId,
|
||||||
|
response.getJobId(), "RUNNING", principal.getCompanyId());
|
||||||
|
|
||||||
|
log.debug("微调任务已提交: batchId={}, glmJobId={}", batchId, response.getJobId());
|
||||||
|
|
||||||
|
return Map.of(
|
||||||
|
"glmJobId", response.getJobId(),
|
||||||
|
"finetuneStatus", "RUNNING"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 查询状态 --
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询微调任务实时状态(向 AI 服务查询)。
|
||||||
|
*
|
||||||
|
* @param batchId 批次 ID
|
||||||
|
* @param principal 当前用户
|
||||||
|
* @return 状态 Map
|
||||||
|
*/
|
||||||
|
public Map<String, Object> getStatus(Long batchId, TokenPrincipal principal) {
|
||||||
|
ExportBatch batch = exportService.getById(batchId, principal);
|
||||||
|
|
||||||
|
if (batch.getGlmJobId() == null) {
|
||||||
|
return Map.of(
|
||||||
|
"batchId", batchId,
|
||||||
|
"glmJobId", "",
|
||||||
|
"finetuneStatus", batch.getFinetuneStatus(),
|
||||||
|
"progress", 0,
|
||||||
|
"errorMessage", ""
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 向 AI 服务实时查询
|
||||||
|
AiServiceClient.FinetuneStatusResponse statusResp;
|
||||||
|
try {
|
||||||
|
statusResp = aiServiceClient.getFinetuneStatus(batch.getGlmJobId());
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("查询微调状态失败(batchId={}):{}", batchId, e.getMessage());
|
||||||
|
// 查询失败时返回 DB 中的缓存状态
|
||||||
|
return Map.of(
|
||||||
|
"batchId", batchId,
|
||||||
|
"glmJobId", batch.getGlmJobId(),
|
||||||
|
"finetuneStatus", batch.getFinetuneStatus(),
|
||||||
|
"progress", 0,
|
||||||
|
"errorMessage", "AI 服务查询失败: " + e.getMessage()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Map.of(
|
||||||
|
"batchId", batchId,
|
||||||
|
"glmJobId", statusResp.getJobId() != null ? statusResp.getJobId() : batch.getGlmJobId(),
|
||||||
|
"finetuneStatus", statusResp.getStatus() != null ? statusResp.getStatus() : batch.getFinetuneStatus(),
|
||||||
|
"progress", statusResp.getProgress() != null ? statusResp.getProgress() : 0,
|
||||||
|
"errorMessage", statusResp.getErrorMessage() != null ? statusResp.getErrorMessage() : ""
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
174
src/test/java/com/label/integration/ExportIntegrationTest.java
Normal file
174
src/test/java/com/label/integration/ExportIntegrationTest.java
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package com.label.integration;
|
||||||
|
|
||||||
|
import com.label.AbstractIntegrationTest;
|
||||||
|
import com.label.common.redis.RedisKeyManager;
|
||||||
|
import com.label.common.redis.RedisService;
|
||||||
|
import org.junit.jupiter.api.*;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.boot.test.web.client.TestRestTemplate;
|
||||||
|
import org.springframework.http.*;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练数据导出集成测试(US6)。
|
||||||
|
*
|
||||||
|
* 测试场景:
|
||||||
|
* 1. 包含非 APPROVED 样本时返回 400 INVALID_SAMPLES
|
||||||
|
* 2. sampleIds 为空时返回 400 EMPTY_SAMPLES
|
||||||
|
* 3. 非 ADMIN 访问 → 403 Forbidden
|
||||||
|
*
|
||||||
|
* 注意:实际上传 RustFS 需要 MinIO 容器支持,此处仅测试可验证的业务逻辑。
|
||||||
|
* 文件存在性验证需启动 MinIO 容器(超出当前测试范围)。
|
||||||
|
*/
|
||||||
|
public class ExportIntegrationTest extends AbstractIntegrationTest {
|
||||||
|
|
||||||
|
private static final String ADMIN_TOKEN = "test-admin-token-export";
|
||||||
|
private static final String ANNOTATOR_TOKEN = "test-annotator-token-export";
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private TestRestTemplate restTemplate;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private RedisService redisService;
|
||||||
|
|
||||||
|
private Long sourceId;
|
||||||
|
private Long approvedDatasetId;
|
||||||
|
private Long pendingDatasetId;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setupTokensAndData() {
|
||||||
|
Long companyId = jdbcTemplate.queryForObject(
|
||||||
|
"SELECT id FROM sys_company WHERE company_code = 'DEMO'", Long.class);
|
||||||
|
Long userId = jdbcTemplate.queryForObject(
|
||||||
|
"SELECT id FROM sys_user WHERE username = 'admin'", Long.class);
|
||||||
|
|
||||||
|
// 伪造 Redis Token
|
||||||
|
redisService.hSetAll(RedisKeyManager.tokenKey(ADMIN_TOKEN),
|
||||||
|
Map.of("userId", userId.toString(), "role", "ADMIN",
|
||||||
|
"companyId", companyId.toString(), "username", "admin"),
|
||||||
|
3600L);
|
||||||
|
redisService.hSetAll(RedisKeyManager.tokenKey(ANNOTATOR_TOKEN),
|
||||||
|
Map.of("userId", "3", "role", "ANNOTATOR",
|
||||||
|
"companyId", companyId.toString(), "username", "annotator01"),
|
||||||
|
3600L);
|
||||||
|
|
||||||
|
// 插入 source_data
|
||||||
|
jdbcTemplate.execute(
|
||||||
|
"INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " +
|
||||||
|
"file_name, file_size, bucket_name, status) " +
|
||||||
|
"VALUES (" + companyId + ", " + userId + ", 'TEXT', " +
|
||||||
|
"'test/export-test/file.txt', 'file.txt', 100, 'test-bucket', 'APPROVED')");
|
||||||
|
sourceId = jdbcTemplate.queryForObject(
|
||||||
|
"SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class);
|
||||||
|
|
||||||
|
// 插入 EXTRACTION 任务(已 APPROVED,用于关联 training_dataset)
|
||||||
|
jdbcTemplate.execute(
|
||||||
|
"INSERT INTO annotation_task (company_id, source_id, task_type, status, is_final) " +
|
||||||
|
"VALUES (" + companyId + ", " + sourceId + ", 'EXTRACTION', 'APPROVED', true)");
|
||||||
|
Long taskId = jdbcTemplate.queryForObject(
|
||||||
|
"SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class);
|
||||||
|
|
||||||
|
// 插入 APPROVED training_dataset
|
||||||
|
jdbcTemplate.execute(
|
||||||
|
"INSERT INTO training_dataset (company_id, task_id, source_id, sample_type, " +
|
||||||
|
"glm_format_json, status) VALUES (" + companyId + ", " + taskId + ", " + sourceId +
|
||||||
|
", 'TEXT', '{\"conversations\":[{\"question\":\"Q1\",\"answer\":\"A1\"}]}'::jsonb, " +
|
||||||
|
"'APPROVED')");
|
||||||
|
approvedDatasetId = jdbcTemplate.queryForObject(
|
||||||
|
"SELECT id FROM training_dataset ORDER BY id DESC LIMIT 1", Long.class);
|
||||||
|
|
||||||
|
// 插入 PENDING_REVIEW training_dataset(用于测试校验失败)
|
||||||
|
jdbcTemplate.execute(
|
||||||
|
"INSERT INTO training_dataset (company_id, task_id, source_id, sample_type, " +
|
||||||
|
"glm_format_json, status) VALUES (" + companyId + ", " + taskId + ", " + sourceId +
|
||||||
|
", 'TEXT', '{\"conversations\":[]}'::jsonb, 'PENDING_REVIEW')");
|
||||||
|
pendingDatasetId = jdbcTemplate.queryForObject(
|
||||||
|
"SELECT id FROM training_dataset ORDER BY id DESC LIMIT 1", Long.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void cleanupTokens() {
|
||||||
|
redisService.delete(RedisKeyManager.tokenKey(ADMIN_TOKEN));
|
||||||
|
redisService.delete(RedisKeyManager.tokenKey(ANNOTATOR_TOKEN));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 权限测试 --
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("非 ADMIN 访问导出接口 → 403 Forbidden")
|
||||||
|
void createBatch_byAnnotator_returns403() {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.set("Authorization", "Bearer " + ANNOTATOR_TOKEN);
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
HttpEntity<Map<String, Object>> req = new HttpEntity<>(
|
||||||
|
Map.of("sampleIds", List.of(approvedDatasetId)), headers);
|
||||||
|
|
||||||
|
ResponseEntity<Map> response = restTemplate.exchange(
|
||||||
|
baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class);
|
||||||
|
|
||||||
|
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 样本校验测试 --
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("sampleIds 为空 → 400 EMPTY_SAMPLES")
|
||||||
|
void createBatch_withEmptyIds_returns400() {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.set("Authorization", "Bearer " + ADMIN_TOKEN);
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
HttpEntity<Map<String, Object>> req = new HttpEntity<>(
|
||||||
|
Map.of("sampleIds", List.of()), headers);
|
||||||
|
|
||||||
|
ResponseEntity<Map> response = restTemplate.exchange(
|
||||||
|
baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class);
|
||||||
|
|
||||||
|
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST);
|
||||||
|
assertThat(response.getBody().get("code")).isEqualTo("EMPTY_SAMPLES");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("包含非 APPROVED 样本 → 400 INVALID_SAMPLES")
|
||||||
|
void createBatch_withNonApprovedSample_returns400() {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.set("Authorization", "Bearer " + ADMIN_TOKEN);
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
// 混合 APPROVED + PENDING_REVIEW
|
||||||
|
HttpEntity<Map<String, Object>> req = new HttpEntity<>(
|
||||||
|
Map.of("sampleIds", List.of(approvedDatasetId, pendingDatasetId)), headers);
|
||||||
|
|
||||||
|
ResponseEntity<Map> response = restTemplate.exchange(
|
||||||
|
baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class);
|
||||||
|
|
||||||
|
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST);
|
||||||
|
assertThat(response.getBody().get("code")).isEqualTo("INVALID_SAMPLES");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@DisplayName("查询已审批样本列表 → 200,包含 APPROVED 样本")
|
||||||
|
void listSamples_adminOnly_returns200() {
|
||||||
|
ResponseEntity<Map> response = restTemplate.exchange(
|
||||||
|
baseUrl("/api/training/samples"),
|
||||||
|
HttpMethod.GET,
|
||||||
|
bearerRequest(ADMIN_TOKEN),
|
||||||
|
Map.class);
|
||||||
|
|
||||||
|
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> data = (Map<String, Object>) response.getBody().get("data");
|
||||||
|
assertThat(((Number) data.get("total")).longValue()).isGreaterThanOrEqualTo(1L);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------ 工具方法 --
|
||||||
|
|
||||||
|
private HttpEntity<Void> bearerRequest(String token) {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.set("Authorization", "Bearer " + token);
|
||||||
|
return new HttpEntity<>(headers);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user