Files
label_backend/src/test/java/com/label/integration/ExportIntegrationTest.java
wh 49666d1579 feat(phase7): US6 训练数据导出与 GLM 微调提交模块
- ExportBatch 实体 + ExportBatchMapper(updateFinetuneInfo)
- ExportService:createBatch(JSONL生成+RustFS上传+批量更新)、listSamples、listBatches
  - 双重校验:sampleIds非空(EMPTY_SAMPLES 400)、全部APPROVED(INVALID_SAMPLES 400)
- FinetuneService:trigger(提交GLM微调)、getStatus(实时查询)
  - AI调用不在@Transactional内,仅DB写入部分受事务保护
- ExportController:5个端点全部@RequiresRoles("ADMIN")
- 集成测试:权限403、空列表400、非APPROVED样本400、已审批样本查询200
2026-04-09 15:43:45 +08:00

175 lines
7.6 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.label.integration;
import com.label.AbstractIntegrationTest;
import com.label.common.redis.RedisKeyManager;
import com.label.common.redis.RedisService;
import org.junit.jupiter.api.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.http.*;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
/**
* 训练数据导出集成测试US6
*
* 测试场景:
* 1. 包含非 APPROVED 样本时返回 400 INVALID_SAMPLES
* 2. sampleIds 为空时返回 400 EMPTY_SAMPLES
* 3. 非 ADMIN 访问 → 403 Forbidden
*
* 注意:实际上传 RustFS 需要 MinIO 容器支持,此处仅测试可验证的业务逻辑。
* 文件存在性验证需启动 MinIO 容器(超出当前测试范围)。
*/
public class ExportIntegrationTest extends AbstractIntegrationTest {
private static final String ADMIN_TOKEN = "test-admin-token-export";
private static final String ANNOTATOR_TOKEN = "test-annotator-token-export";
@Autowired
private TestRestTemplate restTemplate;
@Autowired
private RedisService redisService;
private Long sourceId;
private Long approvedDatasetId;
private Long pendingDatasetId;
@BeforeEach
void setupTokensAndData() {
Long companyId = jdbcTemplate.queryForObject(
"SELECT id FROM sys_company WHERE company_code = 'DEMO'", Long.class);
Long userId = jdbcTemplate.queryForObject(
"SELECT id FROM sys_user WHERE username = 'admin'", Long.class);
// 伪造 Redis Token
redisService.hSetAll(RedisKeyManager.tokenKey(ADMIN_TOKEN),
Map.of("userId", userId.toString(), "role", "ADMIN",
"companyId", companyId.toString(), "username", "admin"),
3600L);
redisService.hSetAll(RedisKeyManager.tokenKey(ANNOTATOR_TOKEN),
Map.of("userId", "3", "role", "ANNOTATOR",
"companyId", companyId.toString(), "username", "annotator01"),
3600L);
// 插入 source_data
jdbcTemplate.execute(
"INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " +
"file_name, file_size, bucket_name, status) " +
"VALUES (" + companyId + ", " + userId + ", 'TEXT', " +
"'test/export-test/file.txt', 'file.txt', 100, 'test-bucket', 'APPROVED')");
sourceId = jdbcTemplate.queryForObject(
"SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class);
// 插入 EXTRACTION 任务(已 APPROVED用于关联 training_dataset
jdbcTemplate.execute(
"INSERT INTO annotation_task (company_id, source_id, task_type, status, is_final) " +
"VALUES (" + companyId + ", " + sourceId + ", 'EXTRACTION', 'APPROVED', true)");
Long taskId = jdbcTemplate.queryForObject(
"SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class);
// 插入 APPROVED training_dataset
jdbcTemplate.execute(
"INSERT INTO training_dataset (company_id, task_id, source_id, sample_type, " +
"glm_format_json, status) VALUES (" + companyId + ", " + taskId + ", " + sourceId +
", 'TEXT', '{\"conversations\":[{\"question\":\"Q1\",\"answer\":\"A1\"}]}'::jsonb, " +
"'APPROVED')");
approvedDatasetId = jdbcTemplate.queryForObject(
"SELECT id FROM training_dataset ORDER BY id DESC LIMIT 1", Long.class);
// 插入 PENDING_REVIEW training_dataset用于测试校验失败
jdbcTemplate.execute(
"INSERT INTO training_dataset (company_id, task_id, source_id, sample_type, " +
"glm_format_json, status) VALUES (" + companyId + ", " + taskId + ", " + sourceId +
", 'TEXT', '{\"conversations\":[]}'::jsonb, 'PENDING_REVIEW')");
pendingDatasetId = jdbcTemplate.queryForObject(
"SELECT id FROM training_dataset ORDER BY id DESC LIMIT 1", Long.class);
}
@AfterEach
void cleanupTokens() {
redisService.delete(RedisKeyManager.tokenKey(ADMIN_TOKEN));
redisService.delete(RedisKeyManager.tokenKey(ANNOTATOR_TOKEN));
}
// ------------------------------------------------------------------ 权限测试 --
@Test
@DisplayName("非 ADMIN 访问导出接口 → 403 Forbidden")
void createBatch_byAnnotator_returns403() {
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", "Bearer " + ANNOTATOR_TOKEN);
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Map<String, Object>> req = new HttpEntity<>(
Map.of("sampleIds", List.of(approvedDatasetId)), headers);
ResponseEntity<Map> response = restTemplate.exchange(
baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
}
// ------------------------------------------------------------------ 样本校验测试 --
@Test
@DisplayName("sampleIds 为空 → 400 EMPTY_SAMPLES")
void createBatch_withEmptyIds_returns400() {
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", "Bearer " + ADMIN_TOKEN);
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Map<String, Object>> req = new HttpEntity<>(
Map.of("sampleIds", List.of()), headers);
ResponseEntity<Map> response = restTemplate.exchange(
baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST);
assertThat(response.getBody().get("code")).isEqualTo("EMPTY_SAMPLES");
}
@Test
@DisplayName("包含非 APPROVED 样本 → 400 INVALID_SAMPLES")
void createBatch_withNonApprovedSample_returns400() {
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", "Bearer " + ADMIN_TOKEN);
headers.setContentType(MediaType.APPLICATION_JSON);
// 混合 APPROVED + PENDING_REVIEW
HttpEntity<Map<String, Object>> req = new HttpEntity<>(
Map.of("sampleIds", List.of(approvedDatasetId, pendingDatasetId)), headers);
ResponseEntity<Map> response = restTemplate.exchange(
baseUrl("/api/export/batch"), HttpMethod.POST, req, Map.class);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST);
assertThat(response.getBody().get("code")).isEqualTo("INVALID_SAMPLES");
}
@Test
@DisplayName("查询已审批样本列表 → 200包含 APPROVED 样本")
void listSamples_adminOnly_returns200() {
ResponseEntity<Map> response = restTemplate.exchange(
baseUrl("/api/training/samples"),
HttpMethod.GET,
bearerRequest(ADMIN_TOKEN),
Map.class);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
@SuppressWarnings("unchecked")
Map<String, Object> data = (Map<String, Object>) response.getBody().get("data");
assertThat(((Number) data.get("total")).longValue()).isGreaterThanOrEqualTo(1L);
}
// ------------------------------------------------------------------ 工具方法 --
private HttpEntity<Void> bearerRequest(String token) {
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", "Bearer " + token);
return new HttpEntity<>(headers);
}
}