Files
label_backend/src/main/java/com/label/service/QaService.java
2026-04-14 13:45:15 +08:00

253 lines
10 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.label.service;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.label.common.exception.BusinessException;
import com.label.common.shiro.TokenPrincipal;
import com.label.common.statemachine.StateValidator;
import com.label.common.statemachine.TaskStatus;
import com.label.entity.TrainingDataset;
import com.label.mapper.TrainingDatasetMapper;
import com.label.entity.SourceData;
import com.label.mapper.SourceDataMapper;
import com.label.entity.AnnotationTask;
import com.label.mapper.AnnotationTaskMapper;
import com.label.service.TaskClaimService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.time.LocalDateTime;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* 问答生成阶段标注服务:查询候选问答对、更新、提交、审批、驳回。
*
* 关键设计:
* - QA 阶段无 AI 调用(候选问答对已由 ExtractionApprovedEventListener 生成)
* - approve() 同一事务内完成training_dataset → APPROVED、task → APPROVED、source_data → APPROVED
* - reject() 清除候选问答对deleteByTaskIdsource_data 保持 QA_REVIEW 状态
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class QaService {
private final AnnotationTaskMapper taskMapper;
private final TrainingDatasetMapper datasetMapper;
private final SourceDataMapper sourceDataMapper;
private final TaskClaimService taskClaimService;
private final ObjectMapper objectMapper;
// ------------------------------------------------------------------ 查询 --
/**
* 获取候选问答对(从 training_dataset.glm_format_json 解析)。
*/
public Map<String, Object> getResult(Long taskId, TokenPrincipal principal) {
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
TrainingDataset dataset = getDataset(taskId);
SourceData source = sourceDataMapper.selectById(task.getSourceId());
String sourceType = source != null ? source.getDataType() : "TEXT";
List<?> items = Collections.emptyList();
if (dataset != null && dataset.getGlmFormatJson() != null) {
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = objectMapper.readValue(dataset.getGlmFormatJson(), Map.class);
Object conversations = parsed.get("conversations");
if (conversations instanceof List) {
items = (List<?>) conversations;
}
} catch (Exception e) {
log.warn("解析 QA JSON 失败taskId={}{}", taskId, e.getMessage());
}
}
return Map.of(
"taskId", taskId,
"sourceType", sourceType,
"items", items
);
}
// ------------------------------------------------------------------ 更新 --
/**
* 整体覆盖问答对PUT 语义)。
*
* @param taskId 任务 ID
* @param body 包含 items 数组的 JSON格式{"items": [...]}
* @param principal 当前用户
*/
@Transactional
public void updateResult(Long taskId, String body, TokenPrincipal principal) {
validateAndGetTask(taskId, principal.getCompanyId());
// 校验 JSON 格式
try {
objectMapper.readTree(body);
} catch (Exception e) {
throw new BusinessException("INVALID_JSON", "请求体 JSON 格式不合法", HttpStatus.BAD_REQUEST);
}
// 将 items 格式包装为 GLM 格式:{"conversations": items}
String glmJson;
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = objectMapper.readValue(body, Map.class);
Object items = parsed.getOrDefault("items", Collections.emptyList());
glmJson = objectMapper.writeValueAsString(Map.of("conversations", items));
} catch (Exception e) {
glmJson = "{\"conversations\":[]}";
}
TrainingDataset dataset = getDataset(taskId);
if (dataset != null) {
datasetMapper.update(null, new LambdaUpdateWrapper<TrainingDataset>()
.eq(TrainingDataset::getTaskId, taskId)
.set(TrainingDataset::getGlmFormatJson, glmJson)
.set(TrainingDataset::getUpdatedAt, LocalDateTime.now()));
} else {
// 若 training_dataset 不存在(异常情况),自动创建
TrainingDataset newDataset = new TrainingDataset();
newDataset.setCompanyId(principal.getCompanyId());
newDataset.setTaskId(taskId);
AnnotationTask task = taskMapper.selectById(taskId);
newDataset.setSourceId(task.getSourceId());
newDataset.setSampleType("TEXT");
newDataset.setGlmFormatJson(glmJson);
newDataset.setStatus("PENDING_REVIEW");
datasetMapper.insert(newDataset);
}
}
// ------------------------------------------------------------------ 提交 --
/**
* 提交 QA 结果IN_PROGRESS → SUBMITTED
*/
@Transactional
public void submit(Long taskId, TokenPrincipal principal) {
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
StateValidator.assertTransition(TaskStatus.TRANSITIONS,
TaskStatus.valueOf(task.getStatus()), TaskStatus.SUBMITTED);
taskMapper.update(null, new LambdaUpdateWrapper<AnnotationTask>()
.eq(AnnotationTask::getId, taskId)
.set(AnnotationTask::getStatus, "SUBMITTED")
.set(AnnotationTask::getSubmittedAt, LocalDateTime.now()));
taskClaimService.insertHistory(taskId, principal.getCompanyId(),
task.getStatus(), "SUBMITTED",
principal.getUserId(), principal.getRole(), null);
}
// ------------------------------------------------------------------ 审批通过 --
/**
* 审批通过SUBMITTED → APPROVED
*
* 同一事务:
* 1. 校验任务(先于一切 DB 写入)
* 2. 自审校验
* 3. StateValidator
* 4. training_dataset → APPROVED
* 5. annotation_task → APPROVED + is_final=true + completedAt
* 6. source_data → APPROVED整条流水线完成
* 7. 写任务历史
*/
@Transactional
public void approve(Long taskId, TokenPrincipal principal) {
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
// 自审校验
if (principal.getUserId().equals(task.getClaimedBy())) {
throw new BusinessException("SELF_REVIEW_FORBIDDEN",
"不允许审批自己提交的任务", HttpStatus.FORBIDDEN);
}
StateValidator.assertTransition(TaskStatus.TRANSITIONS,
TaskStatus.valueOf(task.getStatus()), TaskStatus.APPROVED);
// training_dataset → APPROVED
datasetMapper.approveByTaskId(taskId, principal.getCompanyId());
// annotation_task → APPROVED + is_final=true
taskMapper.update(null, new LambdaUpdateWrapper<AnnotationTask>()
.eq(AnnotationTask::getId, taskId)
.set(AnnotationTask::getStatus, "APPROVED")
.set(AnnotationTask::getIsFinal, true)
.set(AnnotationTask::getCompletedAt, LocalDateTime.now()));
// source_data → APPROVED整条流水线终态
sourceDataMapper.updateStatus(task.getSourceId(), "APPROVED", principal.getCompanyId());
taskClaimService.insertHistory(taskId, principal.getCompanyId(),
"SUBMITTED", "APPROVED",
principal.getUserId(), principal.getRole(), null);
log.info("QA 审批通过,整条流水线完成: taskId={}, sourceId={}", taskId, task.getSourceId());
}
// ------------------------------------------------------------------ 驳回 --
/**
* 驳回 QA 结果SUBMITTED → REJECTED
*
* 清除候选问答对deleteByTaskIdsource_data 保持 QA_REVIEW 状态不变。
*/
@Transactional
public void reject(Long taskId, String reason, TokenPrincipal principal) {
if (reason == null || reason.isBlank()) {
throw new BusinessException("REASON_REQUIRED", "驳回原因不能为空", HttpStatus.BAD_REQUEST);
}
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
// 自审校验
if (principal.getUserId().equals(task.getClaimedBy())) {
throw new BusinessException("SELF_REVIEW_FORBIDDEN",
"不允许驳回自己提交的任务", HttpStatus.FORBIDDEN);
}
StateValidator.assertTransition(TaskStatus.TRANSITIONS,
TaskStatus.valueOf(task.getStatus()), TaskStatus.REJECTED);
// 清除候选问答对
datasetMapper.deleteByTaskId(taskId, principal.getCompanyId());
taskMapper.update(null, new LambdaUpdateWrapper<AnnotationTask>()
.eq(AnnotationTask::getId, taskId)
.set(AnnotationTask::getStatus, "REJECTED")
.set(AnnotationTask::getRejectReason, reason));
taskClaimService.insertHistory(taskId, principal.getCompanyId(),
"SUBMITTED", "REJECTED",
principal.getUserId(), principal.getRole(), reason);
}
// ------------------------------------------------------------------ 私有工具 --
private AnnotationTask validateAndGetTask(Long taskId, Long companyId) {
AnnotationTask task = taskMapper.selectById(taskId);
if (task == null || !companyId.equals(task.getCompanyId())) {
throw new BusinessException("NOT_FOUND", "任务不存在: " + taskId, HttpStatus.NOT_FOUND);
}
return task;
}
private TrainingDataset getDataset(Long taskId) {
return datasetMapper.selectOne(
new com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper<TrainingDataset>()
.eq(TrainingDataset::getTaskId, taskId)
.last("LIMIT 1"));
}
}