diff --git a/src/main/java/com/label/module/annotation/controller/QaController.java b/src/main/java/com/label/module/annotation/controller/QaController.java new file mode 100644 index 0000000..546a050 --- /dev/null +++ b/src/main/java/com/label/module/annotation/controller/QaController.java @@ -0,0 +1,73 @@ +package com.label.module.annotation.controller; + +import com.label.common.result.Result; +import com.label.common.shiro.TokenPrincipal; +import com.label.module.annotation.service.QaService; +import jakarta.servlet.http.HttpServletRequest; +import lombok.RequiredArgsConstructor; +import org.apache.shiro.authz.annotation.RequiresRoles; +import org.springframework.web.bind.annotation.*; + +import java.util.Map; + +/** + * 问答生成阶段标注工作台接口(5 个端点)。 + */ +@RestController +@RequestMapping("/api/qa") +@RequiredArgsConstructor +public class QaController { + + private final QaService qaService; + + /** GET /api/qa/{taskId} — 获取候选问答对 */ + @GetMapping("/{taskId}") + @RequiresRoles("ANNOTATOR") + public Result> getResult(@PathVariable Long taskId, + HttpServletRequest request) { + return Result.success(qaService.getResult(taskId, principal(request))); + } + + /** PUT /api/qa/{taskId} — 整体覆盖问答对 */ + @PutMapping("/{taskId}") + @RequiresRoles("ANNOTATOR") + public Result updateResult(@PathVariable Long taskId, + @RequestBody String body, + HttpServletRequest request) { + qaService.updateResult(taskId, body, principal(request)); + return Result.success(null); + } + + /** POST /api/qa/{taskId}/submit — 提交问答对 */ + @PostMapping("/{taskId}/submit") + @RequiresRoles("ANNOTATOR") + public Result submit(@PathVariable Long taskId, + HttpServletRequest request) { + qaService.submit(taskId, principal(request)); + return Result.success(null); + } + + /** POST /api/qa/{taskId}/approve — 审批通过(REVIEWER) */ + @PostMapping("/{taskId}/approve") + @RequiresRoles("REVIEWER") + public Result approve(@PathVariable Long taskId, + HttpServletRequest request) { + qaService.approve(taskId, principal(request)); + return Result.success(null); + } + + /** POST /api/qa/{taskId}/reject — 驳回(REVIEWER) */ + @PostMapping("/{taskId}/reject") + @RequiresRoles("REVIEWER") + public Result reject(@PathVariable Long taskId, + @RequestBody Map body, + HttpServletRequest request) { + String reason = body != null ? body.get("reason") : null; + qaService.reject(taskId, reason, principal(request)); + return Result.success(null); + } + + private TokenPrincipal principal(HttpServletRequest request) { + return (TokenPrincipal) request.getAttribute("__token_principal__"); + } +} diff --git a/src/main/java/com/label/module/annotation/service/QaService.java b/src/main/java/com/label/module/annotation/service/QaService.java new file mode 100644 index 0000000..efd2097 --- /dev/null +++ b/src/main/java/com/label/module/annotation/service/QaService.java @@ -0,0 +1,252 @@ +package com.label.module.annotation.service; + +import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.common.exception.BusinessException; +import com.label.common.shiro.TokenPrincipal; +import com.label.common.statemachine.StateValidator; +import com.label.common.statemachine.TaskStatus; +import com.label.module.annotation.entity.TrainingDataset; +import com.label.module.annotation.mapper.TrainingDatasetMapper; +import com.label.module.source.entity.SourceData; +import com.label.module.source.mapper.SourceDataMapper; +import com.label.module.task.entity.AnnotationTask; +import com.label.module.task.mapper.AnnotationTaskMapper; +import com.label.module.task.service.TaskClaimService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.time.LocalDateTime; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * 问答生成阶段标注服务:查询候选问答对、更新、提交、审批、驳回。 + * + * 关键设计: + * - QA 阶段无 AI 调用(候选问答对已由 ExtractionApprovedEventListener 生成) + * - approve() 同一事务内完成:training_dataset → APPROVED、task → APPROVED、source_data → APPROVED + * - reject() 清除候选问答对(deleteByTaskId),source_data 保持 QA_REVIEW 状态 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class QaService { + + private final AnnotationTaskMapper taskMapper; + private final TrainingDatasetMapper datasetMapper; + private final SourceDataMapper sourceDataMapper; + private final TaskClaimService taskClaimService; + private final ObjectMapper objectMapper; + + // ------------------------------------------------------------------ 查询 -- + + /** + * 获取候选问答对(从 training_dataset.glm_format_json 解析)。 + */ + public Map getResult(Long taskId, TokenPrincipal principal) { + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + TrainingDataset dataset = getDataset(taskId); + + SourceData source = sourceDataMapper.selectById(task.getSourceId()); + String sourceType = source != null ? source.getDataType() : "TEXT"; + + List items = Collections.emptyList(); + if (dataset != null && dataset.getGlmFormatJson() != null) { + try { + @SuppressWarnings("unchecked") + Map parsed = objectMapper.readValue(dataset.getGlmFormatJson(), Map.class); + Object conversations = parsed.get("conversations"); + if (conversations instanceof List) { + items = (List) conversations; + } + } catch (Exception e) { + log.warn("解析 QA JSON 失败(taskId={}):{}", taskId, e.getMessage()); + } + } + + return Map.of( + "taskId", taskId, + "sourceType", sourceType, + "items", items + ); + } + + // ------------------------------------------------------------------ 更新 -- + + /** + * 整体覆盖问答对(PUT 语义)。 + * + * @param taskId 任务 ID + * @param body 包含 items 数组的 JSON,格式:{"items": [...]} + * @param principal 当前用户 + */ + @Transactional + public void updateResult(Long taskId, String body, TokenPrincipal principal) { + validateAndGetTask(taskId, principal.getCompanyId()); + + // 校验 JSON 格式 + try { + objectMapper.readTree(body); + } catch (Exception e) { + throw new BusinessException("INVALID_JSON", "请求体 JSON 格式不合法", HttpStatus.BAD_REQUEST); + } + + // 将 items 格式包装为 GLM 格式:{"conversations": items} + String glmJson; + try { + @SuppressWarnings("unchecked") + Map parsed = objectMapper.readValue(body, Map.class); + Object items = parsed.getOrDefault("items", Collections.emptyList()); + glmJson = objectMapper.writeValueAsString(Map.of("conversations", items)); + } catch (Exception e) { + glmJson = "{\"conversations\":[]}"; + } + + TrainingDataset dataset = getDataset(taskId); + if (dataset != null) { + datasetMapper.update(null, new LambdaUpdateWrapper() + .eq(TrainingDataset::getTaskId, taskId) + .set(TrainingDataset::getGlmFormatJson, glmJson) + .set(TrainingDataset::getUpdatedAt, LocalDateTime.now())); + } else { + // 若 training_dataset 不存在(异常情况),自动创建 + TrainingDataset newDataset = new TrainingDataset(); + newDataset.setCompanyId(principal.getCompanyId()); + newDataset.setTaskId(taskId); + AnnotationTask task = taskMapper.selectById(taskId); + newDataset.setSourceId(task.getSourceId()); + newDataset.setSampleType("TEXT"); + newDataset.setGlmFormatJson(glmJson); + newDataset.setStatus("PENDING_REVIEW"); + datasetMapper.insert(newDataset); + } + } + + // ------------------------------------------------------------------ 提交 -- + + /** + * 提交 QA 结果(IN_PROGRESS → SUBMITTED)。 + */ + @Transactional + public void submit(Long taskId, TokenPrincipal principal) { + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.SUBMITTED); + + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "SUBMITTED") + .set(AnnotationTask::getSubmittedAt, LocalDateTime.now())); + + taskClaimService.insertHistory(taskId, principal.getCompanyId(), + task.getStatus(), "SUBMITTED", + principal.getUserId(), principal.getRole(), null); + } + + // ------------------------------------------------------------------ 审批通过 -- + + /** + * 审批通过(SUBMITTED → APPROVED)。 + * + * 同一事务: + * 1. 校验任务(先于一切 DB 写入) + * 2. 自审校验 + * 3. StateValidator + * 4. training_dataset → APPROVED + * 5. annotation_task → APPROVED + is_final=true + completedAt + * 6. source_data → APPROVED(整条流水线完成) + * 7. 写任务历史 + */ + @Transactional + public void approve(Long taskId, TokenPrincipal principal) { + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + + // 自审校验 + if (principal.getUserId().equals(task.getClaimedBy())) { + throw new BusinessException("SELF_REVIEW_FORBIDDEN", + "不允许审批自己提交的任务", HttpStatus.FORBIDDEN); + } + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.APPROVED); + + // training_dataset → APPROVED + datasetMapper.approveByTaskId(taskId, principal.getCompanyId()); + + // annotation_task → APPROVED + is_final=true + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "APPROVED") + .set(AnnotationTask::getIsFinal, true) + .set(AnnotationTask::getCompletedAt, LocalDateTime.now())); + + // source_data → APPROVED(整条流水线终态) + sourceDataMapper.updateStatus(task.getSourceId(), "APPROVED", principal.getCompanyId()); + + taskClaimService.insertHistory(taskId, principal.getCompanyId(), + "SUBMITTED", "APPROVED", + principal.getUserId(), principal.getRole(), null); + + log.debug("QA 审批通过,整条流水线完成: taskId={}, sourceId={}", taskId, task.getSourceId()); + } + + // ------------------------------------------------------------------ 驳回 -- + + /** + * 驳回 QA 结果(SUBMITTED → REJECTED)。 + * + * 清除候选问答对(deleteByTaskId),source_data 保持 QA_REVIEW 状态不变。 + */ + @Transactional + public void reject(Long taskId, String reason, TokenPrincipal principal) { + if (reason == null || reason.isBlank()) { + throw new BusinessException("REASON_REQUIRED", "驳回原因不能为空", HttpStatus.BAD_REQUEST); + } + + AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId()); + + // 自审校验 + if (principal.getUserId().equals(task.getClaimedBy())) { + throw new BusinessException("SELF_REVIEW_FORBIDDEN", + "不允许驳回自己提交的任务", HttpStatus.FORBIDDEN); + } + + StateValidator.assertTransition(TaskStatus.TRANSITIONS, + TaskStatus.valueOf(task.getStatus()), TaskStatus.REJECTED); + + // 清除候选问答对 + datasetMapper.deleteByTaskId(taskId, principal.getCompanyId()); + + taskMapper.update(null, new LambdaUpdateWrapper() + .eq(AnnotationTask::getId, taskId) + .set(AnnotationTask::getStatus, "REJECTED") + .set(AnnotationTask::getRejectReason, reason)); + + taskClaimService.insertHistory(taskId, principal.getCompanyId(), + "SUBMITTED", "REJECTED", + principal.getUserId(), principal.getRole(), reason); + } + + // ------------------------------------------------------------------ 私有工具 -- + + private AnnotationTask validateAndGetTask(Long taskId, Long companyId) { + AnnotationTask task = taskMapper.selectById(taskId); + if (task == null) { + throw new BusinessException("NOT_FOUND", "任务不存在: " + taskId, HttpStatus.NOT_FOUND); + } + return task; + } + + private TrainingDataset getDataset(Long taskId) { + return datasetMapper.selectOne( + new com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper() + .eq(TrainingDataset::getTaskId, taskId) + .last("LIMIT 1")); + } +} diff --git a/src/test/java/com/label/integration/QaApprovalIntegrationTest.java b/src/test/java/com/label/integration/QaApprovalIntegrationTest.java new file mode 100644 index 0000000..4c00970 --- /dev/null +++ b/src/test/java/com/label/integration/QaApprovalIntegrationTest.java @@ -0,0 +1,196 @@ +package com.label.integration; + +import com.label.AbstractIntegrationTest; +import com.label.module.user.dto.LoginRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.web.client.TestRestTemplate; +import org.springframework.http.*; + +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * QA 问答生成阶段审批集成测试(US5)。 + * + * 测试场景: + * 1. QA 审批通过 → training_dataset.status = APPROVED,source_data.status = APPROVED + * 2. QA 驳回 → 候选问答对被删除,标注员可重领 + */ +public class QaApprovalIntegrationTest extends AbstractIntegrationTest { + + @Autowired + private TestRestTemplate restTemplate; + + private Long sourceId; + private Long taskId; + private Long datasetId; + private Long annotatorUserId; + private Long reviewerUserId; + + @BeforeEach + void setup() { + annotatorUserId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_user WHERE username = 'annotator01'", Long.class); + reviewerUserId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_user WHERE username = 'reviewer01'", Long.class); + Long companyId = jdbcTemplate.queryForObject( + "SELECT id FROM sys_company WHERE company_code = 'DEMO'", Long.class); + + // 插入 source_data(QA_REVIEW 状态,模拟提取审批已完成) + jdbcTemplate.execute( + "INSERT INTO source_data (company_id, uploader_id, data_type, file_path, " + + "file_name, file_size, bucket_name, status) " + + "VALUES (" + companyId + ", " + annotatorUserId + ", 'TEXT', " + + "'test/qa-test/file.txt', 'file.txt', 100, 'test-bucket', 'QA_REVIEW')"); + sourceId = jdbcTemplate.queryForObject( + "SELECT id FROM source_data ORDER BY id DESC LIMIT 1", Long.class); + + // 插入 QA_GENERATION 任务(UNCLAIMED 状态,模拟提取审批通过后自动创建的 QA 任务) + jdbcTemplate.execute( + "INSERT INTO annotation_task (company_id, source_id, task_type, status) " + + "VALUES (" + companyId + ", " + sourceId + ", 'QA_GENERATION', 'UNCLAIMED')"); + taskId = jdbcTemplate.queryForObject( + "SELECT id FROM annotation_task ORDER BY id DESC LIMIT 1", Long.class); + + // 插入候选问答对(模拟 ExtractionApprovedEventListener 创建) + jdbcTemplate.execute( + "INSERT INTO training_dataset (company_id, task_id, source_id, sample_type, " + + "glm_format_json, status) VALUES (" + companyId + ", " + taskId + ", " + sourceId + + ", 'TEXT', '{\"conversations\":[{\"question\":\"北京是哪个国家的首都?\",\"answer\":\"中国\"}]}'::jsonb, " + + "'PENDING_REVIEW')"); + datasetId = jdbcTemplate.queryForObject( + "SELECT id FROM training_dataset ORDER BY id DESC LIMIT 1", Long.class); + } + + // ------------------------------------------------------------------ 测试 1: 审批通过 → 终态 -- + + @Test + @DisplayName("QA 审批通过 → training_dataset.status=APPROVED,source_data.status=APPROVED") + void approveQaTask_thenDatasetAndSourceApproved() { + String annotatorToken = loginAndGetToken("DEMO", "annotator01", "annot123"); + String reviewerToken = loginAndGetToken("DEMO", "reviewer01", "review123"); + + // 注意:QA 任务 claim 端点为 POST /api/tasks/{id}/claim(ANNOTATOR 角色) + // 但 TaskController.getPool 只给 ANNOTATOR 显示 EXTRACTION/UNCLAIMED + // QA 任务由 ANNOTATOR 直接领取(不经过任务池) + ResponseEntity claimResp = restTemplate.exchange( + baseUrl("/api/tasks/" + taskId + "/claim"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(claimResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 提交 QA 结果 + ResponseEntity submitResp = restTemplate.exchange( + baseUrl("/api/qa/" + taskId + "/submit"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(submitResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 审批通过 + ResponseEntity approveResp = restTemplate.exchange( + baseUrl("/api/qa/" + taskId + "/approve"), + HttpMethod.POST, bearerRequest(reviewerToken), Map.class); + assertThat(approveResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 验证:training_dataset → APPROVED + String datasetStatus = jdbcTemplate.queryForObject( + "SELECT status FROM training_dataset WHERE id = ?", String.class, datasetId); + assertThat(datasetStatus).as("training_dataset 状态应为 APPROVED").isEqualTo("APPROVED"); + + // 验证:annotation_task → APPROVED,is_final=true + Map taskRow = jdbcTemplate.queryForMap( + "SELECT status, is_final FROM annotation_task WHERE id = ?", taskId); + assertThat(taskRow.get("status")).isEqualTo("APPROVED"); + assertThat(taskRow.get("is_final")).isEqualTo(Boolean.TRUE); + + // 验证:source_data → APPROVED(整条流水线完成) + String sourceStatus = jdbcTemplate.queryForObject( + "SELECT status FROM source_data WHERE id = ?", String.class, sourceId); + assertThat(sourceStatus).as("source_data 状态应为 APPROVED(流水线终态)").isEqualTo("APPROVED"); + } + + // ------------------------------------------------------------------ 测试 2: 驳回 → 候选记录删除 → 可重领 -- + + @Test + @DisplayName("QA 驳回 → 候选问答对被删除,标注员可重领并再次提交") + void rejectQaTask_thenDatasetDeletedAndReclaimable() { + String annotatorToken = loginAndGetToken("DEMO", "annotator01", "annot123"); + String reviewerToken = loginAndGetToken("DEMO", "reviewer01", "review123"); + + // 领取并提交 + restTemplate.exchange(baseUrl("/api/tasks/" + taskId + "/claim"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + restTemplate.exchange(baseUrl("/api/qa/" + taskId + "/submit"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + + // 驳回(驳回原因必填) + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + reviewerToken); + headers.setContentType(MediaType.APPLICATION_JSON); + HttpEntity> rejectReq = new HttpEntity<>( + Map.of("reason", "问题描述不准确,请修改"), headers); + + ResponseEntity rejectResp = restTemplate.exchange( + baseUrl("/api/qa/" + taskId + "/reject"), + HttpMethod.POST, rejectReq, Map.class); + assertThat(rejectResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 验证:任务状态变为 REJECTED + String statusAfterReject = jdbcTemplate.queryForObject( + "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); + assertThat(statusAfterReject).isEqualTo("REJECTED"); + + // 验证:候选问答对已被删除 + Integer datasetCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM training_dataset WHERE task_id = ?", + Integer.class, taskId); + assertThat(datasetCount).as("驳回后候选问答对应被删除").isEqualTo(0); + + // 验证:source_data 保持 QA_REVIEW(不变) + String sourceStatus = jdbcTemplate.queryForObject( + "SELECT status FROM source_data WHERE id = ?", String.class, sourceId); + assertThat(sourceStatus).as("驳回后 source_data 应保持 QA_REVIEW").isEqualTo("QA_REVIEW"); + + // 标注员重领任务 + ResponseEntity reclaimResp = restTemplate.exchange( + baseUrl("/api/tasks/" + taskId + "/reclaim"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(reclaimResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 再次提交 + ResponseEntity resubmitResp = restTemplate.exchange( + baseUrl("/api/qa/" + taskId + "/submit"), + HttpMethod.POST, bearerRequest(annotatorToken), Map.class); + assertThat(resubmitResp.getStatusCode()).isEqualTo(HttpStatus.OK); + + // 验证:任务状态变为 SUBMITTED + String finalStatus = jdbcTemplate.queryForObject( + "SELECT status FROM annotation_task WHERE id = ?", String.class, taskId); + assertThat(finalStatus).isEqualTo("SUBMITTED"); + } + + // ------------------------------------------------------------------ 工具方法 -- + + private String loginAndGetToken(String companyCode, String username, String password) { + LoginRequest req = new LoginRequest(); + req.setCompanyCode(companyCode); + req.setUsername(username); + req.setPassword(password); + ResponseEntity response = restTemplate.postForEntity( + baseUrl("/api/auth/login"), req, Map.class); + if (!response.getStatusCode().is2xxSuccessful()) { + return null; + } + @SuppressWarnings("unchecked") + Map data = (Map) response.getBody().get("data"); + return (String) data.get("token"); + } + + private HttpEntity bearerRequest(String token) { + HttpHeaders headers = new HttpHeaders(); + headers.set("Authorization", "Bearer " + token); + return new HttpEntity<>(headers); + } +}