2026-04-14 13:45:15 +08:00
|
|
|
|
package com.label.service;
|
2026-04-09 15:43:45 +08:00
|
|
|
|
|
|
|
|
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
|
|
|
|
|
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
|
|
|
|
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
|
|
|
|
|
import com.label.common.exception.BusinessException;
|
|
|
|
|
|
import com.label.common.result.PageResult;
|
|
|
|
|
|
import com.label.common.shiro.TokenPrincipal;
|
|
|
|
|
|
import com.label.common.storage.RustFsClient;
|
2026-04-14 13:39:24 +08:00
|
|
|
|
import com.label.entity.TrainingDataset;
|
|
|
|
|
|
import com.label.mapper.TrainingDatasetMapper;
|
|
|
|
|
|
import com.label.entity.ExportBatch;
|
|
|
|
|
|
import com.label.mapper.ExportBatchMapper;
|
2026-04-09 15:43:45 +08:00
|
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
|
import org.springframework.http.HttpStatus;
|
|
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
|
|
|
|
|
|
|
import java.io.ByteArrayInputStream;
|
|
|
|
|
|
import java.nio.charset.StandardCharsets;
|
|
|
|
|
|
import java.time.LocalDateTime;
|
|
|
|
|
|
import java.util.List;
|
|
|
|
|
|
import java.util.Map;
|
|
|
|
|
|
import java.util.UUID;
|
|
|
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
2026-04-14 13:31:50 +08:00
|
|
|
|
* 训练数据导出服务。
|
|
|
|
|
|
*
|
|
|
|
|
|
* createBatch() 步骤:
|
|
|
|
|
|
* 1. 校验 sampleIds 非空(EMPTY_SAMPLES 400)
|
|
|
|
|
|
* 2. 查询 training_dataset,校验全部为 APPROVED(INVALID_SAMPLES 400)
|
|
|
|
|
|
* 3. 生成 JSONL(每行一个 glm_format_json)
|
|
|
|
|
|
* 4. 上传 RustFS(bucket: finetune-export, key: export/{batchUuid}.jsonl)
|
|
|
|
|
|
* 5. 插入 export_batch 记录
|
|
|
|
|
|
* 6. 批量更新 training_dataset.export_batch_id + exported_at
|
2026-04-09 15:43:45 +08:00
|
|
|
|
*/
|
|
|
|
|
|
@Slf4j
|
|
|
|
|
|
@Service
|
|
|
|
|
|
@RequiredArgsConstructor
|
|
|
|
|
|
public class ExportService {
|
|
|
|
|
|
|
|
|
|
|
|
private static final String EXPORT_BUCKET = "finetune-export";
|
|
|
|
|
|
|
|
|
|
|
|
private final ExportBatchMapper exportBatchMapper;
|
|
|
|
|
|
private final TrainingDatasetMapper datasetMapper;
|
|
|
|
|
|
private final RustFsClient rustFsClient;
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// ------------------------------------------------------------------ 创建批次 --
|
2026-04-09 15:43:45 +08:00
|
|
|
|
|
|
|
|
|
|
/**
|
2026-04-14 13:31:50 +08:00
|
|
|
|
* 创建导出批次。
|
|
|
|
|
|
*
|
|
|
|
|
|
* @param sampleIds 待导出的 training_dataset ID 列表
|
|
|
|
|
|
* @param principal 当前用户
|
|
|
|
|
|
* @return 新建的 ExportBatch
|
2026-04-09 15:43:45 +08:00
|
|
|
|
*/
|
|
|
|
|
|
@Transactional
|
|
|
|
|
|
public ExportBatch createBatch(List<Long> sampleIds, TokenPrincipal principal) {
|
|
|
|
|
|
if (sampleIds == null || sampleIds.isEmpty()) {
|
2026-04-14 13:31:50 +08:00
|
|
|
|
throw new BusinessException("EMPTY_SAMPLES", "导出样本 ID 列表不能为空", HttpStatus.BAD_REQUEST);
|
2026-04-09 15:43:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// 查询样本
|
2026-04-09 15:43:45 +08:00
|
|
|
|
List<TrainingDataset> samples = datasetMapper.selectList(
|
|
|
|
|
|
new LambdaQueryWrapper<TrainingDataset>()
|
|
|
|
|
|
.in(TrainingDataset::getId, sampleIds)
|
|
|
|
|
|
.eq(TrainingDataset::getCompanyId, principal.getCompanyId()));
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// 校验全部已审批
|
|
|
|
|
|
boolean hasNonApproved = samples.stream()
|
2026-04-09 15:43:45 +08:00
|
|
|
|
.anyMatch(s -> !"APPROVED".equals(s.getStatus()));
|
|
|
|
|
|
if (hasNonApproved || samples.size() != sampleIds.size()) {
|
|
|
|
|
|
throw new BusinessException("INVALID_SAMPLES",
|
2026-04-14 13:31:50 +08:00
|
|
|
|
"部分样本不处于 APPROVED 状态或不属于当前租户", HttpStatus.BAD_REQUEST);
|
2026-04-09 15:43:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// 生成 JSONL(每行一个 JSON 对象)
|
|
|
|
|
|
String jsonl = samples.stream()
|
2026-04-09 15:43:45 +08:00
|
|
|
|
.map(TrainingDataset::getGlmFormatJson)
|
|
|
|
|
|
.collect(Collectors.joining("\n"));
|
|
|
|
|
|
byte[] jsonlBytes = jsonl.getBytes(StandardCharsets.UTF_8);
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// 生成唯一批次 UUID,上传 RustFS
|
2026-04-09 15:43:45 +08:00
|
|
|
|
UUID batchUuid = UUID.randomUUID();
|
|
|
|
|
|
String filePath = "export/" + batchUuid + ".jsonl";
|
|
|
|
|
|
|
|
|
|
|
|
rustFsClient.upload(EXPORT_BUCKET, filePath,
|
|
|
|
|
|
new ByteArrayInputStream(jsonlBytes), jsonlBytes.length,
|
|
|
|
|
|
"application/jsonl");
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// 插入 export_batch 记录(若 DB 写入失败,尝试清理 RustFS 孤儿文件)
|
|
|
|
|
|
ExportBatch batch = new ExportBatch();
|
2026-04-09 15:43:45 +08:00
|
|
|
|
batch.setCompanyId(principal.getCompanyId());
|
|
|
|
|
|
batch.setBatchUuid(batchUuid);
|
|
|
|
|
|
batch.setSampleCount(samples.size());
|
|
|
|
|
|
batch.setDatasetFilePath(filePath);
|
|
|
|
|
|
batch.setFinetuneStatus("NOT_STARTED");
|
2026-04-09 19:42:20 +08:00
|
|
|
|
try {
|
|
|
|
|
|
exportBatchMapper.insert(batch);
|
|
|
|
|
|
} catch (Exception e) {
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// DB 插入失败:尝试删除已上传的 RustFS 文件,防止产生孤儿文件
|
|
|
|
|
|
try {
|
2026-04-09 19:42:20 +08:00
|
|
|
|
rustFsClient.delete(EXPORT_BUCKET, filePath);
|
|
|
|
|
|
} catch (Exception deleteEx) {
|
2026-04-14 13:31:50 +08:00
|
|
|
|
log.error("DB 写入失败后清理 RustFS 文件亦失败,孤儿文件: {}/{}", EXPORT_BUCKET, filePath, deleteEx);
|
2026-04-09 19:42:20 +08:00
|
|
|
|
}
|
|
|
|
|
|
throw e;
|
|
|
|
|
|
}
|
2026-04-09 15:43:45 +08:00
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// 批量更新 training_dataset.export_batch_id + exported_at
|
2026-04-09 15:43:45 +08:00
|
|
|
|
datasetMapper.update(null, new LambdaUpdateWrapper<TrainingDataset>()
|
|
|
|
|
|
.in(TrainingDataset::getId, sampleIds)
|
|
|
|
|
|
.set(TrainingDataset::getExportBatchId, batch.getId())
|
|
|
|
|
|
.set(TrainingDataset::getExportedAt, LocalDateTime.now())
|
|
|
|
|
|
.set(TrainingDataset::getUpdatedAt, LocalDateTime.now()));
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
log.info("导出批次已创建: batchId={}, sampleCount={}, path={}",
|
2026-04-09 15:43:45 +08:00
|
|
|
|
batch.getId(), samples.size(), filePath);
|
|
|
|
|
|
return batch;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// ------------------------------------------------------------------ 查询样本 --
|
2026-04-09 15:43:45 +08:00
|
|
|
|
|
|
|
|
|
|
/**
|
2026-04-14 13:31:50 +08:00
|
|
|
|
* 分页查询已审批、可导出的训练样本。
|
|
|
|
|
|
*/
|
2026-04-09 15:43:45 +08:00
|
|
|
|
public PageResult<TrainingDataset> listSamples(int page, int pageSize,
|
|
|
|
|
|
String sampleType, Boolean exported,
|
|
|
|
|
|
TokenPrincipal principal) {
|
|
|
|
|
|
pageSize = Math.min(pageSize, 100);
|
|
|
|
|
|
LambdaQueryWrapper<TrainingDataset> wrapper = new LambdaQueryWrapper<TrainingDataset>()
|
|
|
|
|
|
.eq(TrainingDataset::getStatus, "APPROVED")
|
|
|
|
|
|
.eq(TrainingDataset::getCompanyId, principal.getCompanyId())
|
|
|
|
|
|
.orderByDesc(TrainingDataset::getCreatedAt);
|
|
|
|
|
|
|
|
|
|
|
|
if (sampleType != null && !sampleType.isBlank()) {
|
|
|
|
|
|
wrapper.eq(TrainingDataset::getSampleType, sampleType);
|
|
|
|
|
|
}
|
|
|
|
|
|
if (exported != null) {
|
|
|
|
|
|
if (exported) {
|
|
|
|
|
|
wrapper.isNotNull(TrainingDataset::getExportBatchId);
|
|
|
|
|
|
} else {
|
|
|
|
|
|
wrapper.isNull(TrainingDataset::getExportBatchId);
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Page<TrainingDataset> result = datasetMapper.selectPage(new Page<>(page, pageSize), wrapper);
|
|
|
|
|
|
return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// ------------------------------------------------------------------ 查询批次列表 --
|
2026-04-09 15:43:45 +08:00
|
|
|
|
|
|
|
|
|
|
/**
|
2026-04-14 13:31:50 +08:00
|
|
|
|
* 分页查询导出批次。
|
|
|
|
|
|
*/
|
2026-04-09 15:43:45 +08:00
|
|
|
|
public PageResult<ExportBatch> listBatches(int page, int pageSize, TokenPrincipal principal) {
|
|
|
|
|
|
pageSize = Math.min(pageSize, 100);
|
|
|
|
|
|
Page<ExportBatch> result = exportBatchMapper.selectPage(
|
|
|
|
|
|
new Page<>(page, pageSize),
|
|
|
|
|
|
new LambdaQueryWrapper<ExportBatch>()
|
|
|
|
|
|
.eq(ExportBatch::getCompanyId, principal.getCompanyId())
|
|
|
|
|
|
.orderByDesc(ExportBatch::getCreatedAt));
|
|
|
|
|
|
return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-14 13:31:50 +08:00
|
|
|
|
// ------------------------------------------------------------------ 查询批次 --
|
2026-04-09 15:43:45 +08:00
|
|
|
|
|
|
|
|
|
|
public ExportBatch getById(Long batchId, TokenPrincipal principal) {
|
|
|
|
|
|
ExportBatch batch = exportBatchMapper.selectById(batchId);
|
|
|
|
|
|
if (batch == null || !batch.getCompanyId().equals(principal.getCompanyId())) {
|
2026-04-14 13:31:50 +08:00
|
|
|
|
throw new BusinessException("NOT_FOUND", "导出批次不存在: " + batchId, HttpStatus.NOT_FOUND);
|
2026-04-09 15:43:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
return batch;
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|