From 927e4f1cf3add63cf5ee9f05b033aed66d41fc35 Mon Sep 17 00:00:00 2001 From: wh Date: Thu, 9 Apr 2026 15:36:11 +0800 Subject: [PATCH] =?UTF-8?q?feat(phase5):=20US3+US4=20=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E9=A2=86=E5=8F=96=E3=80=81=E6=8F=90=E5=8F=96=E6=A0=87=E6=B3=A8?= =?UTF-8?q?=E4=B8=8E=E5=AE=A1=E6=89=B9=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 任务领取(TaskClaimService):Redis SET NX + DB WHERE status=UNCLAIMED 双重并发防护 - 任务管理(TaskService/TaskController):任务池/我的任务/待审批/全部任务/创建/指派 10 端点 - 提取标注(ExtractionService/ExtractionController):AI 预标注/更新/提交/审批/驳回 5 端点 - 审批解耦(ExtractionApprovedEventListener):@TransactionalEventListener(AFTER_COMMIT) + REQUIRES_NEW 确保 AI QA 生成在审批事务提交后独立执行,异常不回滚审批结果 - 状态实体:AnnotationTask/AnnotationTaskHistory/AnnotationResult/TrainingDataset - 集成测试:并发领取安全(10 线程恰好 1 成功)+ 审批流(通过/自审/驳回重领) --- .../controller/ExtractionController.java | 73 +++++ .../annotation/entity/AnnotationResult.java | 32 +++ .../annotation/entity/TrainingDataset.java | 46 +++ .../event/ExtractionApprovedEvent.java | 31 ++ .../mapper/AnnotationResultMapper.java | 36 +++ .../mapper/TrainingDatasetMapper.java | 36 +++ .../ExtractionApprovedEventListener.java | 131 +++++++++ .../annotation/service/ExtractionService.java | 272 ++++++++++++++++++ .../task/controller/TaskController.java | 126 ++++++++ .../label/module/task/dto/TaskResponse.java | 26 ++ .../module/task/entity/AnnotationTask.java | 59 ++++ .../task/entity/AnnotationTaskHistory.java | 43 +++ .../task/mapper/AnnotationTaskMapper.java | 30 ++ .../module/task/mapper/TaskHistoryMapper.java | 14 + .../module/task/service/TaskClaimService.java | 171 +++++++++++ .../module/task/service/TaskService.java | 201 +++++++++++++ .../ExtractionApprovalIntegrationTest.java | 217 ++++++++++++++ .../integration/TaskClaimConcurrencyTest.java | 135 +++++++++ 18 files changed, 1679 insertions(+) create mode 100644 src/main/java/com/label/module/annotation/controller/ExtractionController.java create mode 100644 src/main/java/com/label/module/annotation/entity/AnnotationResult.java create mode 100644 src/main/java/com/label/module/annotation/entity/TrainingDataset.java create mode 100644 src/main/java/com/label/module/annotation/event/ExtractionApprovedEvent.java create mode 100644 src/main/java/com/label/module/annotation/mapper/AnnotationResultMapper.java create mode 100644 src/main/java/com/label/module/annotation/mapper/TrainingDatasetMapper.java create mode 100644 src/main/java/com/label/module/annotation/service/ExtractionApprovedEventListener.java create mode 100644 src/main/java/com/label/module/annotation/service/ExtractionService.java create mode 100644 src/main/java/com/label/module/task/controller/TaskController.java create mode 100644 src/main/java/com/label/module/task/dto/TaskResponse.java create mode 100644 src/main/java/com/label/module/task/entity/AnnotationTask.java create mode 100644 src/main/java/com/label/module/task/entity/AnnotationTaskHistory.java create mode 100644 src/main/java/com/label/module/task/mapper/AnnotationTaskMapper.java create mode 100644 src/main/java/com/label/module/task/mapper/TaskHistoryMapper.java create mode 100644 src/main/java/com/label/module/task/service/TaskClaimService.java create mode 100644 src/main/java/com/label/module/task/service/TaskService.java create mode 100644 src/test/java/com/label/integration/ExtractionApprovalIntegrationTest.java create mode 100644 src/test/java/com/label/integration/TaskClaimConcurrencyTest.java diff --git a/src/main/java/com/label/module/annotation/controller/ExtractionController.java b/src/main/java/com/label/module/annotation/controller/ExtractionController.java new file mode 100644 index 0000000..f202dbe --- /dev/null +++ b/src/main/java/com/label/module/annotation/controller/ExtractionController.java @@ -0,0 +1,73 @@ +package com.label.module.annotation.controller; + +import com.label.common.result.Result; +import com.label.common.shiro.TokenPrincipal; +import com.label.module.annotation.service.ExtractionService; +import jakarta.servlet.http.HttpServletRequest; +import lombok.RequiredArgsConstructor; +import org.apache.shiro.authz.annotation.RequiresRoles; +import org.springframework.web.bind.annotation.*; + +import java.util.Map; + +/** + * 提取阶段标注工作台接口(5 个端点)。 + */ +@RestController +@RequestMapping("/api/extraction") +@RequiredArgsConstructor +public class ExtractionController { + + private final ExtractionService extractionService; + + /** GET /api/extraction/{taskId} — 获取当前标注结果 */ + @GetMapping("/{taskId}") + @RequiresRoles("ANNOTATOR") + public Result> getResult(@PathVariable Long taskId, + HttpServletRequest request) { + return Result.success(extractionService.getResult(taskId, principal(request))); + } + + /** PUT /api/extraction/{taskId} — 更新标注结果(整体覆盖) */ + @PutMapping("/{taskId}") + @RequiresRoles("ANNOTATOR") + public Result updateResult(@PathVariable Long taskId, + @RequestBody String resultJson, + HttpServletRequest request) { + extractionService.updateResult(taskId, resultJson, principal(request)); + return Result.success(null); + } + + /** POST /api/extraction/{taskId}/submit — 提交标注结果 */ + @PostMapping("/{taskId}/submit") + @RequiresRoles("ANNOTATOR") + public Result submit(@PathVariable Long taskId, + HttpServletRequest request) { + extractionService.submit(taskId, principal(request)); + return Result.success(null); + } + + /** POST /api/extraction/{taskId}/approve — 审批通过(REVIEWER) */ + @PostMapping("/{taskId}/approve") + @RequiresRoles("REVIEWER") + public Result approve(@PathVariable Long taskId, + HttpServletRequest request) { + extractionService.approve(taskId, principal(request)); + return Result.success(null); + } + + /** POST /api/extraction/{taskId}/reject — 驳回(REVIEWER) */ + @PostMapping("/{taskId}/reject") + @RequiresRoles("REVIEWER") + public Result reject(@PathVariable Long taskId, + @RequestBody Map body, + HttpServletRequest request) { + String reason = body != null ? body.get("reason") : null; + extractionService.reject(taskId, reason, principal(request)); + return Result.success(null); + } + + private TokenPrincipal principal(HttpServletRequest request) { + return (TokenPrincipal) request.getAttribute("__token_principal__"); + } +} diff --git a/src/main/java/com/label/module/annotation/entity/AnnotationResult.java b/src/main/java/com/label/module/annotation/entity/AnnotationResult.java new file mode 100644 index 0000000..6b9dce1 --- /dev/null +++ b/src/main/java/com/label/module/annotation/entity/AnnotationResult.java @@ -0,0 +1,32 @@ +package com.label.module.annotation.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +import java.time.LocalDateTime; + +/** + * 标注结果实体,对应 annotation_result 表。 + * resultJson 存储 JSONB 格式的标注内容(整体替换语义)。 + */ +@Data +@TableName("annotation_result") +public class AnnotationResult { + + @TableId(type = IdType.AUTO) + private Long id; + + private Long taskId; + + /** 所属公司(多租户键) */ + private Long companyId; + + /** 标注结果 JSON(JSONB,整体覆盖) */ + private String resultJson; + + private LocalDateTime createdAt; + + private LocalDateTime updatedAt; +} diff --git a/src/main/java/com/label/module/annotation/entity/TrainingDataset.java b/src/main/java/com/label/module/annotation/entity/TrainingDataset.java new file mode 100644 index 0000000..feafa45 --- /dev/null +++ b/src/main/java/com/label/module/annotation/entity/TrainingDataset.java @@ -0,0 +1,46 @@ +package com.label.module.annotation.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +import java.time.LocalDateTime; + +/** + * 训练数据集实体,对应 training_dataset 表。 + * + * status 取值:PENDING_REVIEW / APPROVED / REJECTED + * sampleType 取值:TEXT / IMAGE / VIDEO_FRAME + */ +@Data +@TableName("training_dataset") +public class TrainingDataset { + + @TableId(type = IdType.AUTO) + private Long id; + + /** 所属公司(多租户键) */ + private Long companyId; + + private Long taskId; + + private Long sourceId; + + /** 样本类型:TEXT / IMAGE / VIDEO_FRAME */ + private String sampleType; + + /** GLM fine-tune 格式的 JSON 字符串(JSONB) */ + private String glmFormatJson; + + /** 状态:PENDING_REVIEW / APPROVED / REJECTED */ + private String status; + + private Long exportBatchId; + + private LocalDateTime exportedAt; + + private LocalDateTime createdAt; + + private LocalDateTime updatedAt; +} diff --git a/src/main/java/com/label/module/annotation/event/ExtractionApprovedEvent.java b/src/main/java/com/label/module/annotation/event/ExtractionApprovedEvent.java new file mode 100644 index 0000000..d59f332 --- /dev/null +++ b/src/main/java/com/label/module/annotation/event/ExtractionApprovedEvent.java @@ -0,0 +1,31 @@ +package com.label.module.annotation.event; + +import lombok.Getter; +import org.springframework.context.ApplicationEvent; + +/** + * 提取任务审批通过事件。 + * 由 ExtractionService.approve() 在事务提交前发布(@TransactionalEventListener 在 AFTER_COMMIT 处理)。 + * + * 设计约束:AI 调用禁止在审批事务内执行,必须通过此事件解耦。 + */ +@Getter +public class ExtractionApprovedEvent extends ApplicationEvent { + + private final Long taskId; + private final Long sourceId; + /** 资料类型:TEXT / IMAGE,决定调用哪个 AI 生成接口 */ + private final String sourceType; + private final Long companyId; + private final Long reviewerId; + + public ExtractionApprovedEvent(Object source, Long taskId, Long sourceId, + String sourceType, Long companyId, Long reviewerId) { + super(source); + this.taskId = taskId; + this.sourceId = sourceId; + this.sourceType = sourceType; + this.companyId = companyId; + this.reviewerId = reviewerId; + } +} diff --git a/src/main/java/com/label/module/annotation/mapper/AnnotationResultMapper.java b/src/main/java/com/label/module/annotation/mapper/AnnotationResultMapper.java new file mode 100644 index 0000000..4290c62 --- /dev/null +++ b/src/main/java/com/label/module/annotation/mapper/AnnotationResultMapper.java @@ -0,0 +1,36 @@ +package com.label.module.annotation.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.label.module.annotation.entity.AnnotationResult; +import org.apache.ibatis.annotations.*; + +/** + * annotation_result 表 Mapper。 + */ +@Mapper +public interface AnnotationResultMapper extends BaseMapper { + + /** + * 整体覆盖标注结果 JSON(JSONB 字段)。 + * + * @param taskId 任务 ID + * @param resultJson 新的 JSON 字符串(整体替换) + * @param companyId 当前租户 + * @return 影响行数 + */ + @Update("UPDATE annotation_result " + + "SET result_json = #{resultJson}::jsonb, updated_at = NOW() " + + "WHERE task_id = #{taskId} AND company_id = #{companyId}") + int updateResultJson(@Param("taskId") Long taskId, + @Param("resultJson") String resultJson, + @Param("companyId") Long companyId); + + /** + * 按任务 ID 查询标注结果。 + * + * @param taskId 任务 ID + * @return 标注结果(不存在则返回 null) + */ + @Select("SELECT * FROM annotation_result WHERE task_id = #{taskId}") + AnnotationResult selectByTaskId(@Param("taskId") Long taskId); +} diff --git a/src/main/java/com/label/module/annotation/mapper/TrainingDatasetMapper.java b/src/main/java/com/label/module/annotation/mapper/TrainingDatasetMapper.java new file mode 100644 index 0000000..94eefde --- /dev/null +++ b/src/main/java/com/label/module/annotation/mapper/TrainingDatasetMapper.java @@ -0,0 +1,36 @@ +package com.label.module.annotation.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.label.module.annotation.entity.TrainingDataset; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Update; +import org.apache.ibatis.annotations.Delete; + +/** + * training_dataset 表 Mapper。 + */ +@Mapper +public interface TrainingDatasetMapper extends BaseMapper { + + /** + * 按任务 ID 将训练样本状态改为 APPROVED。 + * + * @param taskId 任务 ID + * @param companyId 当前租户 + * @return 影响行数 + */ + @Update("UPDATE training_dataset SET status = 'APPROVED', updated_at = NOW() " + + "WHERE task_id = #{taskId} AND company_id = #{companyId}") + int approveByTaskId(@Param("taskId") Long taskId, @Param("companyId") Long companyId); + + /** + * 按任务 ID 删除训练样本(驳回时清除候选数据)。 + * + * @param taskId 任务 ID + * @param companyId 当前租户 + * @return 影响行数 + */ + @Delete("DELETE FROM training_dataset WHERE task_id = #{taskId} AND company_id = #{companyId}") + int deleteByTaskId(@Param("taskId") Long taskId, @Param("companyId") Long companyId); +} diff --git a/src/main/java/com/label/module/annotation/service/ExtractionApprovedEventListener.java b/src/main/java/com/label/module/annotation/service/ExtractionApprovedEventListener.java new file mode 100644 index 0000000..8143330 --- /dev/null +++ b/src/main/java/com/label/module/annotation/service/ExtractionApprovedEventListener.java @@ -0,0 +1,131 @@ +package com.label.module.annotation.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.common.ai.AiServiceClient; +import com.label.common.context.CompanyContext; +import com.label.module.annotation.entity.TrainingDataset; +import com.label.module.annotation.event.ExtractionApprovedEvent; +import com.label.module.annotation.mapper.AnnotationResultMapper; +import com.label.module.annotation.mapper.TrainingDatasetMapper; +import com.label.module.source.entity.SourceData; +import com.label.module.source.mapper.SourceDataMapper; +import com.label.module.task.service.TaskClaimService; +import com.label.module.task.service.TaskService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; +import org.springframework.transaction.event.TransactionPhase; +import org.springframework.transaction.event.TransactionalEventListener; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * 提取审批通过后的异步处理器。 + * + * 设计约束(关键): + * - @TransactionalEventListener(AFTER_COMMIT):确保在审批事务提交后才触发 AI 调用 + * - @Transactional(REQUIRES_NEW):在独立新事务中写 DB,与审批事务完全隔离 + * - 异常不会回滚审批事务(已提交),但会在日志中记录 + * + * 处理流程: + * 1. 调用 AI 生成候选问答对(Text/Image 走不同端点) + * 2. 写入 training_dataset(status=PENDING_REVIEW) + * 3. 创建 QA_GENERATION 任务(status=UNCLAIMED) + * 4. 更新 source_data 状态为 QA_REVIEW + */ +@Slf4j +@Component +@RequiredArgsConstructor +public class ExtractionApprovedEventListener { + + private final TrainingDatasetMapper datasetMapper; + private final SourceDataMapper sourceDataMapper; + private final TaskService taskService; + private final AiServiceClient aiServiceClient; + private final ObjectMapper objectMapper; + + @Value("${rustfs.bucket:label-source-data}") + private String bucket; + + @TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT) + @Transactional(propagation = Propagation.REQUIRES_NEW) + public void onExtractionApproved(ExtractionApprovedEvent event) { + log.debug("处理提取审批通过事件: taskId={}, sourceId={}", event.getTaskId(), event.getSourceId()); + + // 设置多租户上下文(新事务中 ThreadLocal 已清除) + CompanyContext.set(event.getCompanyId()); + try { + processEvent(event); + } catch (Exception e) { + log.error("处理审批通过事件失败(taskId={}):{}", event.getTaskId(), e.getMessage(), e); + // 不向上抛出,审批操作已提交,此处失败不回滚审批 + } finally { + CompanyContext.clear(); + } + } + + private void processEvent(ExtractionApprovedEvent event) { + SourceData source = sourceDataMapper.selectById(event.getSourceId()); + if (source == null) { + log.warn("资料不存在,跳过后续处理: sourceId={}", event.getSourceId()); + return; + } + + // 1. 调用 AI 生成候选问答对 + AiServiceClient.ExtractionRequest req = AiServiceClient.ExtractionRequest.builder() + .sourceId(source.getId()) + .filePath(source.getFilePath()) + .bucket(bucket) + .build(); + + List> qaPairs; + try { + AiServiceClient.QaGenResponse response = "IMAGE".equals(source.getDataType()) + ? aiServiceClient.genImageQa(req) + : aiServiceClient.genTextQa(req); + qaPairs = response != null && response.getQaPairs() != null + ? response.getQaPairs() : Collections.emptyList(); + } catch (Exception e) { + log.warn("AI 问答生成失败(taskId={}):{},将使用空问答对", event.getTaskId(), e.getMessage()); + qaPairs = Collections.emptyList(); + } + + // 2. 写入 training_dataset(PENDING_REVIEW) + String sampleType = "IMAGE".equals(source.getDataType()) ? "IMAGE" : "TEXT"; + String glmJson = buildGlmJson(qaPairs); + + TrainingDataset dataset = new TrainingDataset(); + dataset.setCompanyId(event.getCompanyId()); + dataset.setTaskId(event.getTaskId()); + dataset.setSourceId(event.getSourceId()); + dataset.setSampleType(sampleType); + dataset.setGlmFormatJson(glmJson); + dataset.setStatus("PENDING_REVIEW"); + datasetMapper.insert(dataset); + + // 3. 创建 QA_GENERATION 任务(UNCLAIMED) + taskService.createTask(event.getSourceId(), "QA_GENERATION", event.getCompanyId()); + + // 4. 更新 source_data 状态为 QA_REVIEW + sourceDataMapper.updateStatus(event.getSourceId(), "QA_REVIEW", event.getCompanyId()); + + log.debug("审批通过后续处理完成: taskId={}, 新 QA 任务已创建", event.getTaskId()); + } + + /** + * 将 AI 生成的问答对列表转换为 GLM fine-tune 格式 JSON。 + */ + private String buildGlmJson(List> qaPairs) { + try { + return objectMapper.writeValueAsString(Map.of("conversations", qaPairs)); + } catch (Exception e) { + log.error("构建 GLM JSON 失败", e); + return "{\"conversations\":[]}"; + } + } +} diff --git a/src/main/java/com/label/module/annotation/service/ExtractionService.java b/src/main/java/com/label/module/annotation/service/ExtractionService.java new file mode 100644 index 0000000..ca96195 --- /dev/null +++ b/src/main/java/com/label/module/annotation/service/ExtractionService.java @@ -0,0 +1,272 @@ +package com.label.module.annotation.service; + +import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.common.ai.AiServiceClient; +import com.label.common.exception.BusinessException; +import com.label.common.shiro.TokenPrincipal; +import com.label.common.statemachine.StateValidator; +import com.label.common.statemachine.TaskStatus; +import com.label.module.annotation.entity.AnnotationResult; +import com.label.module.annotation.entity.TrainingDataset; +import com.label.module.annotation.event.ExtractionApprovedEvent; +import com.label.module.annotation.mapper.AnnotationResultMapper; +import com.label.module.annotation.mapper.TrainingDatasetMapper; +import com.label.module.source.entity.SourceData; +import com.label.module.source.mapper.SourceDataMapper; +import com.label.module.task.entity.AnnotationTask; +import com.label.module.task.mapper.AnnotationTaskMapper; +import com.label.module.task.service.TaskClaimService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.time.LocalDateTime; +import java.util.Collections; +import java.util.Map; + +/** + * 提取阶段标注服务:AI 预标注、更新结果、提交、审批、驳回。 + * + * 关键设计: + * - approve() 内禁止直接调用 AI,通过 ExtractionApprovedEvent 解耦(AFTER_COMMIT) + * - 所有写操作包裹在 @Transactional 中,确保任务状态和历史的一致性 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class ExtractionService { + + private final AnnotationTaskMapper taskMapper; + private final AnnotationResultMapper resultMapper; + private final TrainingDatasetMapper datasetMapper; + private final SourceDataMapper sourceDataMapper; + private final TaskClaimService taskClaimService; + private final AiServiceClient aiServiceClient; + private final ApplicationEventPublisher eventPublisher; + private final ObjectMapper objectMapper; + + @Value("${rustfs.bucket:label-source-data}") + private String bucket; + + // ------------------------------------------------------------------ AI 预标注 -- + + /** + * AI 辅助预标注:调用 AI 服务,将结果写入 annotation_result。 + * 注:此方法在 @Transactional 外调用(AI 调用不应在事务内),由控制器直接调用。 + */ + public void aiPreAnnotate(Long taskId, TokenPrincipal principal) { + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + + SourceData source = sourceDataMapper.selectById(task.getSourceId()); + if (source == null) { + throw new BusinessException("NOT_FOUND", "关联资料不存在", HttpStatus.NOT_FOUND); + } + + // 调用 AI 服务(在事务外,避免长时间持有 DB 连接) + AiServiceClient.ExtractionRequest req = AiServiceClient.ExtractionRequest.builder() + .sourceId(source.getId()) + .filePath(source.getFilePath()) + .bucket(bucket) + .build(); + + AiServiceClient.ExtractionResponse aiResponse; + try { + if ("IMAGE".equals(source.getDataType())) { + aiResponse = aiServiceClient.extractImage(req); + } else { + aiResponse = aiServiceClient.extractText(req); + } + } catch (Exception e) { + log.warn("AI 预标注调用失败(任务 {}):{}", taskId, e.getMessage()); + // AI 失败不阻塞流程,写入空结果 + aiResponse = new AiServiceClient.ExtractionResponse(); + aiResponse.setItems(Collections.emptyList()); + } + + // 将 AI 结果写入 annotation_result(UPSERT 语义) + writeOrUpdateResult(taskId, principal.getCompanyId(), aiResponse.getItems()); + } + + // ------------------------------------------------------------------ 更新结果 -- + + /** + * 人工更新标注结果(整体覆盖,PUT 语义)。 + * + * @param taskId 任务 ID + * @param resultJson 新的标注结果 JSON 字符串 + * @param principal 当前用户 + */ + @Transactional + public void updateResult(Long taskId, String resultJson, TokenPrincipal principal) { + validateAndGetTask(taskId, principal.getCompanyId()); + + // 校验 JSON 格式 + try { + objectMapper.readTree(resultJson); + } catch (Exception e) { + throw new BusinessException("INVALID_JSON", "标注结果 JSON 格式不合法", HttpStatus.BAD_REQUEST); + } + + int updated = resultMapper.updateResultJson(taskId, resultJson, principal.getCompanyId()); + if (updated == 0) { + // 不存在则新建 + AnnotationResult result = new AnnotationResult(); + result.setTaskId(taskId); + result.setCompanyId(principal.getCompanyId()); + result.setResultJson(resultJson); + resultMapper.insert(result); + } + } + + // ------------------------------------------------------------------ 提交 -- + + /** + * 提交提取结果(IN_PROGRESS → SUBMITTED)。 + */ + @Transactional + public void submit(Long taskId, TokenPrincipal principal) { + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.SUBMITTED); + + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "SUBMITTED") + .set(AnnotationTask::getSubmittedAt, LocalDateTime.now())); + + taskClaimService.insertHistory(taskId, principal.getCompanyId(), + task.getStatus(), "SUBMITTED", + principal.getUserId(), principal.getRole(), null); + } + + // ------------------------------------------------------------------ 审批通过 -- + + /** + * 审批通过(SUBMITTED → APPROVED)。 + * + * 两阶段: + * 1. 同步事务:is_final=true,状态推进,写历史 + * 2. 事务提交后(AFTER_COMMIT):AI 生成问答对 → training_dataset → QA 任务 → source_data 状态 + * + * 注:AI 调用严禁在此事务内执行。 + */ + @Transactional + public void approve(Long taskId, TokenPrincipal principal) { + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + + // 自审校验 + if (principal.getUserId().equals(task.getClaimedBy())) { + throw new BusinessException("SELF_REVIEW_FORBIDDEN", + "不允许审批自己提交的任务", HttpStatus.FORBIDDEN); + } + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.APPROVED); + + // 标记为最终结果 + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "APPROVED") + .set(AnnotationTask::getIsFinal, true) + .set(AnnotationTask::getCompletedAt, LocalDateTime.now())); + + taskClaimService.insertHistory(taskId, principal.getCompanyId(), + "SUBMITTED", "APPROVED", + principal.getUserId(), principal.getRole(), null); + + // 获取资料信息,用于事件 + SourceData source = sourceDataMapper.selectById(task.getSourceId()); + String sourceType = source != null ? source.getDataType() : "TEXT"; + + // 发布事件(@TransactionalEventListener(AFTER_COMMIT) 处理 AI 调用) + eventPublisher.publishEvent(new ExtractionApprovedEvent( + this, taskId, task.getSourceId(), sourceType, + principal.getCompanyId(), principal.getUserId())); + } + + // ------------------------------------------------------------------ 驳回 -- + + /** + * 驳回提取结果(SUBMITTED → REJECTED)。 + */ + @Transactional + public void reject(Long taskId, String reason, TokenPrincipal principal) { + if (reason == null || reason.isBlank()) { + throw new BusinessException("REASON_REQUIRED", "驳回原因不能为空", HttpStatus.BAD_REQUEST); + } + + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + + // 自审校验 + if (principal.getUserId().equals(task.getClaimedBy())) { + throw new BusinessException("SELF_REVIEW_FORBIDDEN", + "不允许驳回自己提交的任务", HttpStatus.FORBIDDEN); + } + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.REJECTED); + + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "REJECTED") + .set(AnnotationTask::getRejectReason, reason)); + + taskClaimService.insertHistory(taskId, principal.getCompanyId(), + "SUBMITTED", "REJECTED", + principal.getUserId(), principal.getRole(), reason); + } + + // ------------------------------------------------------------------ 查询 -- + + /** + * 获取当前标注结果。 + */ + public Map getResult(Long taskId, TokenPrincipal principal) { + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + AnnotationResult result = resultMapper.selectByTaskId(taskId); + SourceData source = sourceDataMapper.selectById(task.getSourceId()); + + return Map.of( + "taskId", taskId, + "sourceType", source != null ? source.getDataType() : "", + "sourceFilePath", source != null && source.getFilePath() != null ? source.getFilePath() : "", + "isFinal", task.getIsFinal() != null && task.getIsFinal(), + "resultJson", result != null ? result.getResultJson() : "[]" + ); + } + + // ------------------------------------------------------------------ 私有工具 -- + + /** + * 校验任务存在性(多租户自动过滤)。 + */ + private AnnotationTask validateAndGetTask(Long taskId, Long companyId) { + AnnotationTask task = taskMapper.selectById(taskId); + if (task == null) { + throw new BusinessException("NOT_FOUND", "任务不存在: " + taskId, HttpStatus.NOT_FOUND); + } + return task; + } + + private void writeOrUpdateResult(Long taskId, Long companyId, java.util.List items) { + try { + String json = objectMapper.writeValueAsString(Map.of("items", items != null ? items : Collections.emptyList())); + int updated = resultMapper.updateResultJson(taskId, json, companyId); + if (updated == 0) { + AnnotationResult result = new AnnotationResult(); + result.setTaskId(taskId); + result.setCompanyId(companyId); + result.setResultJson(json); + resultMapper.insert(result); + } + } catch (Exception e) { + log.error("写入 AI 预标注结果失败: taskId={}", taskId, e); + } + } +} diff --git a/src/main/java/com/label/module/task/controller/TaskController.java b/src/main/java/com/label/module/task/controller/TaskController.java new file mode 100644 index 0000000..17d1565 --- /dev/null +++ b/src/main/java/com/label/module/task/controller/TaskController.java @@ -0,0 +1,126 @@ +package com.label.module.task.controller; + +import com.label.common.result.PageResult; +import com.label.common.result.Result; +import com.label.common.shiro.TokenPrincipal; +import com.label.module.task.dto.TaskResponse; +import com.label.module.task.service.TaskClaimService; +import com.label.module.task.service.TaskService; +import jakarta.servlet.http.HttpServletRequest; +import lombok.RequiredArgsConstructor; +import org.apache.shiro.authz.annotation.RequiresRoles; +import org.springframework.web.bind.annotation.*; + +import java.util.Map; + +/** + * 任务管理接口(10 个端点)。 + */ +@RestController +@RequestMapping("/api/tasks") +@RequiredArgsConstructor +public class TaskController { + + private final TaskService taskService; + private final TaskClaimService taskClaimService; + + /** GET /api/tasks/pool — 查询可领取任务池(角色感知) */ + @GetMapping("/pool") + @RequiresRoles("ANNOTATOR") + public Result> getPool( + @RequestParam(defaultValue = "1") int page, + @RequestParam(defaultValue = "20") int pageSize, + HttpServletRequest request) { + return Result.success(taskService.getPool(page, pageSize, principal(request))); + } + + /** GET /api/tasks/mine — 查询我的任务 */ + @GetMapping("/mine") + @RequiresRoles("ANNOTATOR") + public Result> getMine( + @RequestParam(defaultValue = "1") int page, + @RequestParam(defaultValue = "20") int pageSize, + @RequestParam(required = false) String status, + HttpServletRequest request) { + return Result.success(taskService.getMine(page, pageSize, status, principal(request))); + } + + /** GET /api/tasks/pending-review — 待审批队列(REVIEWER 专属) */ + @GetMapping("/pending-review") + @RequiresRoles("REVIEWER") + public Result> getPendingReview( + @RequestParam(defaultValue = "1") int page, + @RequestParam(defaultValue = "20") int pageSize, + @RequestParam(required = false) String taskType) { + return Result.success(taskService.getPendingReview(page, pageSize, taskType)); + } + + /** GET /api/tasks — 查询全部任务(ADMIN) */ + @GetMapping + @RequiresRoles("ADMIN") + public Result> getAll( + @RequestParam(defaultValue = "1") int page, + @RequestParam(defaultValue = "20") int pageSize, + @RequestParam(required = false) String status, + @RequestParam(required = false) String taskType) { + return Result.success(taskService.getAll(page, pageSize, status, taskType)); + } + + /** POST /api/tasks — 创建任务(ADMIN) */ + @PostMapping + @RequiresRoles("ADMIN") + public Result createTask(@RequestBody Map body, + HttpServletRequest request) { + Long sourceId = Long.parseLong(body.get("sourceId").toString()); + String taskType = body.get("taskType").toString(); + TokenPrincipal principal = principal(request); + return Result.success(taskService.toPublicResponse( + taskService.createTask(sourceId, taskType, principal.getCompanyId()))); + } + + /** GET /api/tasks/{id} — 查询任务详情 */ + @GetMapping("/{id}") + @RequiresRoles("ANNOTATOR") + public Result getById(@PathVariable Long id) { + return Result.success(taskService.toPublicResponse(taskService.getById(id))); + } + + /** POST /api/tasks/{id}/claim — 领取任务 */ + @PostMapping("/{id}/claim") + @RequiresRoles("ANNOTATOR") + public Result claim(@PathVariable Long id, HttpServletRequest request) { + taskClaimService.claim(id, principal(request)); + return Result.success(null); + } + + /** POST /api/tasks/{id}/unclaim — 放弃任务 */ + @PostMapping("/{id}/unclaim") + @RequiresRoles("ANNOTATOR") + public Result unclaim(@PathVariable Long id, HttpServletRequest request) { + taskClaimService.unclaim(id, principal(request)); + return Result.success(null); + } + + /** POST /api/tasks/{id}/reclaim — 重领被驳回的任务 */ + @PostMapping("/{id}/reclaim") + @RequiresRoles("ANNOTATOR") + public Result reclaim(@PathVariable Long id, HttpServletRequest request) { + taskClaimService.reclaim(id, principal(request)); + return Result.success(null); + } + + /** PUT /api/tasks/{id}/reassign — ADMIN 强制指派 */ + @PutMapping("/{id}/reassign") + @RequiresRoles("ADMIN") + public Result reassign(@PathVariable Long id, + @RequestBody Map body, + HttpServletRequest request) { + Long targetUserId = Long.parseLong(body.get("userId").toString()); + taskService.reassign(id, targetUserId, principal(request)); + return Result.success(null); + } + + private TokenPrincipal principal(HttpServletRequest request) { + return (TokenPrincipal) request.getAttribute("__token_principal__"); + } +} diff --git a/src/main/java/com/label/module/task/dto/TaskResponse.java b/src/main/java/com/label/module/task/dto/TaskResponse.java new file mode 100644 index 0000000..a6d2550 --- /dev/null +++ b/src/main/java/com/label/module/task/dto/TaskResponse.java @@ -0,0 +1,26 @@ +package com.label.module.task.dto; + +import lombok.Builder; +import lombok.Data; + +import java.time.LocalDateTime; + +/** + * 任务接口统一响应体(任务池、我的任务、任务详情均复用)。 + */ +@Data +@Builder +public class TaskResponse { + private Long id; + private Long sourceId; + /** 任务类型(对应 taskType 字段):EXTRACTION / QA_GENERATION */ + private String taskType; + private String status; + private Long claimedBy; + private LocalDateTime claimedAt; + private LocalDateTime submittedAt; + private LocalDateTime completedAt; + /** 驳回原因(REJECTED 状态时非空) */ + private String rejectReason; + private LocalDateTime createdAt; +} diff --git a/src/main/java/com/label/module/task/entity/AnnotationTask.java b/src/main/java/com/label/module/task/entity/AnnotationTask.java new file mode 100644 index 0000000..c8fcce7 --- /dev/null +++ b/src/main/java/com/label/module/task/entity/AnnotationTask.java @@ -0,0 +1,59 @@ +package com.label.module.task.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; + +import java.time.LocalDateTime; + +/** + * 标注任务实体,对应 annotation_task 表。 + * + * taskType 取值:EXTRACTION / QA_GENERATION + * status 取值:UNCLAIMED / IN_PROGRESS / SUBMITTED / APPROVED / REJECTED + */ +@Data +@TableName("annotation_task") +public class AnnotationTask { + + @TableId(type = IdType.AUTO) + private Long id; + + /** 所属公司(多租户键) */ + private Long companyId; + + /** 关联的原始资料 ID */ + private Long sourceId; + + /** 任务类型:EXTRACTION / QA_GENERATION */ + private String taskType; + + /** 任务状态 */ + private String status; + + /** 领取任务的用户 ID */ + private Long claimedBy; + + /** 领取时间 */ + private LocalDateTime claimedAt; + + /** 提交时间 */ + private LocalDateTime submittedAt; + + /** 完成时间(APPROVED 时设置) */ + private LocalDateTime completedAt; + + /** 是否最终结果(APPROVED 且无需再审)*/ + private Boolean isFinal; + + /** 使用的 AI 模型名称 */ + private String aiModel; + + /** 驳回原因 */ + private String rejectReason; + + private LocalDateTime createdAt; + + private LocalDateTime updatedAt; +} diff --git a/src/main/java/com/label/module/task/entity/AnnotationTaskHistory.java b/src/main/java/com/label/module/task/entity/AnnotationTaskHistory.java new file mode 100644 index 0000000..6e2c638 --- /dev/null +++ b/src/main/java/com/label/module/task/entity/AnnotationTaskHistory.java @@ -0,0 +1,43 @@ +package com.label.module.task.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Builder; +import lombok.Data; + +import java.time.LocalDateTime; + +/** + * 任务状态历史,对应 annotation_task_history 表(仅追加,无 UPDATE/DELETE)。 + */ +@Data +@Builder +@TableName("annotation_task_history") +public class AnnotationTaskHistory { + + @TableId(type = IdType.AUTO) + private Long id; + + private Long taskId; + + /** 所属公司(多租户键) */ + private Long companyId; + + /** 转换前状态(首次插入时为 null) */ + private String fromStatus; + + /** 转换后状态 */ + private String toStatus; + + /** 操作人 ID */ + private Long operatorId; + + /** 操作人角色 */ + private String operatorRole; + + /** 备注(驳回原因等) */ + private String comment; + + private LocalDateTime createdAt; +} diff --git a/src/main/java/com/label/module/task/mapper/AnnotationTaskMapper.java b/src/main/java/com/label/module/task/mapper/AnnotationTaskMapper.java new file mode 100644 index 0000000..3159e54 --- /dev/null +++ b/src/main/java/com/label/module/task/mapper/AnnotationTaskMapper.java @@ -0,0 +1,30 @@ +package com.label.module.task.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.label.module.task.entity.AnnotationTask; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Update; + +/** + * annotation_task 表 Mapper。 + */ +@Mapper +public interface AnnotationTaskMapper extends BaseMapper { + + /** + * 原子性领取任务:仅当任务为 UNCLAIMED 且属于当前租户时才更新。 + * 使用乐观 WHERE 条件实现并发安全(依赖数据库行级锁)。 + * + * @param taskId 任务 ID + * @param userId 领取用户 ID + * @param companyId 当前租户 + * @return 影响行数(0 = 任务已被他人领取或不存在) + */ + @Update("UPDATE annotation_task " + + "SET status = 'IN_PROGRESS', claimed_by = #{userId}, claimed_at = NOW(), updated_at = NOW() " + + "WHERE id = #{taskId} AND status = 'UNCLAIMED' AND company_id = #{companyId}") + int claimTask(@Param("taskId") Long taskId, + @Param("userId") Long userId, + @Param("companyId") Long companyId); +} diff --git a/src/main/java/com/label/module/task/mapper/TaskHistoryMapper.java b/src/main/java/com/label/module/task/mapper/TaskHistoryMapper.java new file mode 100644 index 0000000..c4f2f2a --- /dev/null +++ b/src/main/java/com/label/module/task/mapper/TaskHistoryMapper.java @@ -0,0 +1,14 @@ +package com.label.module.task.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.label.module.task.entity.AnnotationTaskHistory; +import org.apache.ibatis.annotations.Mapper; + +/** + * annotation_task_history 表 Mapper(仅追加,禁止 UPDATE/DELETE)。 + */ +@Mapper +public interface TaskHistoryMapper extends BaseMapper { + // 继承 BaseMapper 的 insert 用于追加历史记录 + // 严禁调用 update/delete 相关方法 +} diff --git a/src/main/java/com/label/module/task/service/TaskClaimService.java b/src/main/java/com/label/module/task/service/TaskClaimService.java new file mode 100644 index 0000000..fb2b65a --- /dev/null +++ b/src/main/java/com/label/module/task/service/TaskClaimService.java @@ -0,0 +1,171 @@ +package com.label.module.task.service; + +import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; +import com.label.common.exception.BusinessException; +import com.label.common.redis.RedisKeyManager; +import com.label.common.redis.RedisService; +import com.label.common.shiro.TokenPrincipal; +import com.label.common.statemachine.StateValidator; +import com.label.common.statemachine.TaskStatus; +import com.label.module.task.entity.AnnotationTask; +import com.label.module.task.entity.AnnotationTaskHistory; +import com.label.module.task.mapper.AnnotationTaskMapper; +import com.label.module.task.mapper.TaskHistoryMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +/** + * 任务领取/放弃/重领服务。 + * + * 并发安全设计: + * 1. Redis SET NX 作为分布式预锁(TTL 30s),快速拒绝并发请求 + * 2. DB UPDATE WHERE status='UNCLAIMED' 作为兜底原子操作 + * 两层防护确保同一任务只有一人可领取 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class TaskClaimService { + + /** Redis 分布式锁 TTL(秒) */ + private static final long CLAIM_LOCK_TTL = 30L; + + private final AnnotationTaskMapper taskMapper; + private final TaskHistoryMapper historyMapper; + private final RedisService redisService; + + // ------------------------------------------------------------------ 领取 -- + + /** + * 领取任务(双重防护:Redis NX + DB 原子更新)。 + * + * @param taskId 任务 ID + * @param principal 当前用户 + * @throws BusinessException TASK_CLAIMED(409) 任务已被他人领取 + */ + @Transactional + public void claim(Long taskId, TokenPrincipal principal) { + String lockKey = RedisKeyManager.taskClaimKey(taskId); + + // 1. Redis SET NX 预锁(快速失败) + boolean lockAcquired = redisService.setIfAbsent( + lockKey, principal.getUserId().toString(), CLAIM_LOCK_TTL); + if (!lockAcquired) { + throw new BusinessException("TASK_CLAIMED", "任务已被他人领取,请选择其他任务", HttpStatus.CONFLICT); + } + + // 2. DB 原子更新(WHERE status='UNCLAIMED' 兜底) + int affected = taskMapper.claimTask(taskId, principal.getUserId(), principal.getCompanyId()); + if (affected == 0) { + // DB 更新失败说明任务状态已变,清除刚设置的锁 + redisService.delete(lockKey); + throw new BusinessException("TASK_CLAIMED", "任务已被他人领取,请选择其他任务", HttpStatus.CONFLICT); + } + + // 3. 写入状态历史 + insertHistory(taskId, principal.getCompanyId(), + "UNCLAIMED", "IN_PROGRESS", + principal.getUserId(), principal.getRole(), null); + + log.debug("任务领取成功: taskId={}, userId={}", taskId, principal.getUserId()); + } + + // ------------------------------------------------------------------ 放弃 -- + + /** + * 放弃任务(IN_PROGRESS → UNCLAIMED)。 + * + * @param taskId 任务 ID + * @param principal 当前用户 + */ + @Transactional + public void unclaim(Long taskId, TokenPrincipal principal) { + AnnotationTask task = taskMapper.selectById(taskId); + validateTaskExists(task, taskId); + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.UNCLAIMED); + + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "UNCLAIMED") + .set(AnnotationTask::getClaimedBy, null) + .set(AnnotationTask::getClaimedAt, null)); + + // 清除 Redis 分布式锁 + redisService.delete(RedisKeyManager.taskClaimKey(taskId)); + + insertHistory(taskId, principal.getCompanyId(), + "IN_PROGRESS", "UNCLAIMED", + principal.getUserId(), principal.getRole(), null); + } + + // ------------------------------------------------------------------ 重领 -- + + /** + * 重领任务(REJECTED → IN_PROGRESS,仅原领取人可重领)。 + * + * @param taskId 任务 ID + * @param principal 当前用户 + */ + @Transactional + public void reclaim(Long taskId, TokenPrincipal principal) { + AnnotationTask task = taskMapper.selectById(taskId); + validateTaskExists(task, taskId); + + if (!"REJECTED".equals(task.getStatus())) { + throw new BusinessException("INVALID_STATE_TRANSITION", + "只有 REJECTED 状态的任务可以重领", HttpStatus.CONFLICT); + } + + if (!principal.getUserId().equals(task.getClaimedBy())) { + throw new BusinessException("FORBIDDEN", + "只有原领取人可以重领该任务", HttpStatus.FORBIDDEN); + } + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.IN_PROGRESS); + + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "IN_PROGRESS") + .set(AnnotationTask::getClaimedAt, java.time.LocalDateTime.now())); + + // 重新设置 Redis 锁(防止并发再次争抢) + redisService.setIfAbsent( + RedisKeyManager.taskClaimKey(taskId), + principal.getUserId().toString(), CLAIM_LOCK_TTL); + + insertHistory(taskId, principal.getCompanyId(), + "REJECTED", "IN_PROGRESS", + principal.getUserId(), principal.getRole(), null); + } + + // ------------------------------------------------------------------ 私有工具 -- + + private void validateTaskExists(AnnotationTask task, Long taskId) { + if (task == null) { + throw new BusinessException("NOT_FOUND", "任务不存在: " + taskId, HttpStatus.NOT_FOUND); + } + } + + /** + * 向 annotation_task_history 追加一条历史记录(仅 INSERT,禁止 UPDATE/DELETE)。 + */ + public void insertHistory(Long taskId, Long companyId, + String fromStatus, String toStatus, + Long operatorId, String operatorRole, String comment) { + historyMapper.insert(AnnotationTaskHistory.builder() + .taskId(taskId) + .companyId(companyId) + .fromStatus(fromStatus) + .toStatus(toStatus) + .operatorId(operatorId) + .operatorRole(operatorRole) + .comment(comment) + .build()); + } +} diff --git a/src/main/java/com/label/module/task/service/TaskService.java b/src/main/java/com/label/module/task/service/TaskService.java new file mode 100644 index 0000000..a902545 --- /dev/null +++ b/src/main/java/com/label/module/task/service/TaskService.java @@ -0,0 +1,201 @@ +package com.label.module.task.service; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.label.common.exception.BusinessException; +import com.label.common.result.PageResult; +import com.label.common.shiro.TokenPrincipal; +import com.label.module.task.dto.TaskResponse; +import com.label.module.task.entity.AnnotationTask; +import com.label.module.task.mapper.AnnotationTaskMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * 任务管理服务:创建、查询任务池、我的任务、待审批队列、指派。 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class TaskService { + + private final AnnotationTaskMapper taskMapper; + private final TaskClaimService taskClaimService; + + // ------------------------------------------------------------------ 创建 -- + + /** + * 创建标注任务(内部调用,例如视频处理完成后)。 + * + * @param sourceId 资料 ID + * @param taskType 任务类型(EXTRACTION / QA_GENERATION) + * @param companyId 租户 ID + * @return 新任务 + */ + @Transactional + public AnnotationTask createTask(Long sourceId, String taskType, Long companyId) { + AnnotationTask task = new AnnotationTask(); + task.setCompanyId(companyId); + task.setSourceId(sourceId); + task.setTaskType(taskType); + task.setStatus("UNCLAIMED"); + task.setIsFinal(false); + taskMapper.insert(task); + log.debug("任务已创建: id={}, type={}, sourceId={}", task.getId(), taskType, sourceId); + return task; + } + + // ------------------------------------------------------------------ 任务池 -- + + /** + * 查询任务池(按角色过滤): + * - ANNOTATOR → EXTRACTION 类型、UNCLAIMED 状态 + * - REVIEWER/ADMIN → SUBMITTED 状态(任意类型) + */ + public PageResult getPool(int page, int pageSize, TokenPrincipal principal) { + pageSize = Math.min(pageSize, 100); + LambdaQueryWrapper wrapper = new LambdaQueryWrapper() + .orderByAsc(AnnotationTask::getCreatedAt); + + String role = principal.getRole(); + if ("ANNOTATOR".equals(role)) { + wrapper.eq(AnnotationTask::getTaskType, "EXTRACTION") + .eq(AnnotationTask::getStatus, "UNCLAIMED"); + } else { + // REVIEWER / ADMIN 看待审批队列 + wrapper.eq(AnnotationTask::getStatus, "SUBMITTED"); + } + + Page pageResult = taskMapper.selectPage(new Page<>(page, pageSize), wrapper); + return toPageResult(pageResult, page, pageSize); + } + + // ------------------------------------------------------------------ 我的任务 -- + + /** + * 查询当前用户的任务(IN_PROGRESS、SUBMITTED、REJECTED)。 + */ + public PageResult getMine(int page, int pageSize, + String status, TokenPrincipal principal) { + pageSize = Math.min(pageSize, 100); + LambdaQueryWrapper wrapper = new LambdaQueryWrapper() + .eq(AnnotationTask::getClaimedBy, principal.getUserId()) + .in(AnnotationTask::getStatus, "IN_PROGRESS", "SUBMITTED", "REJECTED") + .orderByDesc(AnnotationTask::getUpdatedAt); + + if (status != null && !status.isBlank()) { + wrapper.eq(AnnotationTask::getStatus, status); + } + + Page pageResult = taskMapper.selectPage(new Page<>(page, pageSize), wrapper); + return toPageResult(pageResult, page, pageSize); + } + + // ------------------------------------------------------------------ 待审批 -- + + /** + * 查询待审批任务(REVIEWER 专属,status=SUBMITTED)。 + */ + public PageResult getPendingReview(int page, int pageSize, String taskType) { + pageSize = Math.min(pageSize, 100); + LambdaQueryWrapper wrapper = new LambdaQueryWrapper() + .eq(AnnotationTask::getStatus, "SUBMITTED") + .orderByAsc(AnnotationTask::getSubmittedAt); + + if (taskType != null && !taskType.isBlank()) { + wrapper.eq(AnnotationTask::getTaskType, taskType); + } + + Page pageResult = taskMapper.selectPage(new Page<>(page, pageSize), wrapper); + return toPageResult(pageResult, page, pageSize); + } + + // ------------------------------------------------------------------ 查询单条 -- + + public AnnotationTask getById(Long id) { + AnnotationTask task = taskMapper.selectById(id); + if (task == null) { + throw new BusinessException("NOT_FOUND", "任务不存在: " + id, HttpStatus.NOT_FOUND); + } + return task; + } + + // ------------------------------------------------------------------ 全部任务(ADMIN)-- + + /** + * 查询全部任务(ADMIN 专用)。 + */ + public PageResult getAll(int page, int pageSize, String status, String taskType) { + pageSize = Math.min(pageSize, 100); + LambdaQueryWrapper wrapper = new LambdaQueryWrapper() + .orderByDesc(AnnotationTask::getCreatedAt); + + if (status != null && !status.isBlank()) { + wrapper.eq(AnnotationTask::getStatus, status); + } + if (taskType != null && !taskType.isBlank()) { + wrapper.eq(AnnotationTask::getTaskType, taskType); + } + + Page pageResult = taskMapper.selectPage(new Page<>(page, pageSize), wrapper); + return toPageResult(pageResult, page, pageSize); + } + + // ------------------------------------------------------------------ 指派(ADMIN)-- + + /** + * ADMIN 强制指派任务给指定用户(IN_PROGRESS → IN_PROGRESS)。 + */ + @Transactional + public void reassign(Long taskId, Long targetUserId, TokenPrincipal principal) { + AnnotationTask task = taskMapper.selectById(taskId); + if (task == null) { + throw new BusinessException("NOT_FOUND", "任务不存在: " + taskId, HttpStatus.NOT_FOUND); + } + + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getClaimedBy, targetUserId) + .set(AnnotationTask::getClaimedAt, java.time.LocalDateTime.now())); + + taskClaimService.insertHistory(taskId, principal.getCompanyId(), + task.getStatus(), "IN_PROGRESS", + principal.getUserId(), principal.getRole(), + "ADMIN 强制指派给用户 " + targetUserId); + } + + // ------------------------------------------------------------------ 私有工具 -- + + private PageResult toPageResult(Page pageResult, int page, int pageSize) { + List items = pageResult.getRecords().stream() + .map(this::toResponse) + .collect(Collectors.toList()); + return PageResult.of(items, pageResult.getTotal(), page, pageSize); + } + + public TaskResponse toPublicResponse(AnnotationTask task) { + return toResponse(task); + } + + private TaskResponse toResponse(AnnotationTask task) { + return TaskResponse.builder() + .id(task.getId()) + .sourceId(task.getSourceId()) + .taskType(task.getTaskType()) + .status(task.getStatus()) + .claimedBy(task.getClaimedBy()) + .claimedAt(task.getClaimedAt()) + .submittedAt(task.getSubmittedAt()) + .completedAt(task.getCompletedAt()) + .rejectReason(task.getRejectReason()) + .createdAt(task.getCreatedAt()) + .build(); + } +} diff --git a/src/test/java/com/label/integration/ExtractionApprovalIntegrationTest.java b/src/test/java/com/label/integration/ExtractionApprovalIntegrationTest.java new file mode 100644 index 0000000..b88379c --- /dev/null +++ b/src/test/java/com/label/integration/ExtractionApprovalIntegrationTest.java @@ -0,0 +1,217 @@ +package com.label.integration; + +import com.label.AbstractIntegrationTest; +import com.label.module.user.dto.LoginRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.web.client.TestRestTemplate; +import org.springframework.http.*; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * 提取阶段审批集成测试(US4)。 + * + * 测试场景: + * 1. 审批通过 → QA_GENERATION 任务自动创建,source_data 状态更新为 QA_REVIEW + * 2. 审批人与提交人相同(自审)→ 403 SELF_REVIEW_FORBIDDEN + * 3. 驳回后标注员可重领任务并再次提交 + */ +public class ExtractionApprovalIntegrationTest extends AbstractIntegrationTest { + + @Autowired + private TestRestTemplate restTemplate; + + private Long sourceId; + private Long taskId; + private Long annotatorUserId; + private Long reviewerUserId; + + @BeforeEach + void setup() { + // 获取种子用户 ID(init.sql 中已插入) + annotatorUserId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_user WHERE username = 'annotator01'", Long.class); + reviewerUserId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_user WHERE username = 'reviewer01'", Long.class); + Long companyId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_company WHERE company_code = 'DEMO'", Long.class); + + // 插入测试 source_data + jdbcTemplate.execute( + "INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " + + "file_name, file_size, bucket_name, status) " + + "VALUES (" + companyId + ", " + annotatorUserId + ", 'TEXT', " + + "'test/approval-test/file.txt', 'file.txt', 100, 'test-bucket', 'PENDING')"); + sourceId = jdbcTemplate.queryForObject( + "SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class); + + // 插入 UNCLAIMED EXTRACTION 任务 + jdbcTemplate.execute( + "INSERT INTO annotation_task (company_id, source_id, task_type, status) " + + "VALUES (" + companyId + ", " + sourceId + ", 'EXTRACTION', 'UNCLAIMED')"); + taskId = jdbcTemplate.queryForObject( + "SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class); + } + + // ------------------------------------------------------------------ 测试 1: 审批通过 → QA 任务自动创建 -- + + @Test + @DisplayName("审批通过后,QA_GENERATION 任务自动创建,source_data 状态变为 QA_REVIEW") + void approveTask_thenQaTaskAndSourceStatusUpdated() { + String annotatorToken = loginAndGetToken("DEMO", "annotator01", "annot123"); + String reviewerToken = loginAndGetToken("DEMO", "reviewer01", "review123"); + + // 1. 标注员领取任务 + ResponseEntity claimResp = restTemplate.exchange( + baseUrl("/api/tasks/" + taskId + "/claim"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(claimResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 2. 标注员提交标注 + ResponseEntity submitResp = restTemplate.exchange( + baseUrl("/api/extraction/" + taskId + "/submit"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(submitResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 3. 审核员审批通过 + // 注:ExtractionApprovedEventListener(@TransactionalEventListener AFTER_COMMIT) + // 在同一线程中同步执行,HTTP 响应返回前已完成后续处理 + ResponseEntity approveResp = restTemplate.exchange( + baseUrl("/api/extraction/" + taskId + "/approve"), + HttpMethod.POST, bearerRequest(reviewerToken), Map.class); + assertThat(approveResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 验证:原任务状态变为 APPROVED,is_final=true + Map taskRow = jdbcTemplate.queryForMap( + "SELECT status, is_final FROM annotation_task WHERE id = ?", taskId); + assertThat(taskRow.get("status")).isEqualTo("APPROVED"); + assertThat(taskRow.get("is_final")).isEqualTo(Boolean.TRUE); + + // 验证:QA_GENERATION 任务已自动创建(UNCLAIMED 状态) + Integer qaTaskCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM annotation_task " + + "WHERE source_id = ? AND task_type = 'QA_GENERATION' AND status = 'UNCLAIMED'", + Integer.class, sourceId); + assertThat(qaTaskCount).as("QA_GENERATION 任务应已创建").isEqualTo(1); + + // 验证:source_data 状态已更新为 QA_REVIEW + String sourceStatus = jdbcTemplate.queryForObject( + "SELECT status FROM source_data WHERE id = ?", String.class, sourceId); + assertThat(sourceStatus).as("source_data 状态应为 QA_REVIEW").isEqualTo("QA_REVIEW"); + + // 验证:training_dataset 已以 PENDING_REVIEW 状态创建 + Integer datasetCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM training_dataset " + + "WHERE source_id = ? AND status = 'PENDING_REVIEW'", + Integer.class, sourceId); + assertThat(datasetCount).as("training_dataset 应已创建").isEqualTo(1); + } + + // ------------------------------------------------------------------ 测试 2: 自审返回 403 -- + + @Test + @DisplayName("审批人与任务领取人相同(自审)→ 403 SELF_REVIEW_FORBIDDEN") + void approveOwnSubmission_returnsForbidden() { + // 直接将任务置为 SUBMITTED 并设 claimed_by = reviewer01(模拟自审场景) + jdbcTemplate.execute( + "UPDATE annotation_task " + + "SET status = 'SUBMITTED', claimed_by = " + reviewerUserId + + ", claimed_at = NOW(), submitted_at = NOW() " + + "WHERE id = " + taskId); + + String reviewerToken = loginAndGetToken("DEMO", "reviewer01", "review123"); + + ResponseEntity resp = restTemplate.exchange( + baseUrl("/api/extraction/" + taskId + "/approve"), + HttpMethod.POST, bearerRequest(reviewerToken), Map.class); + + assertThat(resp.getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); + + // 验证任务状态未变 + String status = jdbcTemplate.queryForObject( + "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); + assertThat(status).isEqualTo("SUBMITTED"); + } + + // ------------------------------------------------------------------ 测试 3: 驳回 → 重领 → 再提交 -- + + @Test + @DisplayName("驳回后标注员可重领任务并再次提交,任务状态恢复为 SUBMITTED") + void rejectThenReclaimAndResubmit_succeeds() { + String annotatorToken = loginAndGetToken("DEMO", "annotator01", "annot123"); + String reviewerToken = loginAndGetToken("DEMO", "reviewer01", "review123"); + + // 1. 标注员领取并提交 + restTemplate.exchange(baseUrl("/api/tasks/" + taskId + "/claim"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + restTemplate.exchange(baseUrl("/api/extraction/" + taskId + "/submit"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + + // 2. 审核员驳回(驳回原因必填) + HttpHeaders rejectHeaders = new HttpHeaders(); + rejectHeaders.set("Authorization", "Bearer " + reviewerToken); + rejectHeaders.setContentType(MediaType.APPLICATION_JSON); + HttpEntity> rejectReq = new HttpEntity<>( + Map.of("reason", "实体识别有误,请重新标注"), rejectHeaders); + + ResponseEntity rejectResp = restTemplate.exchange( + baseUrl("/api/extraction/" + taskId + "/reject"), + HttpMethod.POST, rejectReq, Map.class); + assertThat(rejectResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 验证:任务状态变为 REJECTED + String statusAfterReject = jdbcTemplate.queryForObject( + "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); + assertThat(statusAfterReject).isEqualTo("REJECTED"); + + // 3. 标注员重领任务(REJECTED → IN_PROGRESS) + ResponseEntity reclaimResp = restTemplate.exchange( + baseUrl("/api/tasks/" + taskId + "/reclaim"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(reclaimResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 验证:任务状态恢复为 IN_PROGRESS + String statusAfterReclaim = jdbcTemplate.queryForObject( + "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); + assertThat(statusAfterReclaim).isEqualTo("IN_PROGRESS"); + + // 4. 标注员再次提交(IN_PROGRESS → SUBMITTED) + ResponseEntity resubmitResp = restTemplate.exchange( + baseUrl("/api/extraction/" + taskId + "/submit"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(resubmitResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 验证:任务状态变为 SUBMITTED + String finalStatus = jdbcTemplate.queryForObject( + "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); + assertThat(finalStatus).isEqualTo("SUBMITTED"); + } + + // ------------------------------------------------------------------ 工具方法 -- + + private String loginAndGetToken(String companyCode, String username, String password) { + LoginRequest req = new LoginRequest(); + req.setCompanyCode(companyCode); + req.setUsername(username); + req.setPassword(password); + ResponseEntity response = restTemplate.postForEntity( + baseUrl("/api/auth/login"), req, Map.class); + if (!response.getStatusCode().is2xxSuccessful()) { + return null; + } + @SuppressWarnings("unchecked") + Map data = (Map) response.getBody().get("data"); + return (String) data.get("token"); + } + + private HttpEntity bearerRequest(String token) { + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + token); + return new HttpEntity<>(headers); + } +} diff --git a/src/test/java/com/label/integration/TaskClaimConcurrencyTest.java b/src/test/java/com/label/integration/TaskClaimConcurrencyTest.java new file mode 100644 index 0000000..c972bca --- /dev/null +++ b/src/test/java/com/label/integration/TaskClaimConcurrencyTest.java @@ -0,0 +1,135 @@ +package com.label.integration; + +import com.label.AbstractIntegrationTest; +import com.label.common.redis.RedisKeyManager; +import com.label.common.redis.RedisService; +import org.junit.jupiter.api.*; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.web.client.TestRestTemplate; +import org.springframework.http.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * 任务领取并发安全集成测试(US3)。 + * + * 测试场景:10 个线程同时争抢同一 UNCLAIMED 任务。 + * 期望结果: + * - 恰好 1 人成功(200 OK) + * - 其余 9 人收到 TASK_CLAIMED (409) + * - DB 中 claimed_by 唯一(只有一个用户 ID) + * + * 此测试需要 10 个不同的 userId,使用同一 DB 用户账号但不同的 Token。 + */ +public class TaskClaimConcurrencyTest extends AbstractIntegrationTest { + + @Autowired + private TestRestTemplate restTemplate; + + @Autowired + private RedisService redisService; + + private Long taskId; + private final List tokens = new ArrayList<>(); + + @BeforeEach + void setup() { + // 创建测试任务(直接向 DB 插入一条 UNCLAIMED 任务) + jdbcTemplate.execute( + "INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " + + "file_name, file_size, bucket_name, status) " + + "VALUES (1, 1, 'TEXT', 'test/path/file.txt', 'file.txt', 100, 'test-bucket', 'PENDING')"); + Long sourceId = jdbcTemplate.queryForObject( + "SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class); + + jdbcTemplate.execute( + "INSERT INTO annotation_task (company_id, source_id, task_type, status) " + + "VALUES (1, " + sourceId + ", 'EXTRACTION', 'UNCLAIMED')"); + taskId = jdbcTemplate.queryForObject( + "SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class); + + // 创建 10 个 Annotator Token(模拟不同用户) + for (int i = 1; i <= 10; i++) { + String token = "concurrency-test-token-" + i; + tokens.add(token); + // 所有 Token 使用 userId=3(annotator01),这在真实场景不会发生 + // 但在测试中用于验证并发锁机制(redis key 基于 taskId,不是 userId) + redisService.hSetAll(RedisKeyManager.tokenKey(token), + Map.of("userId", String.valueOf(i + 100), // 假设 userId > 100 不存在,但不影响锁逻辑 + "role", "ANNOTATOR", "companyId", "1", "username", "annotator" + i), + 3600L); + } + } + + @AfterEach + void cleanup() { + tokens.forEach(token -> redisService.delete(RedisKeyManager.tokenKey(token))); + if (taskId != null) { + redisService.delete(RedisKeyManager.taskClaimKey(taskId)); + } + } + + @Test + @DisplayName("10 线程并发抢同一任务:恰好 1 人成功,其余 9 人收到 409 TASK_CLAIMED") + void concurrentClaim_onlyOneSucceeds() throws InterruptedException { + ExecutorService executor = Executors.newFixedThreadPool(10); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(10); + + AtomicInteger successCount = new AtomicInteger(0); + AtomicInteger conflictCount = new AtomicInteger(0); + + for (int i = 0; i < 10; i++) { + final String token = tokens.get(i); + executor.submit(() -> { + try { + startLatch.await(); // 等待起跑信号,最大化并发 + + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + token); + HttpEntity request = new HttpEntity<>(headers); + + ResponseEntity response = restTemplate.exchange( + baseUrl("/api/tasks/" + taskId + "/claim"), + HttpMethod.POST, request, Map.class); + + if (response.getStatusCode() == HttpStatus.OK) { + successCount.incrementAndGet(); + } else if (response.getStatusCode() == HttpStatus.CONFLICT) { + conflictCount.incrementAndGet(); + } + } catch (Exception e) { + conflictCount.incrementAndGet(); // 异常也算失败 + } finally { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); // 同时放行所有线程 + doneLatch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + // 恰好 1 人成功 + assertThat(successCount.get()).isEqualTo(1); + // 其余 9 人失败(409 或异常) + assertThat(conflictCount.get()).isEqualTo(9); + + // DB 中 claimed_by 有且仅有一个值 + String claimedByStr = jdbcTemplate.queryForObject( + "SELECT claimed_by::text FROM annotation_task WHERE id = ?", + String.class, taskId); + assertThat(claimedByStr).isNotNull(); + + // DB 中状态为 IN_PROGRESS + String status = jdbcTemplate.queryForObject( + "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); + assertThat(status).isEqualTo("IN_PROGRESS"); + } +}