137 lines
5.4 KiB
Java
137 lines
5.4 KiB
Java
package com.label.integration;
|
||
|
||
import com.label.AbstractIntegrationTest;
|
||
import com.label.service.RedisService;
|
||
import com.label.util.RedisUtil;
|
||
|
||
import org.junit.jupiter.api.*;
|
||
import org.springframework.beans.factory.annotation.Autowired;
|
||
import org.springframework.boot.test.web.client.TestRestTemplate;
|
||
import org.springframework.http.*;
|
||
|
||
import java.util.ArrayList;
|
||
import java.util.List;
|
||
import java.util.Map;
|
||
import java.util.concurrent.*;
|
||
import java.util.concurrent.atomic.AtomicInteger;
|
||
|
||
import static org.assertj.core.api.Assertions.assertThat;
|
||
|
||
/**
|
||
* 任务领取并发安全集成测试(US3)。
|
||
*
|
||
* 测试场景:10 个线程同时争抢同一 UNCLAIMED 任务。
|
||
* 期望结果:
|
||
* - 恰好 1 人成功(200 OK)
|
||
* - 其余 9 人收到 TASK_CLAIMED (409)
|
||
* - DB 中 claimed_by 唯一(只有一个用户 ID)
|
||
*
|
||
* 此测试需要 10 个不同的 userId,使用同一 DB 用户账号但不同的 Token。
|
||
*/
|
||
public class TaskClaimConcurrencyTest extends AbstractIntegrationTest {
|
||
|
||
@Autowired
|
||
private TestRestTemplate restTemplate;
|
||
|
||
@Autowired
|
||
private RedisService redisService;
|
||
|
||
private Long taskId;
|
||
private final List<String> tokens = new ArrayList<>();
|
||
|
||
@BeforeEach
|
||
void setup() {
|
||
// 创建测试任务(直接向 DB 插入一条 UNCLAIMED 任务)
|
||
jdbcTemplate.execute(
|
||
"INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " +
|
||
"file_name, file_size, bucket_name, status) " +
|
||
"VALUES (1, 1, 'TEXT', 'test/path/file.txt', 'file.txt', 100, 'test-bucket', 'PENDING')");
|
||
Long sourceId = jdbcTemplate.queryForObject(
|
||
"SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class);
|
||
|
||
jdbcTemplate.execute(
|
||
"INSERT INTO annotation_task (company_id, source_id, task_type, status) " +
|
||
"VALUES (1, " + sourceId + ", 'EXTRACTION', 'UNCLAIMED')");
|
||
taskId = jdbcTemplate.queryForObject(
|
||
"SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class);
|
||
|
||
// 创建 10 个 Annotator Token(模拟不同用户)
|
||
for (int i = 1; i <= 10; i++) {
|
||
String token = "concurrency-test-token-" + i;
|
||
tokens.add(token);
|
||
// 所有 Token 使用 userId=3(annotator01),这在真实场景不会发生
|
||
// 但在测试中用于验证并发锁机制(redis key 基于 taskId,不是 userId)
|
||
redisService.hSetAll(RedisUtil.tokenKey(token),
|
||
Map.of("userId", String.valueOf(i + 100), // 假设 userId > 100 不存在,但不影响锁逻辑
|
||
"role", "ANNOTATOR", "companyId", "1", "username", "annotator" + i),
|
||
3600L);
|
||
}
|
||
}
|
||
|
||
@AfterEach
|
||
void cleanup() {
|
||
tokens.forEach(token -> redisService.delete(RedisUtil.tokenKey(token)));
|
||
if (taskId != null) {
|
||
redisService.delete(RedisUtil.taskClaimKey(taskId));
|
||
}
|
||
}
|
||
|
||
@Test
|
||
@DisplayName("10 线程并发抢同一任务:恰好 1 人成功,其余 9 人收到 409 TASK_CLAIMED")
|
||
void concurrentClaim_onlyOneSucceeds() throws InterruptedException {
|
||
ExecutorService executor = Executors.newFixedThreadPool(10);
|
||
CountDownLatch startLatch = new CountDownLatch(1);
|
||
CountDownLatch doneLatch = new CountDownLatch(10);
|
||
|
||
AtomicInteger successCount = new AtomicInteger(0);
|
||
AtomicInteger conflictCount = new AtomicInteger(0);
|
||
|
||
for (int i = 0; i < 10; i++) {
|
||
final String token = tokens.get(i);
|
||
executor.submit(() -> {
|
||
try {
|
||
startLatch.await(); // 等待起跑信号,最大化并发
|
||
|
||
HttpHeaders headers = new HttpHeaders();
|
||
headers.set("Authorization", "Bearer " + token);
|
||
HttpEntity<Void> request = new HttpEntity<>(headers);
|
||
|
||
ResponseEntity<Map> response = restTemplate.exchange(
|
||
baseUrl("/label/api/tasks/" + taskId + "/claim"),
|
||
HttpMethod.POST, request, Map.class);
|
||
|
||
if (response.getStatusCode() == HttpStatus.OK) {
|
||
successCount.incrementAndGet();
|
||
} else if (response.getStatusCode() == HttpStatus.CONFLICT) {
|
||
conflictCount.incrementAndGet();
|
||
}
|
||
} catch (Exception e) {
|
||
conflictCount.incrementAndGet(); // 异常也算失败
|
||
} finally {
|
||
doneLatch.countDown();
|
||
}
|
||
});
|
||
}
|
||
|
||
startLatch.countDown(); // 同时放行所有线程
|
||
doneLatch.await(30, TimeUnit.SECONDS);
|
||
executor.shutdown();
|
||
|
||
// 恰好 1 人成功
|
||
assertThat(successCount.get()).isEqualTo(1);
|
||
// 其余 9 人失败(409 或异常)
|
||
assertThat(conflictCount.get()).isEqualTo(9);
|
||
|
||
// DB 中 claimed_by 有且仅有一个值
|
||
String claimedByStr = jdbcTemplate.queryForObject(
|
||
"SELECT claimed_by::text FROM annotation_task WHERE id = ?",
|
||
String.class, taskId);
|
||
assertThat(claimedByStr).isNotNull();
|
||
|
||
// DB 中状态为 IN_PROGRESS
|
||
String status = jdbcTemplate.queryForObject(
|
||
"SELECT status FROM annotation_task WHERE id = ?", String.class, taskId);
|
||
assertThat(status).isEqualTo("IN_PROGRESS");
|
||
}
|
||
}
|