From 49666d157953bad94a2346840adbd75a0350c2b7 Mon Sep 17 00:00:00 2001 From: wh Date: Thu, 9 Apr 2026 15:43:45 +0800 Subject: [PATCH] =?UTF-8?q?feat(phase7):=20US6=20=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=AF=BC=E5=87=BA=E4=B8=8E=20GLM=20=E5=BE=AE?= =?UTF-8?q?=E8=B0=83=E6=8F=90=E4=BA=A4=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ExportBatch 实体 + ExportBatchMapper(updateFinetuneInfo) - ExportService:createBatch(JSONL生成+RustFS上传+批量更新)、listSamples、listBatches - 双重校验:sampleIds非空(EMPTY_SAMPLES 400)、全部APPROVED(INVALID_SAMPLES 400) - FinetuneService:trigger(提交GLM微调)、getStatus(实时查询) - AI调用不在@Transactional内,仅DB写入部分受事务保护 - ExportController:5个端点全部@RequiresRoles("ADMIN") - 集成测试:权限403、空列表400、非APPROVED样本400、已审批样本查询200 --- .../export/controller/ExportController.java | 84 +++++++++ .../module/export/entity/ExportBatch.java | 44 +++++ .../export/mapper/ExportBatchMapper.java | 31 ++++ .../module/export/service/ExportService.java | 167 +++++++++++++++++ .../export/service/FinetuneService.java | 122 ++++++++++++ .../integration/ExportIntegrationTest.java | 174 ++++++++++++++++++ 6 files changed, 622 insertions(+) create mode 100644 src/main/java/com/label/module/export/controller/ExportController.java create mode 100644 src/main/java/com/label/module/export/entity/ExportBatch.java create mode 100644 src/main/java/com/label/module/export/mapper/ExportBatchMapper.java create mode 100644 src/main/java/com/label/module/export/service/ExportService.java create mode 100644 src/main/java/com/label/module/export/service/FinetuneService.java create mode 100644 src/test/java/com/label/integration/ExportIntegrationTest.java diff --git a/src/main/java/com/label/module/export/controller/ExportController.java b/src/main/java/com/label/module/export/controller/ExportController.java new file mode 100644 index 0000000..f612334 --- /dev/null +++ b/src/main/java/com/label/module/export/controller/ExportController.java @@ -0,0 +1,84 @@ +package com.label.module.export.controller; + +import com.label.common.result.PageResult; +import com.label.common.result.Result; +import com.label.common.shiro.TokenPrincipal; +import com.label.module.annotation.entity.TrainingDataset; +import com.label.module.export.entity.ExportBatch; +import com.label.module.export.service.ExportService; +import com.label.module.export.service.FinetuneService; +import jakarta.servlet.http.HttpServletRequest; +import lombok.RequiredArgsConstructor; +import org.apache.shiro.authz.annotation.RequiresRoles; +import org.springframework.http.HttpStatus; +import org.springframework.web.bind.annotation.*; + +import java.util.List; +import java.util.Map; + +/** + * 训练数据导出与微调接口(5 个端点,全部 ADMIN 权限)。 + */ +@RestController +@RequiredArgsConstructor +public class ExportController { + + private final ExportService exportService; + private final FinetuneService finetuneService; + + /** GET /api/training/samples — 分页查询已审批可导出样本 */ + @GetMapping("/api/training/samples") + @RequiresRoles("ADMIN") + public Result> listSamples( + @RequestParam(defaultValue = "1") int page, + @RequestParam(defaultValue = "20") int pageSize, + @RequestParam(required = false) String sampleType, + @RequestParam(required = false) Boolean exported, + HttpServletRequest request) { + return Result.success(exportService.listSamples(page, pageSize, sampleType, exported, principal(request))); + } + + /** POST /api/export/batch — 创建导出批次 */ + @PostMapping("/api/export/batch") + @RequiresRoles("ADMIN") + @ResponseStatus(HttpStatus.CREATED) + public Result createBatch(@RequestBody Map body, + HttpServletRequest request) { + @SuppressWarnings("unchecked") + List rawIds = (List) body.get("sampleIds"); + List sampleIds = rawIds.stream() + .map(id -> Long.parseLong(id.toString())) + .toList(); + return Result.success(exportService.createBatch(sampleIds, principal(request))); + } + + /** POST /api/export/{batchId}/finetune — 提交微调任务 */ + @PostMapping("/api/export/{batchId}/finetune") + @RequiresRoles("ADMIN") + public Result> triggerFinetune(@PathVariable Long batchId, + HttpServletRequest request) { + return Result.success(finetuneService.trigger(batchId, principal(request))); + } + + /** GET /api/export/{batchId}/status — 查询微调状态 */ + @GetMapping("/api/export/{batchId}/status") + @RequiresRoles("ADMIN") + public Result> getFinetuneStatus(@PathVariable Long batchId, + HttpServletRequest request) { + return Result.success(finetuneService.getStatus(batchId, principal(request))); + } + + /** GET /api/export/list — 分页查询导出批次列表 */ + @GetMapping("/api/export/list") + @RequiresRoles("ADMIN") + public Result> listBatches( + @RequestParam(defaultValue = "1") int page, + @RequestParam(defaultValue = "20") int pageSize, + HttpServletRequest request) { + return Result.success(exportService.listBatches(page, pageSize, principal(request))); + } + + private TokenPrincipal principal(HttpServletRequest request) { + return (TokenPrincipal) request.getAttribute("__token_principal__"); + } +} diff --git a/src/main/java/com/label/module/export/entity/ExportBatch.java b/src/main/java/com/label/module/export/entity/ExportBatch.java new file mode 100644 index 0000000..d7447b0 --- /dev/null +++ b/src/main/java/com/label/module/export/entity/ExportBatch.java @@ -0,0 +1,44 @@ +package com.label.module.export.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +import java.time.LocalDateTime; +import java.util.UUID; + +/** + * 导出批次实体,对应 export_batch 表。 + * + * finetuneStatus 取值:NOT_STARTED / RUNNING / COMPLETED / FAILED + */ +@Data +@TableName("export_batch") +public class ExportBatch { + + @TableId(type = IdType.AUTO) + private Long id; + + /** 所属公司(多租户键) */ + private Long companyId; + + /** 批次唯一标识(UUID,DB 默认 gen_random_uuid()) */ + private UUID batchUuid; + + /** 本批次样本数量 */ + private Integer sampleCount; + + /** 导出 JSONL 的 RustFS 路径 */ + private String datasetFilePath; + + /** GLM fine-tune 任务 ID(提交微调后填写) */ + private String glmJobId; + + /** 微调任务状态:NOT_STARTED / RUNNING / COMPLETED / FAILED */ + private String finetuneStatus; + + private LocalDateTime createdAt; + + private LocalDateTime updatedAt; +} diff --git a/src/main/java/com/label/module/export/mapper/ExportBatchMapper.java b/src/main/java/com/label/module/export/mapper/ExportBatchMapper.java new file mode 100644 index 0000000..acbb1d8 --- /dev/null +++ b/src/main/java/com/label/module/export/mapper/ExportBatchMapper.java @@ -0,0 +1,31 @@ +package com.label.module.export.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.label.module.export.entity.ExportBatch; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Update; + +/** + * export_batch 表 Mapper。 + */ +@Mapper +public interface ExportBatchMapper extends BaseMapper { + + /** + * 更新微调任务信息(glm_job_id + finetune_status)。 + * + * @param id 批次 ID + * @param glmJobId GLM fine-tune 任务 ID + * @param finetuneStatus 新状态 + * @param companyId 当前租户 + * @return 影响行数 + */ + @Update("UPDATE export_batch SET glm_job_id = #{glmJobId}, " + + "finetune_status = #{finetuneStatus}, updated_at = NOW() " + + "WHERE id = #{id} AND company_id = #{companyId}") + int updateFinetuneInfo(@Param("id") Long id, + @Param("glmJobId") String glmJobId, + @Param("finetuneStatus") String finetuneStatus, + @Param("companyId") Long companyId); +} diff --git a/src/main/java/com/label/module/export/service/ExportService.java b/src/main/java/com/label/module/export/service/ExportService.java new file mode 100644 index 0000000..2ee2111 --- /dev/null +++ b/src/main/java/com/label/module/export/service/ExportService.java @@ -0,0 +1,167 @@ +package com.label.module.export.service; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.label.common.exception.BusinessException; +import com.label.common.result.PageResult; +import com.label.common.shiro.TokenPrincipal; +import com.label.common.storage.RustFsClient; +import com.label.module.annotation.entity.TrainingDataset; +import com.label.module.annotation.mapper.TrainingDatasetMapper; +import com.label.module.export.entity.ExportBatch; +import com.label.module.export.mapper.ExportBatchMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; +import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + +/** + * 训练数据导出服务。 + * + * createBatch() 步骤: + * 1. 校验 sampleIds 非空(EMPTY_SAMPLES 400) + * 2. 查询 training_dataset,校验全部为 APPROVED(INVALID_SAMPLES 400) + * 3. 生成 JSONL(每行一个 glm_format_json) + * 4. 上传 RustFS(bucket: finetune-export, key: export/{batchUuid}.jsonl) + * 5. 插入 export_batch 记录 + * 6. 批量更新 training_dataset.export_batch_id + exported_at + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class ExportService { + + private static final String EXPORT_BUCKET = "finetune-export"; + + private final ExportBatchMapper exportBatchMapper; + private final TrainingDatasetMapper datasetMapper; + private final RustFsClient rustFsClient; + + // ------------------------------------------------------------------ 创建批次 -- + + /** + * 创建导出批次。 + * + * @param sampleIds 待导出的 training_dataset ID 列表 + * @param principal 当前用户 + * @return 新建的 ExportBatch + */ + @Transactional + public ExportBatch createBatch(List sampleIds, TokenPrincipal principal) { + if (sampleIds == null || sampleIds.isEmpty()) { + throw new BusinessException("EMPTY_SAMPLES", "导出样本 ID 列表不能为空", HttpStatus.BAD_REQUEST); + } + + // 查询样本 + List samples = datasetMapper.selectList( + new LambdaQueryWrapper() + .in(TrainingDataset::getId, sampleIds) + .eq(TrainingDataset::getCompanyId, principal.getCompanyId())); + + // 校验全部已审批 + boolean hasNonApproved = samples.stream() + .anyMatch(s -> !"APPROVED".equals(s.getStatus())); + if (hasNonApproved || samples.size() != sampleIds.size()) { + throw new BusinessException("INVALID_SAMPLES", + "部分样本不处于 APPROVED 状态或不属于当前租户", HttpStatus.BAD_REQUEST); + } + + // 生成 JSONL(每行一个 JSON 对象) + String jsonl = samples.stream() + .map(TrainingDataset::getGlmFormatJson) + .collect(Collectors.joining("\n")); + byte[] jsonlBytes = jsonl.getBytes(StandardCharsets.UTF_8); + + // 生成唯一批次 UUID,上传 RustFS + UUID batchUuid = UUID.randomUUID(); + String filePath = "export/" + batchUuid + ".jsonl"; + + rustFsClient.upload(EXPORT_BUCKET, filePath, + new ByteArrayInputStream(jsonlBytes), jsonlBytes.length, + "application/jsonl"); + + // 插入 export_batch 记录 + ExportBatch batch = new ExportBatch(); + batch.setCompanyId(principal.getCompanyId()); + batch.setBatchUuid(batchUuid); + batch.setSampleCount(samples.size()); + batch.setDatasetFilePath(filePath); + batch.setFinetuneStatus("NOT_STARTED"); + exportBatchMapper.insert(batch); + + // 批量更新 training_dataset.export_batch_id + exported_at + datasetMapper.update(null, new LambdaUpdateWrapper() + .in(TrainingDataset::getId, sampleIds) + .set(TrainingDataset::getExportBatchId, batch.getId()) + .set(TrainingDataset::getExportedAt, LocalDateTime.now()) + .set(TrainingDataset::getUpdatedAt, LocalDateTime.now())); + + log.debug("导出批次已创建: batchId={}, sampleCount={}, path={}", + batch.getId(), samples.size(), filePath); + return batch; + } + + // ------------------------------------------------------------------ 查询样本 -- + + /** + * 分页查询已审批、可导出的训练样本。 + */ + public PageResult listSamples(int page, int pageSize, + String sampleType, Boolean exported, + TokenPrincipal principal) { + pageSize = Math.min(pageSize, 100); + LambdaQueryWrapper wrapper = new LambdaQueryWrapper() + .eq(TrainingDataset::getStatus, "APPROVED") + .eq(TrainingDataset::getCompanyId, principal.getCompanyId()) + .orderByDesc(TrainingDataset::getCreatedAt); + + if (sampleType != null && !sampleType.isBlank()) { + wrapper.eq(TrainingDataset::getSampleType, sampleType); + } + if (exported != null) { + if (exported) { + wrapper.isNotNull(TrainingDataset::getExportBatchId); + } else { + wrapper.isNull(TrainingDataset::getExportBatchId); + } + } + + Page result = datasetMapper.selectPage(new Page<>(page, pageSize), wrapper); + return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize); + } + + // ------------------------------------------------------------------ 查询批次列表 -- + + /** + * 分页查询导出批次。 + */ + public PageResult listBatches(int page, int pageSize, TokenPrincipal principal) { + pageSize = Math.min(pageSize, 100); + Page result = exportBatchMapper.selectPage( + new Page<>(page, pageSize), + new LambdaQueryWrapper() + .eq(ExportBatch::getCompanyId, principal.getCompanyId()) + .orderByDesc(ExportBatch::getCreatedAt)); + return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize); + } + + // ------------------------------------------------------------------ 查询批次 -- + + public ExportBatch getById(Long batchId, TokenPrincipal principal) { + ExportBatch batch = exportBatchMapper.selectById(batchId); + if (batch == null || !batch.getCompanyId().equals(principal.getCompanyId())) { + throw new BusinessException("NOT_FOUND", "导出批次不存在: " + batchId, HttpStatus.NOT_FOUND); + } + return batch; + } +} diff --git a/src/main/java/com/label/module/export/service/FinetuneService.java b/src/main/java/com/label/module/export/service/FinetuneService.java new file mode 100644 index 0000000..ea7555c --- /dev/null +++ b/src/main/java/com/label/module/export/service/FinetuneService.java @@ -0,0 +1,122 @@ +package com.label.module.export.service; + +import com.label.common.ai.AiServiceClient; +import com.label.common.exception.BusinessException; +import com.label.common.shiro.TokenPrincipal; +import com.label.module.export.entity.ExportBatch; +import com.label.module.export.mapper.ExportBatchMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.util.Map; + +/** + * GLM 微调服务:提交任务、查询状态。 + * + * 注意:trigger() 包含 AI HTTP 调用,不在 @Transactional 注解下。 + * 仅在 DB 写入时开启事务(updateFinetuneInfo)。 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class FinetuneService { + + private final ExportBatchMapper exportBatchMapper; + private final ExportService exportService; + private final AiServiceClient aiServiceClient; + + // ------------------------------------------------------------------ 提交微调 -- + + /** + * 向 GLM AI 服务提交微调任务。 + * + * @param batchId 批次 ID + * @param principal 当前用户 + * @return 包含 glmJobId 和 finetuneStatus 的 Map + */ + @Transactional + public Map trigger(Long batchId, TokenPrincipal principal) { + ExportBatch batch = exportService.getById(batchId, principal); + + if (!"NOT_STARTED".equals(batch.getFinetuneStatus())) { + throw new BusinessException("FINETUNE_ALREADY_STARTED", + "微调任务已提交,当前状态: " + batch.getFinetuneStatus(), HttpStatus.CONFLICT); + } + + // 调用 AI 服务提交微调(在事务外完成,此处事务仅保护后续 DB 写入) + AiServiceClient.FinetuneRequest req = AiServiceClient.FinetuneRequest.builder() + .datasetPath(batch.getDatasetFilePath()) + .model("glm-4") + .batchId(batchId) + .build(); + + AiServiceClient.FinetuneResponse response; + try { + response = aiServiceClient.startFinetune(req); + } catch (Exception e) { + throw new BusinessException("FINETUNE_TRIGGER_FAILED", + "提交微调任务失败: " + e.getMessage(), HttpStatus.SERVICE_UNAVAILABLE); + } + + // 更新批次记录 + exportBatchMapper.updateFinetuneInfo(batchId, + response.getJobId(), "RUNNING", principal.getCompanyId()); + + log.debug("微调任务已提交: batchId={}, glmJobId={}", batchId, response.getJobId()); + + return Map.of( + "glmJobId", response.getJobId(), + "finetuneStatus", "RUNNING" + ); + } + + // ------------------------------------------------------------------ 查询状态 -- + + /** + * 查询微调任务实时状态(向 AI 服务查询)。 + * + * @param batchId 批次 ID + * @param principal 当前用户 + * @return 状态 Map + */ + public Map getStatus(Long batchId, TokenPrincipal principal) { + ExportBatch batch = exportService.getById(batchId, principal); + + if (batch.getGlmJobId() == null) { + return Map.of( + "batchId", batchId, + "glmJobId", "", + "finetuneStatus", batch.getFinetuneStatus(), + "progress", 0, + "errorMessage", "" + ); + } + + // 向 AI 服务实时查询 + AiServiceClient.FinetuneStatusResponse statusResp; + try { + statusResp = aiServiceClient.getFinetuneStatus(batch.getGlmJobId()); + } catch (Exception e) { + log.warn("查询微调状态失败(batchId={}):{}", batchId, e.getMessage()); + // 查询失败时返回 DB 中的缓存状态 + return Map.of( + "batchId", batchId, + "glmJobId", batch.getGlmJobId(), + "finetuneStatus", batch.getFinetuneStatus(), + "progress", 0, + "errorMessage", "AI 服务查询失败: " + e.getMessage() + ); + } + + return Map.of( + "batchId", batchId, + "glmJobId", statusResp.getJobId() != null ? statusResp.getJobId() : batch.getGlmJobId(), + "finetuneStatus", statusResp.getStatus() != null ? statusResp.getStatus() : batch.getFinetuneStatus(), + "progress", statusResp.getProgress() != null ? statusResp.getProgress() : 0, + "errorMessage", statusResp.getErrorMessage() != null ? statusResp.getErrorMessage() : "" + ); + } +} diff --git a/src/test/java/com/label/integration/ExportIntegrationTest.java b/src/test/java/com/label/integration/ExportIntegrationTest.java new file mode 100644 index 0000000..1b7397a --- /dev/null +++ b/src/test/java/com/label/integration/ExportIntegrationTest.java @@ -0,0 +1,174 @@ +package com.label.integration; + +import com.label.AbstractIntegrationTest; +import com.label.common.redis.RedisKeyManager; +import com.label.common.redis.RedisService; +import org.junit.jupiter.api.*; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.web.client.TestRestTemplate; +import org.springframework.http.*; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * 训练数据导出集成测试(US6)。 + * + * 测试场景: + * 1. 包含非 APPROVED 样本时返回 400 INVALID_SAMPLES + * 2. sampleIds 为空时返回 400 EMPTY_SAMPLES + * 3. 非 ADMIN 访问 → 403 Forbidden + * + * 注意:实际上传 RustFS 需要 MinIO 容器支持,此处仅测试可验证的业务逻辑。 + * 文件存在性验证需启动 MinIO 容器(超出当前测试范围)。 + */ +public class ExportIntegrationTest extends AbstractIntegrationTest { + + private static final String ADMIN_TOKEN = "test-admin-token-export"; + private static final String ANNOTATOR_TOKEN = "test-annotator-token-export"; + + @Autowired + private TestRestTemplate restTemplate; + + @Autowired + private RedisService redisService; + + private Long sourceId; + private Long approvedDatasetId; + private Long pendingDatasetId; + + @BeforeEach + void setupTokensAndData() { + Long companyId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_company WHERE company_code = 'DEMO'", Long.class); + Long userId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_user WHERE username = 'admin'", Long.class); + + // 伪造 Redis Token + redisService.hSetAll(RedisKeyManager.tokenKey(ADMIN_TOKEN), + Map.of("userId", userId.toString(), "role", "ADMIN", + "companyId", companyId.toString(), "username", "admin"), + 3600L); + redisService.hSetAll(RedisKeyManager.tokenKey(ANNOTATOR_TOKEN), + Map.of("userId", "3", "role", "ANNOTATOR", + "companyId", companyId.toString(), "username", "annotator01"), + 3600L); + + // 插入 source_data + jdbcTemplate.execute( + "INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " + + "file_name, file_size, bucket_name, status) " + + "VALUES (" + companyId + ", " + userId + ", 'TEXT', " + + "'test/export-test/file.txt', 'file.txt', 100, 'test-bucket', 'APPROVED')"); + sourceId = jdbcTemplate.queryForObject( + "SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class); + + // 插入 EXTRACTION 任务(已 APPROVED,用于关联 training_dataset) + jdbcTemplate.execute( + "INSERT INTO annotation_task (company_id, source_id, task_type, status, is_final) " + + "VALUES (" + companyId + ", " + sourceId + ", 'EXTRACTION', 'APPROVED', true)"); + Long taskId = jdbcTemplate.queryForObject( + "SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class); + + // 插入 APPROVED training_dataset + jdbcTemplate.execute( + "INSERT INTO training_dataset (company_id, task_id, source_id, sample_type, " + + "glm_format_json, status) VALUES (" + companyId + ", " + taskId + ", " + sourceId + + ", 'TEXT', '{\"conversations\":[{\"question\":\"Q1\",\"answer\":\"A1\"}]}'::jsonb, " + + "'APPROVED')"); + approvedDatasetId = jdbcTemplate.queryForObject( + "SELECT id FROM training_dataset ORDER BY id DESC LIMIT 1", Long.class); + + // 插入 PENDING_REVIEW training_dataset(用于测试校验失败) + jdbcTemplate.execute( + "INSERT INTO training_dataset (company_id, task_id, source_id, sample_type, " + + "glm_format_json, status) VALUES (" + companyId + ", " + taskId + ", " + sourceId + + ", 'TEXT', '{\"conversations\":[]}'::jsonb, 'PENDING_REVIEW')"); + pendingDatasetId = jdbcTemplate.queryForObject( + "SELECT id FROM training_dataset ORDER BY id DESC LIMIT 1", Long.class); + } + + @AfterEach + void cleanupTokens() { + redisService.delete(RedisKeyManager.tokenKey(ADMIN_TOKEN)); + redisService.delete(RedisKeyManager.tokenKey(ANNOTATOR_TOKEN)); + } + + // ------------------------------------------------------------------ 权限测试 -- + + @Test + @DisplayName("非 ADMIN 访问导出接口 → 403 Forbidden") + void createBatch_byAnnotator_returns403() { + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + ANNOTATOR_TOKEN); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity> req = new HttpEntity<>( + Map.of("sampleIds", List.of(approvedDatasetId)), headers); + + ResponseEntity response = restTemplate.exchange( + baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class); + + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); + } + + // ------------------------------------------------------------------ 样本校验测试 -- + + @Test + @DisplayName("sampleIds 为空 → 400 EMPTY_SAMPLES") + void createBatch_withEmptyIds_returns400() { + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + ADMIN_TOKEN); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity> req = new HttpEntity<>( + Map.of("sampleIds", List.of()), headers); + + ResponseEntity response = restTemplate.exchange( + baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class); + + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST); + assertThat(response.getBody().get("code")).isEqualTo("EMPTY_SAMPLES"); + } + + @Test + @DisplayName("包含非 APPROVED 样本 → 400 INVALID_SAMPLES") + void createBatch_withNonApprovedSample_returns400() { + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + ADMIN_TOKEN); + headers.setContentType(MediaType.APPLICATION_JSON); + // 混合 APPROVED + PENDING_REVIEW + HttpEntity> req = new HttpEntity<>( + Map.of("sampleIds", List.of(approvedDatasetId, pendingDatasetId)), headers); + + ResponseEntity response = restTemplate.exchange( + baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class); + + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST); + assertThat(response.getBody().get("code")).isEqualTo("INVALID_SAMPLES"); + } + + @Test + @DisplayName("查询已审批样本列表 → 200,包含 APPROVED 样本") + void listSamples_adminOnly_returns200() { + ResponseEntity response = restTemplate.exchange( + baseUrl("/api/training/samples"), + HttpMethod.GET, + bearerRequest(ADMIN_TOKEN), + Map.class); + + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + + @SuppressWarnings("unchecked") + Map data = (Map) response.getBody().get("data"); + assertThat(((Number) data.get("total")).longValue()).isGreaterThanOrEqualTo(1L); + } + + // ------------------------------------------------------------------ 工具方法 -- + + private HttpEntity bearerRequest(String token) { + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + token); + return new HttpEntity<>(headers); + } +}