Files
label_backend/src/test/java/com/label/integration/TaskClaimConcurrencyTest.java
wh 927e4f1cf3 feat(phase5): US3+US4 任务领取、提取标注与审批模块
- 任务领取(TaskClaimService):Redis SET NX + DB WHERE status=UNCLAIMED 双重并发防护
- 任务管理(TaskService/TaskController):任务池/我的任务/待审批/全部任务/创建/指派 10 端点
- 提取标注(ExtractionService/ExtractionController):AI 预标注/更新/提交/审批/驳回 5 端点
- 审批解耦(ExtractionApprovedEventListener):@TransactionalEventListener(AFTER_COMMIT) + REQUIRES_NEW
  确保 AI QA 生成在审批事务提交后独立执行,异常不回滚审批结果
- 状态实体:AnnotationTask/AnnotationTaskHistory/AnnotationResult/TrainingDataset
- 集成测试:并发领取安全(10 线程恰好 1 成功)+ 审批流(通过/自审/驳回重领)
2026-04-09 15:36:11 +08:00

136 lines
5.4 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.label.integration;
import com.label.AbstractIntegrationTest;
import com.label.common.redis.RedisKeyManager;
import com.label.common.redis.RedisService;
import org.junit.jupiter.api.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.http.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import static org.assertj.core.api.Assertions.assertThat;
/**
* 任务领取并发安全集成测试US3
*
* 测试场景10 个线程同时争抢同一 UNCLAIMED 任务。
* 期望结果:
* - 恰好 1 人成功200 OK
* - 其余 9 人收到 TASK_CLAIMED (409)
* - DB 中 claimed_by 唯一(只有一个用户 ID
*
* 此测试需要 10 个不同的 userId使用同一 DB 用户账号但不同的 Token。
*/
public class TaskClaimConcurrencyTest extends AbstractIntegrationTest {
@Autowired
private TestRestTemplate restTemplate;
@Autowired
private RedisService redisService;
private Long taskId;
private final List<String> tokens = new ArrayList<>();
@BeforeEach
void setup() {
// 创建测试任务(直接向 DB 插入一条 UNCLAIMED 任务)
jdbcTemplate.execute(
"INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " +
"file_name, file_size, bucket_name, status) " +
"VALUES (1, 1, 'TEXT', 'test/path/file.txt', 'file.txt', 100, 'test-bucket', 'PENDING')");
Long sourceId = jdbcTemplate.queryForObject(
"SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class);
jdbcTemplate.execute(
"INSERT INTO annotation_task (company_id, source_id, task_type, status) " +
"VALUES (1, " + sourceId + ", 'EXTRACTION', 'UNCLAIMED')");
taskId = jdbcTemplate.queryForObject(
"SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class);
// 创建 10 个 Annotator Token模拟不同用户
for (int i = 1; i <= 10; i++) {
String token = "concurrency-test-token-" + i;
tokens.add(token);
// 所有 Token 使用 userId=3annotator01这在真实场景不会发生
// 但在测试中用于验证并发锁机制redis key 基于 taskId不是 userId
redisService.hSetAll(RedisKeyManager.tokenKey(token),
Map.of("userId", String.valueOf(i + 100), // 假设 userId > 100 不存在,但不影响锁逻辑
"role", "ANNOTATOR", "companyId", "1", "username", "annotator" + i),
3600L);
}
}
@AfterEach
void cleanup() {
tokens.forEach(token -> redisService.delete(RedisKeyManager.tokenKey(token)));
if (taskId != null) {
redisService.delete(RedisKeyManager.taskClaimKey(taskId));
}
}
@Test
@DisplayName("10 线程并发抢同一任务:恰好 1 人成功,其余 9 人收到 409 TASK_CLAIMED")
void concurrentClaim_onlyOneSucceeds() throws InterruptedException {
ExecutorService executor = Executors.newFixedThreadPool(10);
CountDownLatch startLatch = new CountDownLatch(1);
CountDownLatch doneLatch = new CountDownLatch(10);
AtomicInteger successCount = new AtomicInteger(0);
AtomicInteger conflictCount = new AtomicInteger(0);
for (int i = 0; i < 10; i++) {
final String token = tokens.get(i);
executor.submit(() -> {
try {
startLatch.await(); // 等待起跑信号,最大化并发
HttpHeaders headers = new HttpHeaders();
headers.set("Authorization", "Bearer " + token);
HttpEntity<Void> request = new HttpEntity<>(headers);
ResponseEntity<Map> response = restTemplate.exchange(
baseUrl("/api/tasks/" + taskId + "/claim"),
HttpMethod.POST, request, Map.class);
if (response.getStatusCode() == HttpStatus.OK) {
successCount.incrementAndGet();
} else if (response.getStatusCode() == HttpStatus.CONFLICT) {
conflictCount.incrementAndGet();
}
} catch (Exception e) {
conflictCount.incrementAndGet(); // 异常也算失败
} finally {
doneLatch.countDown();
}
});
}
startLatch.countDown(); // 同时放行所有线程
doneLatch.await(30, TimeUnit.SECONDS);
executor.shutdown();
// 恰好 1 人成功
assertThat(successCount.get()).isEqualTo(1);
// 其余 9 人失败409 或异常)
assertThat(conflictCount.get()).isEqualTo(9);
// DB 中 claimed_by 有且仅有一个值
String claimedByStr = jdbcTemplate.queryForObject(
"SELECT claimed_by::text FROM annotation_task WHERE id = ?",
String.class, taskId);
assertThat(claimedByStr).isNotNull();
// DB 中状态为 IN_PROGRESS
String status = jdbcTemplate.queryForObject(
"SELECT status FROM annotation_task WHERE id = ?", String.class, taskId);
assertThat(status).isEqualTo("IN_PROGRESS");
}
}