252 lines
10 KiB
Java
252 lines
10 KiB
Java
package com.label.service;
|
||
|
||
import java.time.LocalDateTime;
|
||
import java.util.Map;
|
||
|
||
import org.springframework.beans.factory.annotation.Value;
|
||
import org.springframework.context.ApplicationEventPublisher;
|
||
import org.springframework.http.HttpStatus;
|
||
import org.springframework.stereotype.Service;
|
||
import org.springframework.transaction.annotation.Transactional;
|
||
|
||
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||
import com.label.common.auth.TokenPrincipal;
|
||
import com.label.common.exception.BusinessException;
|
||
import com.label.common.statemachine.StateValidator;
|
||
import com.label.common.statemachine.TaskStatus;
|
||
import com.label.entity.AnnotationResult;
|
||
import com.label.entity.AnnotationTask;
|
||
import com.label.entity.SourceData;
|
||
import com.label.event.ExtractionApprovedEvent;
|
||
import com.label.mapper.AnnotationResultMapper;
|
||
import com.label.mapper.AnnotationTaskMapper;
|
||
import com.label.mapper.SourceDataMapper;
|
||
|
||
import lombok.RequiredArgsConstructor;
|
||
import lombok.extern.slf4j.Slf4j;
|
||
|
||
/**
|
||
* 提取阶段标注服务:AI 预标注、更新结果、提交、审批、驳回。
|
||
*
|
||
* 关键设计:
|
||
* - approve() 内禁止直接调用 AI,通过 ExtractionApprovedEvent 解耦(AFTER_COMMIT)
|
||
* - 所有写操作包裹在 @Transactional 中,确保任务状态和历史的一致性
|
||
*/
|
||
@Slf4j
|
||
@Service
|
||
@RequiredArgsConstructor
|
||
public class ExtractionService {
|
||
|
||
private final AnnotationTaskMapper taskMapper;
|
||
private final AnnotationResultMapper resultMapper;
|
||
// private final TrainingDatasetMapper datasetMapper;
|
||
private final SourceDataMapper sourceDataMapper;
|
||
private final TaskClaimService taskClaimService;
|
||
// private final AiServiceClient aiServiceClient;
|
||
private final ApplicationEventPublisher eventPublisher;
|
||
private final ObjectMapper objectMapper;
|
||
private final AiAnnotationAsyncService aiAnnotationAsyncService; // 注入异步服务
|
||
|
||
@Value("${rustfs.bucket:label-source-data}")
|
||
private String bucket;
|
||
|
||
// ------------------------------------------------------------------ AI 预标注 --
|
||
|
||
/**
|
||
* AI 辅助预标注:调用 AI 服务,将结果写入 annotation_result。
|
||
* 注:此方法在 @Transactional 外调用(AI 调用不应在事务内),由控制器直接调用。
|
||
*/
|
||
public void aiPreAnnotate(Long taskId, TokenPrincipal principal) {
|
||
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
|
||
|
||
SourceData source = sourceDataMapper.selectById(task.getSourceId());
|
||
if (source == null) {
|
||
throw new BusinessException("NOT_FOUND", "关联资料不存在", HttpStatus.NOT_FOUND);
|
||
}
|
||
|
||
if (source.getFilePath() == null || source.getFilePath().isEmpty()) {
|
||
throw new BusinessException("INVALID_SOURCE", "源文件路径不能为空", HttpStatus.BAD_REQUEST);
|
||
}
|
||
|
||
if (source.getDataType() == null || source.getDataType().isEmpty()) {
|
||
throw new BusinessException("INVALID_SOURCE", "数据类型不能为空", HttpStatus.BAD_REQUEST);
|
||
}
|
||
|
||
String dataType = source.getDataType().toUpperCase();
|
||
if (!"IMAGE".equals(dataType) && !"TEXT".equals(dataType)) {
|
||
log.warn("不支持的数据类型: {}, 任务ID: {}", dataType, taskId);
|
||
throw new BusinessException("UNSUPPORTED_TYPE",
|
||
"不支持的数据类型: " + dataType, HttpStatus.BAD_REQUEST);
|
||
}
|
||
|
||
// 更新任务状态为 PROCESSING
|
||
taskMapper.update(null, new LambdaUpdateWrapper<AnnotationTask>()
|
||
.eq(AnnotationTask::getId, taskId)
|
||
.set(AnnotationTask::getAiStatus, "PROCESSING"));
|
||
|
||
// 触发异步任务
|
||
aiAnnotationAsyncService.processAnnotation(taskId, principal.getCompanyId(), source);
|
||
// executeAiAnnotationAsync(taskId, principal.getCompanyId(), source);
|
||
}
|
||
|
||
/**
|
||
* 人工更新标注结果(整体覆盖,PUT 语义)。
|
||
*
|
||
* @param taskId 任务 ID
|
||
* @param resultJson 新的标注结果 JSON 字符串
|
||
* @param principal 当前用户
|
||
*/
|
||
@Transactional
|
||
public void updateResult(Long taskId, String resultJson, TokenPrincipal principal) {
|
||
validateAndGetTask(taskId, principal.getCompanyId());
|
||
|
||
// 校验 JSON 格式
|
||
try {
|
||
objectMapper.readTree(resultJson);
|
||
} catch (Exception e) {
|
||
throw new BusinessException("INVALID_JSON", "标注结果 JSON 格式不合法", HttpStatus.BAD_REQUEST);
|
||
}
|
||
|
||
int updated = resultMapper.updateResultJson(taskId, resultJson, principal.getCompanyId());
|
||
if (updated == 0) {
|
||
// 不存在则新建
|
||
AnnotationResult result = new AnnotationResult();
|
||
result.setTaskId(taskId);
|
||
result.setCompanyId(principal.getCompanyId());
|
||
result.setResultJson(resultJson);
|
||
resultMapper.insert(result);
|
||
}
|
||
}
|
||
|
||
// ------------------------------------------------------------------ 提交 --
|
||
|
||
/**
|
||
* 提交提取结果(IN_PROGRESS → SUBMITTED)。
|
||
*/
|
||
@Transactional
|
||
public void submit(Long taskId, TokenPrincipal principal) {
|
||
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
|
||
|
||
StateValidator.assertTransition(TaskStatus.TRANSITIONS,
|
||
TaskStatus.valueOf(task.getStatus()), TaskStatus.SUBMITTED);
|
||
|
||
taskMapper.update(null, new LambdaUpdateWrapper<AnnotationTask>()
|
||
.eq(AnnotationTask::getId, taskId)
|
||
.set(AnnotationTask::getStatus, "SUBMITTED")
|
||
.set(AnnotationTask::getSubmittedAt, LocalDateTime.now()));
|
||
|
||
taskClaimService.insertHistory(taskId, principal.getCompanyId(),
|
||
task.getStatus(), "SUBMITTED",
|
||
principal.getUserId(), principal.getRole(), null);
|
||
}
|
||
|
||
// ------------------------------------------------------------------ 审批通过 --
|
||
|
||
/**
|
||
* 审批通过(SUBMITTED → APPROVED)。
|
||
*
|
||
* 两阶段:
|
||
* 1. 同步事务:is_final=true,状态推进,写历史
|
||
* 2. 事务提交后(AFTER_COMMIT):AI 生成问答对 → training_dataset → QA 任务 → source_data 状态
|
||
*
|
||
* 注:AI 调用严禁在此事务内执行。
|
||
*/
|
||
@Transactional
|
||
public void approve(Long taskId, TokenPrincipal principal) {
|
||
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
|
||
|
||
// 自审校验
|
||
if (principal.getUserId().equals(task.getClaimedBy())) {
|
||
throw new BusinessException("SELF_REVIEW_FORBIDDEN",
|
||
"不允许审批自己提交的任务", HttpStatus.FORBIDDEN);
|
||
}
|
||
|
||
StateValidator.assertTransition(TaskStatus.TRANSITIONS,
|
||
TaskStatus.valueOf(task.getStatus()), TaskStatus.APPROVED);
|
||
|
||
// 标记为最终结果
|
||
taskMapper.update(null, new LambdaUpdateWrapper<AnnotationTask>()
|
||
.eq(AnnotationTask::getId, taskId)
|
||
.set(AnnotationTask::getStatus, "APPROVED")
|
||
.set(AnnotationTask::getIsFinal, true)
|
||
.set(AnnotationTask::getCompletedAt, LocalDateTime.now()));
|
||
|
||
taskClaimService.insertHistory(taskId, principal.getCompanyId(),
|
||
"SUBMITTED", "APPROVED",
|
||
principal.getUserId(), principal.getRole(), null);
|
||
|
||
// 获取资料信息,用于事件
|
||
SourceData source = sourceDataMapper.selectById(task.getSourceId());
|
||
String sourceType = source != null ? source.getDataType() : "TEXT";
|
||
|
||
// 发布事件(@TransactionalEventListener(AFTER_COMMIT) 处理 AI 调用)
|
||
eventPublisher.publishEvent(new ExtractionApprovedEvent(
|
||
this, taskId, task.getSourceId(), sourceType,
|
||
principal.getCompanyId(), principal.getUserId()));
|
||
}
|
||
|
||
// ------------------------------------------------------------------ 驳回 --
|
||
|
||
/**
|
||
* 驳回提取结果(SUBMITTED → REJECTED)。
|
||
*/
|
||
@Transactional
|
||
public void reject(Long taskId, String reason, TokenPrincipal principal) {
|
||
if (reason == null || reason.isBlank()) {
|
||
throw new BusinessException("REASON_REQUIRED", "驳回原因不能为空", HttpStatus.BAD_REQUEST);
|
||
}
|
||
|
||
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
|
||
|
||
// 自审校验
|
||
if (principal.getUserId().equals(task.getClaimedBy())) {
|
||
throw new BusinessException("SELF_REVIEW_FORBIDDEN",
|
||
"不允许驳回自己提交的任务", HttpStatus.FORBIDDEN);
|
||
}
|
||
|
||
StateValidator.assertTransition(TaskStatus.TRANSITIONS,
|
||
TaskStatus.valueOf(task.getStatus()), TaskStatus.REJECTED);
|
||
|
||
taskMapper.update(null, new LambdaUpdateWrapper<AnnotationTask>()
|
||
.eq(AnnotationTask::getId, taskId)
|
||
.set(AnnotationTask::getStatus, "REJECTED")
|
||
.set(AnnotationTask::getRejectReason, reason));
|
||
|
||
taskClaimService.insertHistory(taskId, principal.getCompanyId(),
|
||
"SUBMITTED", "REJECTED",
|
||
principal.getUserId(), principal.getRole(), reason);
|
||
}
|
||
|
||
// ------------------------------------------------------------------ 查询 --
|
||
|
||
/**
|
||
* 获取当前标注结果。
|
||
*/
|
||
public Map<String, Object> getResult(Long taskId, TokenPrincipal principal) {
|
||
AnnotationTask task = validateAndGetTask(taskId, principal.getCompanyId());
|
||
AnnotationResult result = resultMapper.selectByTaskId(taskId);
|
||
SourceData source = sourceDataMapper.selectById(task.getSourceId());
|
||
|
||
return Map.of(
|
||
"taskId", taskId,
|
||
"sourceType", source != null ? source.getDataType() : "",
|
||
"sourceFilePath", source != null && source.getFilePath() != null ? source.getFilePath() : "",
|
||
"isFinal", task.getIsFinal() != null && task.getIsFinal(),
|
||
"resultJson", result != null ? result.getResultJson() : "[]");
|
||
}
|
||
|
||
// ------------------------------------------------------------------ 私有工具 --
|
||
|
||
/**
|
||
* 校验任务存在性(多租户自动过滤)。
|
||
*/
|
||
private AnnotationTask validateAndGetTask(Long taskId, Long companyId) {
|
||
AnnotationTask task = taskMapper.selectById(taskId);
|
||
if (task == null || !companyId.equals(task.getCompanyId())) {
|
||
throw new BusinessException("NOT_FOUND", "任务不存在: " + taskId, HttpStatus.NOT_FOUND);
|
||
}
|
||
return task;
|
||
}
|
||
}
|