package com.label.integration; import com.label.AbstractIntegrationTest; import com.label.service.RedisService; import com.label.util.RedisKeyManager; import org.junit.jupiter.api.*; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.web.client.TestRestTemplate; import org.springframework.http.*; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; /** * 任务领取并发安全集成测试(US3)。 * * 测试场景:10 个线程同时争抢同一 UNCLAIMED 任务。 * 期望结果: * - 恰好 1 人成功(200 OK) * - 其余 9 人收到 TASK_CLAIMED (409) * - DB 中 claimed_by 唯一(只有一个用户 ID) * * 此测试需要 10 个不同的 userId,使用同一 DB 用户账号但不同的 Token。 */ public class TaskClaimConcurrencyTest extends AbstractIntegrationTest { @Autowired private TestRestTemplate restTemplate; @Autowired private RedisService redisService; private Long taskId; private final List tokens = new ArrayList<>(); @BeforeEach void setup() { // 创建测试任务(直接向 DB 插入一条 UNCLAIMED 任务) jdbcTemplate.execute( "INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " + "file_name, file_size, bucket_name, status) " + "VALUES (1, 1, 'TEXT', 'test/path/file.txt', 'file.txt', 100, 'test-bucket', 'PENDING')"); Long sourceId = jdbcTemplate.queryForObject( "SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class); jdbcTemplate.execute( "INSERT INTO annotation_task (company_id, source_id, task_type, status) " + "VALUES (1, " + sourceId + ", 'EXTRACTION', 'UNCLAIMED')"); taskId = jdbcTemplate.queryForObject( "SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class); // 创建 10 个 Annotator Token(模拟不同用户) for (int i = 1; i <= 10; i++) { String token = "concurrency-test-token-" + i; tokens.add(token); // 所有 Token 使用 userId=3(annotator01),这在真实场景不会发生 // 但在测试中用于验证并发锁机制(redis key 基于 taskId,不是 userId) redisService.hSetAll(RedisKeyManager.tokenKey(token), Map.of("userId", String.valueOf(i + 100), // 假设 userId > 100 不存在,但不影响锁逻辑 "role", "ANNOTATOR", "companyId", "1", "username", "annotator" + i), 3600L); } } @AfterEach void cleanup() { tokens.forEach(token -> redisService.delete(RedisKeyManager.tokenKey(token))); if (taskId != null) { redisService.delete(RedisKeyManager.taskClaimKey(taskId)); } } @Test @DisplayName("10 线程并发抢同一任务:恰好 1 人成功,其余 9 人收到 409 TASK_CLAIMED") void concurrentClaim_onlyOneSucceeds() throws InterruptedException { ExecutorService executor = Executors.newFixedThreadPool(10); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch doneLatch = new CountDownLatch(10); AtomicInteger successCount = new AtomicInteger(0); AtomicInteger conflictCount = new AtomicInteger(0); for (int i = 0; i < 10; i++) { final String token = tokens.get(i); executor.submit(() -> { try { startLatch.await(); // 等待起跑信号,最大化并发 HttpHeaders headers = new HttpHeaders(); headers.set("Authorization", "Bearer " + token); HttpEntity request = new HttpEntity<>(headers); ResponseEntity response = restTemplate.exchange( baseUrl("/api/tasks/" + taskId + "/claim"), HttpMethod.POST, request, Map.class); if (response.getStatusCode() == HttpStatus.OK) { successCount.incrementAndGet(); } else if (response.getStatusCode() == HttpStatus.CONFLICT) { conflictCount.incrementAndGet(); } } catch (Exception e) { conflictCount.incrementAndGet(); // 异常也算失败 } finally { doneLatch.countDown(); } }); } startLatch.countDown(); // 同时放行所有线程 doneLatch.await(30, TimeUnit.SECONDS); executor.shutdown(); // 恰好 1 人成功 assertThat(successCount.get()).isEqualTo(1); // 其余 9 人失败(409 或异常) assertThat(conflictCount.get()).isEqualTo(9); // DB 中 claimed_by 有且仅有一个值 String claimedByStr = jdbcTemplate.queryForObject( "SELECT claimed_by::text FROM annotation_task WHERE id = ?", String.class, taskId); assertThat(claimedByStr).isNotNull(); // DB 中状态为 IN_PROGRESS String status = jdbcTemplate.queryForObject( "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); assertThat(status).isEqualTo("IN_PROGRESS"); } }