package com.label.service; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.label.common.exception.BusinessException; import com.label.common.result.PageResult; import com.label.common.shiro.TokenPrincipal; import com.label.common.storage.RustFsClient; import com.label.entity.TrainingDataset; import com.label.mapper.TrainingDatasetMapper; import com.label.entity.ExportBatch; import com.label.mapper.ExportBatchMapper; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.io.ByteArrayInputStream; import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.stream.Collectors; /** * 训练数据导出服务。 * * createBatch() 步骤: * 1. 校验 sampleIds 非空(EMPTY_SAMPLES 400) * 2. 查询 training_dataset,校验全部为 APPROVED(INVALID_SAMPLES 400) * 3. 生成 JSONL(每行一个 glm_format_json) * 4. 上传 RustFS(bucket: finetune-export, key: export/{batchUuid}.jsonl) * 5. 插入 export_batch 记录 * 6. 批量更新 training_dataset.export_batch_id + exported_at */ @Slf4j @Service @RequiredArgsConstructor public class ExportService { private static final String EXPORT_BUCKET = "finetune-export"; private final ExportBatchMapper exportBatchMapper; private final TrainingDatasetMapper datasetMapper; private final RustFsClient rustFsClient; // ------------------------------------------------------------------ 创建批次 -- /** * 创建导出批次。 * * @param sampleIds 待导出的 training_dataset ID 列表 * @param principal 当前用户 * @return 新建的 ExportBatch */ @Transactional public ExportBatch createBatch(List sampleIds, TokenPrincipal principal) { if (sampleIds == null || sampleIds.isEmpty()) { throw new BusinessException("EMPTY_SAMPLES", "导出样本 ID 列表不能为空", HttpStatus.BAD_REQUEST); } // 查询样本 List samples = datasetMapper.selectList( new LambdaQueryWrapper() .in(TrainingDataset::getId, sampleIds) .eq(TrainingDataset::getCompanyId, principal.getCompanyId())); // 校验全部已审批 boolean hasNonApproved = samples.stream() .anyMatch(s -> !"APPROVED".equals(s.getStatus())); if (hasNonApproved || samples.size() != sampleIds.size()) { throw new BusinessException("INVALID_SAMPLES", "部分样本不处于 APPROVED 状态或不属于当前租户", HttpStatus.BAD_REQUEST); } // 生成 JSONL(每行一个 JSON 对象) String jsonl = samples.stream() .map(TrainingDataset::getGlmFormatJson) .collect(Collectors.joining("\n")); byte[] jsonlBytes = jsonl.getBytes(StandardCharsets.UTF_8); // 生成唯一批次 UUID,上传 RustFS UUID batchUuid = UUID.randomUUID(); String filePath = "export/" + batchUuid + ".jsonl"; rustFsClient.upload(EXPORT_BUCKET, filePath, new ByteArrayInputStream(jsonlBytes), jsonlBytes.length, "application/jsonl"); // 插入 export_batch 记录(若 DB 写入失败,尝试清理 RustFS 孤儿文件) ExportBatch batch = new ExportBatch(); batch.setCompanyId(principal.getCompanyId()); batch.setBatchUuid(batchUuid); batch.setSampleCount(samples.size()); batch.setDatasetFilePath(filePath); batch.setFinetuneStatus("NOT_STARTED"); try { exportBatchMapper.insert(batch); } catch (Exception e) { // DB 插入失败:尝试删除已上传的 RustFS 文件,防止产生孤儿文件 try { rustFsClient.delete(EXPORT_BUCKET, filePath); } catch (Exception deleteEx) { log.error("DB 写入失败后清理 RustFS 文件亦失败,孤儿文件: {}/{}", EXPORT_BUCKET, filePath, deleteEx); } throw e; } // 批量更新 training_dataset.export_batch_id + exported_at datasetMapper.update(null, new LambdaUpdateWrapper() .in(TrainingDataset::getId, sampleIds) .set(TrainingDataset::getExportBatchId, batch.getId()) .set(TrainingDataset::getExportedAt, LocalDateTime.now()) .set(TrainingDataset::getUpdatedAt, LocalDateTime.now())); log.info("导出批次已创建: batchId={}, sampleCount={}, path={}", batch.getId(), samples.size(), filePath); return batch; } // ------------------------------------------------------------------ 查询样本 -- /** * 分页查询已审批、可导出的训练样本。 */ public PageResult listSamples(int page, int pageSize, String sampleType, Boolean exported, TokenPrincipal principal) { pageSize = Math.min(pageSize, 100); LambdaQueryWrapper wrapper = new LambdaQueryWrapper() .eq(TrainingDataset::getStatus, "APPROVED") .eq(TrainingDataset::getCompanyId, principal.getCompanyId()) .orderByDesc(TrainingDataset::getCreatedAt); if (sampleType != null && !sampleType.isBlank()) { wrapper.eq(TrainingDataset::getSampleType, sampleType); } if (exported != null) { if (exported) { wrapper.isNotNull(TrainingDataset::getExportBatchId); } else { wrapper.isNull(TrainingDataset::getExportBatchId); } } Page result = datasetMapper.selectPage(new Page<>(page, pageSize), wrapper); return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize); } // ------------------------------------------------------------------ 查询批次列表 -- /** * 分页查询导出批次。 */ public PageResult listBatches(int page, int pageSize, TokenPrincipal principal) { pageSize = Math.min(pageSize, 100); Page result = exportBatchMapper.selectPage( new Page<>(page, pageSize), new LambdaQueryWrapper() .eq(ExportBatch::getCompanyId, principal.getCompanyId()) .orderByDesc(ExportBatch::getCreatedAt)); return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize); } // ------------------------------------------------------------------ 查询批次 -- public ExportBatch getById(Long batchId, TokenPrincipal principal) { ExportBatch batch = exportBatchMapper.selectById(batchId); if (batch == null || !batch.getCompanyId().equals(principal.getCompanyId())) { throw new BusinessException("NOT_FOUND", "导出批次不存在: " + batchId, HttpStatus.NOT_FOUND); } return batch; } }