Files
label_backend/src/main/java/com/label/service/ExportService.java
2026-04-14 13:45:15 +08:00

178 lines
7.4 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.label.service;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.label.common.exception.BusinessException;
import com.label.common.result.PageResult;
import com.label.common.shiro.TokenPrincipal;
import com.label.common.storage.RustFsClient;
import com.label.entity.TrainingDataset;
import com.label.mapper.TrainingDatasetMapper;
import com.label.entity.ExportBatch;
import com.label.mapper.ExportBatchMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
/**
* 训练数据导出服务。
*
* createBatch() 步骤:
* 1. 校验 sampleIds 非空EMPTY_SAMPLES 400
* 2. 查询 training_dataset校验全部为 APPROVEDINVALID_SAMPLES 400
* 3. 生成 JSONL每行一个 glm_format_json
* 4. 上传 RustFSbucket: finetune-export, key: export/{batchUuid}.jsonl
* 5. 插入 export_batch 记录
* 6. 批量更新 training_dataset.export_batch_id + exported_at
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class ExportService {
private static final String EXPORT_BUCKET = "finetune-export";
private final ExportBatchMapper exportBatchMapper;
private final TrainingDatasetMapper datasetMapper;
private final RustFsClient rustFsClient;
// ------------------------------------------------------------------ 创建批次 --
/**
* 创建导出批次。
*
* @param sampleIds 待导出的 training_dataset ID 列表
* @param principal 当前用户
* @return 新建的 ExportBatch
*/
@Transactional
public ExportBatch createBatch(List<Long> sampleIds, TokenPrincipal principal) {
if (sampleIds == null || sampleIds.isEmpty()) {
throw new BusinessException("EMPTY_SAMPLES", "导出样本 ID 列表不能为空", HttpStatus.BAD_REQUEST);
}
// 查询样本
List<TrainingDataset> samples = datasetMapper.selectList(
new LambdaQueryWrapper<TrainingDataset>()
.in(TrainingDataset::getId, sampleIds)
.eq(TrainingDataset::getCompanyId, principal.getCompanyId()));
// 校验全部已审批
boolean hasNonApproved = samples.stream()
.anyMatch(s -> !"APPROVED".equals(s.getStatus()));
if (hasNonApproved || samples.size() != sampleIds.size()) {
throw new BusinessException("INVALID_SAMPLES",
"部分样本不处于 APPROVED 状态或不属于当前租户", HttpStatus.BAD_REQUEST);
}
// 生成 JSONL每行一个 JSON 对象)
String jsonl = samples.stream()
.map(TrainingDataset::getGlmFormatJson)
.collect(Collectors.joining("\n"));
byte[] jsonlBytes = jsonl.getBytes(StandardCharsets.UTF_8);
// 生成唯一批次 UUID上传 RustFS
UUID batchUuid = UUID.randomUUID();
String filePath = "export/" + batchUuid + ".jsonl";
rustFsClient.upload(EXPORT_BUCKET, filePath,
new ByteArrayInputStream(jsonlBytes), jsonlBytes.length,
"application/jsonl");
// 插入 export_batch 记录(若 DB 写入失败,尝试清理 RustFS 孤儿文件)
ExportBatch batch = new ExportBatch();
batch.setCompanyId(principal.getCompanyId());
batch.setBatchUuid(batchUuid);
batch.setSampleCount(samples.size());
batch.setDatasetFilePath(filePath);
batch.setFinetuneStatus("NOT_STARTED");
try {
exportBatchMapper.insert(batch);
} catch (Exception e) {
// DB 插入失败:尝试删除已上传的 RustFS 文件,防止产生孤儿文件
try {
rustFsClient.delete(EXPORT_BUCKET, filePath);
} catch (Exception deleteEx) {
log.error("DB 写入失败后清理 RustFS 文件亦失败,孤儿文件: {}/{}", EXPORT_BUCKET, filePath, deleteEx);
}
throw e;
}
// 批量更新 training_dataset.export_batch_id + exported_at
datasetMapper.update(null, new LambdaUpdateWrapper<TrainingDataset>()
.in(TrainingDataset::getId, sampleIds)
.set(TrainingDataset::getExportBatchId, batch.getId())
.set(TrainingDataset::getExportedAt, LocalDateTime.now())
.set(TrainingDataset::getUpdatedAt, LocalDateTime.now()));
log.info("导出批次已创建: batchId={}, sampleCount={}, path={}",
batch.getId(), samples.size(), filePath);
return batch;
}
// ------------------------------------------------------------------ 查询样本 --
/**
* 分页查询已审批、可导出的训练样本。
*/
public PageResult<TrainingDataset> listSamples(int page, int pageSize,
String sampleType, Boolean exported,
TokenPrincipal principal) {
pageSize = Math.min(pageSize, 100);
LambdaQueryWrapper<TrainingDataset> wrapper = new LambdaQueryWrapper<TrainingDataset>()
.eq(TrainingDataset::getStatus, "APPROVED")
.eq(TrainingDataset::getCompanyId, principal.getCompanyId())
.orderByDesc(TrainingDataset::getCreatedAt);
if (sampleType != null && !sampleType.isBlank()) {
wrapper.eq(TrainingDataset::getSampleType, sampleType);
}
if (exported != null) {
if (exported) {
wrapper.isNotNull(TrainingDataset::getExportBatchId);
} else {
wrapper.isNull(TrainingDataset::getExportBatchId);
}
}
Page<TrainingDataset> result = datasetMapper.selectPage(new Page<>(page, pageSize), wrapper);
return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize);
}
// ------------------------------------------------------------------ 查询批次列表 --
/**
* 分页查询导出批次。
*/
public PageResult<ExportBatch> listBatches(int page, int pageSize, TokenPrincipal principal) {
pageSize = Math.min(pageSize, 100);
Page<ExportBatch> result = exportBatchMapper.selectPage(
new Page<>(page, pageSize),
new LambdaQueryWrapper<ExportBatch>()
.eq(ExportBatch::getCompanyId, principal.getCompanyId())
.orderByDesc(ExportBatch::getCreatedAt));
return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize);
}
// ------------------------------------------------------------------ 查询批次 --
public ExportBatch getById(Long batchId, TokenPrincipal principal) {
ExportBatch batch = exportBatchMapper.selectById(batchId);
if (batch == null || !batch.getCompanyId().equals(principal.getCompanyId())) {
throw new BusinessException("NOT_FOUND", "导出批次不存在: " + batchId, HttpStatus.NOT_FOUND);
}
return batch;
}
}