diff --git a/.gitignore b/.gitignore index 0f47a07..cec8c00 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ target/ *.iml *.ipr *.iws +.agents/ +logs/ # ========================================== # 3. 项目特定工具目录 (根据你的文件列表) diff --git a/pom.xml b/pom.xml index 9332c6b..a81d3b5 100644 --- a/pom.xml +++ b/pom.xml @@ -1,7 +1,6 @@ 4.0.0 @@ -78,15 +77,21 @@ com.baomidou - mybatis-plus-boot-starter - 3.5.9 + mybatis-plus-spring-boot3-starter + 3.5.10 com.baomidou mybatis-plus-jsqlparser - 3.5.9 + 3.5.10 + + + + org.springdoc + springdoc-openapi-starter-webmvc-ui + 2.5.0 @@ -169,7 +174,9 @@ copy-dependencies package - copy-dependencies + + copy-dependencies + ${project.build.directory}/libs runtime @@ -186,7 +193,9 @@ create-distribution package - single + + single + src/main/assembly/distribution.xml @@ -201,4 +210,4 @@ - \ No newline at end of file + diff --git a/src/main/java/com/label/LabelBackendApplication.java b/src/main/java/com/label/LabelBackendApplication.java index 856844e..056e5bc 100644 --- a/src/main/java/com/label/LabelBackendApplication.java +++ b/src/main/java/com/label/LabelBackendApplication.java @@ -1,3 +1,4 @@ + package com.label; import org.springframework.boot.SpringApplication; @@ -11,11 +12,13 @@ import org.springframework.boot.autoconfigure.SpringBootApplication; * Spring Boot 3. 的 jakarta.servlet 命名空间冲突。 认证/ 授权逻辑改由 * TokenFilter(OncePerRequestFilter)+ ShiroConfig 手动装配。 */ -@SpringBootApplication(excludeName = { - "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebAutoConfiguration", - "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebFilterConfiguration", - "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebMvcAutoConfiguration" }) +// (excludeName = { + +// "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebAutoConfiguration", +// "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebFilterConfiguration", +// "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebMvcAutoConfiguration" }) +@SpringBootApplication public class LabelBackendApplication { public static void main(String[] args) { diff --git a/src/main/java/com/label/common/ai/AiServiceClient.java b/src/main/java/com/label/common/ai/AiServiceClient.java index 93da8f9..56e78d2 100644 --- a/src/main/java/com/label/common/ai/AiServiceClient.java +++ b/src/main/java/com/label/common/ai/AiServiceClient.java @@ -23,9 +23,7 @@ public class AiServiceClient { @PostConstruct public void init() { - restClient = RestClient.builder() - .baseUrl(baseUrl) - .build(); + restClient = RestClient.builder().baseUrl(baseUrl).build(); } // DTO classes @@ -42,7 +40,7 @@ public class AiServiceClient { @Data public static class ExtractionResponse { - private List> items; // triple/quadruple items + private List> items; // triple/quadruple items private String rawOutput; } @@ -52,7 +50,7 @@ public class AiServiceClient { private Long sourceId; private String filePath; private String bucket; - private Map params; // frameInterval, mode etc. + private Map params; // frameInterval, mode etc. } @Data @@ -63,7 +61,7 @@ public class AiServiceClient { @Data @Builder public static class FinetuneRequest { - private String datasetPath; // RustFS path to JSONL file + private String datasetPath; // RustFS path to JSONL file private String model; private Long batchId; } @@ -77,73 +75,42 @@ public class AiServiceClient { @Data public static class FinetuneStatusResponse { private String jobId; - private String status; // PENDING/RUNNING/COMPLETED/FAILED - private Integer progress; // 0-100 + private String status; // PENDING/RUNNING/COMPLETED/FAILED + private Integer progress; // 0-100 private String errorMessage; } // The 8 endpoints: public ExtractionResponse extractText(ExtractionRequest request) { - return restClient.post() - .uri("/extract/text") - .body(request) - .retrieve() - .body(ExtractionResponse.class); + return restClient.post().uri("/extract/text").body(request).retrieve().body(ExtractionResponse.class); } public ExtractionResponse extractImage(ExtractionRequest request) { - return restClient.post() - .uri("/extract/image") - .body(request) - .retrieve() - .body(ExtractionResponse.class); + return restClient.post().uri("/extract/image").body(request).retrieve().body(ExtractionResponse.class); } public void extractFrames(VideoProcessRequest request) { - restClient.post() - .uri("/video/extract-frames") - .body(request) - .retrieve() - .toBodilessEntity(); + restClient.post().uri("/video/extract-frames").body(request).retrieve().toBodilessEntity(); } public void videoToText(VideoProcessRequest request) { - restClient.post() - .uri("/video/to-text") - .body(request) - .retrieve() - .toBodilessEntity(); + restClient.post().uri("/video/to-text").body(request).retrieve().toBodilessEntity(); } public QaGenResponse genTextQa(ExtractionRequest request) { - return restClient.post() - .uri("/qa/gen-text") - .body(request) - .retrieve() - .body(QaGenResponse.class); + return restClient.post().uri("/qa/gen-text").body(request).retrieve().body(QaGenResponse.class); } public QaGenResponse genImageQa(ExtractionRequest request) { - return restClient.post() - .uri("/qa/gen-image") - .body(request) - .retrieve() - .body(QaGenResponse.class); + return restClient.post().uri("/qa/gen-image").body(request).retrieve().body(QaGenResponse.class); } public FinetuneResponse startFinetune(FinetuneRequest request) { - return restClient.post() - .uri("/finetune/start") - .body(request) - .retrieve() - .body(FinetuneResponse.class); + return restClient.post().uri("/finetune/start").body(request).retrieve().body(FinetuneResponse.class); } public FinetuneStatusResponse getFinetuneStatus(String jobId) { - return restClient.get() - .uri("/finetune/status/{jobId}", jobId) - .retrieve() - .body(FinetuneStatusResponse.class); + return restClient.get().uri("/finetune/status/{jobId}", jobId).retrieve().body(FinetuneStatusResponse.class); } } diff --git a/src/main/java/com/label/common/context/CompanyContext.java b/src/main/java/com/label/common/context/CompanyContext.java index 1606633..f46bf69 100644 --- a/src/main/java/com/label/common/context/CompanyContext.java +++ b/src/main/java/com/label/common/context/CompanyContext.java @@ -12,10 +12,10 @@ public class CompanyContext { } public static void clear() { - COMPANY_ID.remove(); // Use remove() not set(null) to prevent memory leaks + COMPANY_ID.remove(); // Use remove() not set(null) to prevent memory leaks } - private CompanyContext() { // Prevent instantiation + private CompanyContext() { // Prevent instantiation throw new UnsupportedOperationException("Utility class"); } } diff --git a/src/main/java/com/label/common/shiro/TokenFilter.java b/src/main/java/com/label/common/shiro/TokenFilter.java index 7eb4b80..dfaa48d 100644 --- a/src/main/java/com/label/common/shiro/TokenFilter.java +++ b/src/main/java/com/label/common/shiro/TokenFilter.java @@ -13,6 +13,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.shiro.SecurityUtils; import org.apache.shiro.util.ThreadContext; +import org.springframework.beans.factory.annotation.Value; import org.springframework.http.MediaType; import org.springframework.web.filter.OncePerRequestFilter; @@ -38,6 +39,24 @@ public class TokenFilter extends OncePerRequestFilter { private final RedisService redisService; private final ObjectMapper objectMapper; + @Value("${shiro.auth.enabled:true}") + private boolean authEnabled; + + @Value("${shiro.auth.mock-company-id:1}") + private Long mockCompanyId; + + @Value("${shiro.auth.mock-user-id:1}") + private Long mockUserId; + + @Value("${shiro.auth.mock-role:ADMIN}") + private String mockRole; + + @Value("${shiro.auth.mock-username:mock}") + private String mockUsername; + + @Value("${token.ttl-seconds:7200}") + private long tokenTtlSeconds; + /** * 公开端点跳过过滤:非 /api/ 前缀路径,以及登录接口本身。 */ @@ -46,13 +65,25 @@ public class TokenFilter extends OncePerRequestFilter { String path = request.getServletPath(); return !path.startsWith("/api/") || path.equals("/api/auth/login") - || path.equals("/api/video/callback"); // AI 服务内部回调,不走用户 Token 认证 + || path.equals("/api/video/callback") + || path.startsWith("/swagger-ui") + || path.startsWith("/v3/api-docs"); // AI 服务内部回调,不走用户 Token 认证 } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { try { + if (!authEnabled) { + TokenPrincipal principal = new TokenPrincipal( + mockUserId, mockRole, mockCompanyId, mockUsername, "mock-token"); + CompanyContext.set(mockCompanyId); + SecurityUtils.getSubject().login(new BearerToken("mock-token", principal)); + request.setAttribute("__token_principal__", principal); + filterChain.doFilter(request, response); + return; + } + String authHeader = request.getHeader("Authorization"); if (authHeader == null || !authHeader.startsWith("Bearer ")) { writeUnauthorized(response, "缺少或无效的认证令牌"); @@ -79,6 +110,8 @@ public class TokenFilter extends OncePerRequestFilter { TokenPrincipal principal = new TokenPrincipal(userId, role, companyId, username, token); SecurityUtils.getSubject().login(new BearerToken(token, principal)); request.setAttribute("__token_principal__", principal); + redisService.expire(RedisKeyManager.tokenKey(token), tokenTtlSeconds); + redisService.expire(RedisKeyManager.userSessionsKey(userId), tokenTtlSeconds); filterChain.doFilter(request, response); } catch (Exception e) { diff --git a/src/main/java/com/label/module/annotation/controller/ExtractionController.java b/src/main/java/com/label/module/annotation/controller/ExtractionController.java index f202dbe..6dcef99 100644 --- a/src/main/java/com/label/module/annotation/controller/ExtractionController.java +++ b/src/main/java/com/label/module/annotation/controller/ExtractionController.java @@ -3,6 +3,8 @@ package com.label.module.annotation.controller; import com.label.common.result.Result; import com.label.common.shiro.TokenPrincipal; import com.label.module.annotation.service.ExtractionService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.apache.shiro.authz.annotation.RequiresRoles; @@ -13,6 +15,7 @@ import java.util.Map; /** * 提取阶段标注工作台接口(5 个端点)。 */ +@Tag(name = "提取标注", description = "提取阶段的查看、编辑、提交和审批") @RestController @RequestMapping("/api/extraction") @RequiredArgsConstructor @@ -21,6 +24,7 @@ public class ExtractionController { private final ExtractionService extractionService; /** GET /api/extraction/{taskId} — 获取当前标注结果 */ + @Operation(summary = "获取提取标注结果") @GetMapping("/{taskId}") @RequiresRoles("ANNOTATOR") public Result> getResult(@PathVariable Long taskId, @@ -29,6 +33,7 @@ public class ExtractionController { } /** PUT /api/extraction/{taskId} — 更新标注结果(整体覆盖) */ + @Operation(summary = "更新提取标注结果") @PutMapping("/{taskId}") @RequiresRoles("ANNOTATOR") public Result updateResult(@PathVariable Long taskId, @@ -39,6 +44,7 @@ public class ExtractionController { } /** POST /api/extraction/{taskId}/submit — 提交标注结果 */ + @Operation(summary = "提交提取标注结果") @PostMapping("/{taskId}/submit") @RequiresRoles("ANNOTATOR") public Result submit(@PathVariable Long taskId, @@ -48,6 +54,7 @@ public class ExtractionController { } /** POST /api/extraction/{taskId}/approve — 审批通过(REVIEWER) */ + @Operation(summary = "审批通过提取结果") @PostMapping("/{taskId}/approve") @RequiresRoles("REVIEWER") public Result approve(@PathVariable Long taskId, @@ -57,6 +64,7 @@ public class ExtractionController { } /** POST /api/extraction/{taskId}/reject — 驳回(REVIEWER) */ + @Operation(summary = "驳回提取结果") @PostMapping("/{taskId}/reject") @RequiresRoles("REVIEWER") public Result reject(@PathVariable Long taskId, diff --git a/src/main/java/com/label/module/annotation/controller/QaController.java b/src/main/java/com/label/module/annotation/controller/QaController.java index 546a050..f87686d 100644 --- a/src/main/java/com/label/module/annotation/controller/QaController.java +++ b/src/main/java/com/label/module/annotation/controller/QaController.java @@ -3,6 +3,8 @@ package com.label.module.annotation.controller; import com.label.common.result.Result; import com.label.common.shiro.TokenPrincipal; import com.label.module.annotation.service.QaService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.apache.shiro.authz.annotation.RequiresRoles; @@ -13,6 +15,7 @@ import java.util.Map; /** * 问答生成阶段标注工作台接口(5 个端点)。 */ +@Tag(name = "问答生成", description = "问答生成阶段的查看、编辑、提交和审批") @RestController @RequestMapping("/api/qa") @RequiredArgsConstructor @@ -21,6 +24,7 @@ public class QaController { private final QaService qaService; /** GET /api/qa/{taskId} — 获取候选问答对 */ + @Operation(summary = "获取候选问答对") @GetMapping("/{taskId}") @RequiresRoles("ANNOTATOR") public Result> getResult(@PathVariable Long taskId, @@ -29,6 +33,7 @@ public class QaController { } /** PUT /api/qa/{taskId} — 整体覆盖问答对 */ + @Operation(summary = "更新候选问答对") @PutMapping("/{taskId}") @RequiresRoles("ANNOTATOR") public Result updateResult(@PathVariable Long taskId, @@ -39,6 +44,7 @@ public class QaController { } /** POST /api/qa/{taskId}/submit — 提交问答对 */ + @Operation(summary = "提交问答对") @PostMapping("/{taskId}/submit") @RequiresRoles("ANNOTATOR") public Result submit(@PathVariable Long taskId, @@ -48,6 +54,7 @@ public class QaController { } /** POST /api/qa/{taskId}/approve — 审批通过(REVIEWER) */ + @Operation(summary = "审批通过问答对") @PostMapping("/{taskId}/approve") @RequiresRoles("REVIEWER") public Result approve(@PathVariable Long taskId, @@ -57,6 +64,7 @@ public class QaController { } /** POST /api/qa/{taskId}/reject — 驳回(REVIEWER) */ + @Operation(summary = "驳回答案对") @PostMapping("/{taskId}/reject") @RequiresRoles("REVIEWER") public Result reject(@PathVariable Long taskId, diff --git a/src/main/java/com/label/module/config/controller/SysConfigController.java b/src/main/java/com/label/module/config/controller/SysConfigController.java index 695d3a0..e53aebd 100644 --- a/src/main/java/com/label/module/config/controller/SysConfigController.java +++ b/src/main/java/com/label/module/config/controller/SysConfigController.java @@ -4,6 +4,8 @@ import com.label.common.result.Result; import com.label.common.shiro.TokenPrincipal; import com.label.module.config.entity.SysConfig; import com.label.module.config.service.SysConfigService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.apache.shiro.authz.annotation.RequiresRoles; @@ -18,6 +20,7 @@ import java.util.Map; * GET /api/config — 查询当前公司所有可见配置(公司专属 + 全局默认合并) * PUT /api/config/{key} — 更新/创建公司专属配置(UPSERT) */ +@Tag(name = "系统配置", description = "全局和公司级系统配置管理") @RestController @RequiredArgsConstructor public class SysConfigController { @@ -31,6 +34,7 @@ public class SysConfigController { * - "COMPANY":当前公司专属配置(优先生效) * - "GLOBAL":全局默认配置(公司未覆盖时生效) */ + @Operation(summary = "查询合并后的系统配置") @GetMapping("/api/config") @RequiresRoles("ADMIN") public Result>> listConfig(HttpServletRequest request) { @@ -43,6 +47,7 @@ public class SysConfigController { * * Body: { "value": "...", "description": "..." } */ + @Operation(summary = "更新或创建公司专属配置") @PutMapping("/api/config/{key}") @RequiresRoles("ADMIN") public Result updateConfig(@PathVariable String key, diff --git a/src/main/java/com/label/module/export/controller/ExportController.java b/src/main/java/com/label/module/export/controller/ExportController.java index f612334..b60e7a5 100644 --- a/src/main/java/com/label/module/export/controller/ExportController.java +++ b/src/main/java/com/label/module/export/controller/ExportController.java @@ -7,6 +7,8 @@ import com.label.module.annotation.entity.TrainingDataset; import com.label.module.export.entity.ExportBatch; import com.label.module.export.service.ExportService; import com.label.module.export.service.FinetuneService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.apache.shiro.authz.annotation.RequiresRoles; @@ -19,6 +21,7 @@ import java.util.Map; /** * 训练数据导出与微调接口(5 个端点,全部 ADMIN 权限)。 */ +@Tag(name = "导出管理", description = "训练样本查询、导出批次和微调任务") @RestController @RequiredArgsConstructor public class ExportController { @@ -27,6 +30,7 @@ public class ExportController { private final FinetuneService finetuneService; /** GET /api/training/samples — 分页查询已审批可导出样本 */ + @Operation(summary = "分页查询可导出训练样本") @GetMapping("/api/training/samples") @RequiresRoles("ADMIN") public Result> listSamples( @@ -39,6 +43,7 @@ public class ExportController { } /** POST /api/export/batch — 创建导出批次 */ + @Operation(summary = "创建导出批次") @PostMapping("/api/export/batch") @RequiresRoles("ADMIN") @ResponseStatus(HttpStatus.CREATED) @@ -53,6 +58,7 @@ public class ExportController { } /** POST /api/export/{batchId}/finetune — 提交微调任务 */ + @Operation(summary = "提交微调任务") @PostMapping("/api/export/{batchId}/finetune") @RequiresRoles("ADMIN") public Result> triggerFinetune(@PathVariable Long batchId, @@ -61,6 +67,7 @@ public class ExportController { } /** GET /api/export/{batchId}/status — 查询微调状态 */ + @Operation(summary = "查询微调状态") @GetMapping("/api/export/{batchId}/status") @RequiresRoles("ADMIN") public Result> getFinetuneStatus(@PathVariable Long batchId, @@ -69,6 +76,7 @@ public class ExportController { } /** GET /api/export/list — 分页查询导出批次列表 */ + @Operation(summary = "分页查询导出批次") @GetMapping("/api/export/list") @RequiresRoles("ADMIN") public Result> listBatches( diff --git a/src/main/java/com/label/module/source/controller/SourceController.java b/src/main/java/com/label/module/source/controller/SourceController.java index 5ba0d8e..57323d0 100644 --- a/src/main/java/com/label/module/source/controller/SourceController.java +++ b/src/main/java/com/label/module/source/controller/SourceController.java @@ -5,6 +5,8 @@ import com.label.common.result.Result; import com.label.common.shiro.TokenPrincipal; import com.label.module.source.dto.SourceResponse; import com.label.module.source.service.SourceService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.apache.shiro.authz.annotation.RequiresRoles; @@ -19,6 +21,7 @@ import org.springframework.web.multipart.MultipartFile; * - 上传 / 列表 / 详情:UPLOADER 及以上角色(含 ANNOTATOR、REVIEWER、ADMIN) * - 删除:仅 ADMIN */ +@Tag(name = "资料管理", description = "原始资料上传、查询和删除") @RestController @RequestMapping("/api/source") @RequiredArgsConstructor @@ -30,6 +33,7 @@ public class SourceController { * 上传文件(multipart/form-data)。 * 返回 201 Created + 资料摘要。 */ + @Operation(summary = "上传原始资料") @PostMapping("/upload") @RequiresRoles("UPLOADER") @ResponseStatus(HttpStatus.CREATED) @@ -45,6 +49,7 @@ public class SourceController { * 分页查询资料列表。 * UPLOADER 只见自己的资料;ADMIN 见全公司资料。 */ + @Operation(summary = "分页查询资料列表") @GetMapping("/list") @RequiresRoles("UPLOADER") public Result> list( @@ -60,6 +65,7 @@ public class SourceController { /** * 查询资料详情(含 15 分钟预签名下载链接)。 */ + @Operation(summary = "查询资料详情") @GetMapping("/{id}") @RequiresRoles("UPLOADER") public Result findById(@PathVariable Long id) { @@ -70,6 +76,7 @@ public class SourceController { * 删除资料(仅 PENDING 状态可删)。 * 同步删除 RustFS 文件及 DB 记录。 */ + @Operation(summary = "删除资料") @DeleteMapping("/{id}") @RequiresRoles("ADMIN") public Result delete(@PathVariable Long id, HttpServletRequest request) { diff --git a/src/main/java/com/label/module/source/dto/SourceResponse.java b/src/main/java/com/label/module/source/dto/SourceResponse.java index 418afda..fa3a088 100644 --- a/src/main/java/com/label/module/source/dto/SourceResponse.java +++ b/src/main/java/com/label/module/source/dto/SourceResponse.java @@ -1,5 +1,6 @@ package com.label.module.source.dto; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.Builder; import lombok.Data; @@ -11,17 +12,27 @@ import java.time.LocalDateTime; */ @Data @Builder +@Schema(description = "原始资料响应") public class SourceResponse { + @Schema(description = "资料主键") private Long id; + @Schema(description = "文件名") private String fileName; + @Schema(description = "资料类型", example = "TEXT") private String dataType; + @Schema(description = "文件大小(字节)") private Long fileSize; + @Schema(description = "资料状态", example = "PENDING") private String status; /** 上传用户 ID(列表端点返回) */ + @Schema(description = "上传用户 ID") private Long uploaderId; /** 15 分钟预签名下载链接(详情端点返回) */ + @Schema(description = "预签名下载链接") private String presignedUrl; /** 父资料 ID(视频帧 / 文本片段;详情端点返回) */ + @Schema(description = "父资料 ID") private Long parentSourceId; + @Schema(description = "创建时间") private LocalDateTime createdAt; } diff --git a/src/main/java/com/label/module/task/controller/TaskController.java b/src/main/java/com/label/module/task/controller/TaskController.java index 17d1565..a9563c6 100644 --- a/src/main/java/com/label/module/task/controller/TaskController.java +++ b/src/main/java/com/label/module/task/controller/TaskController.java @@ -6,6 +6,8 @@ import com.label.common.shiro.TokenPrincipal; import com.label.module.task.dto.TaskResponse; import com.label.module.task.service.TaskClaimService; import com.label.module.task.service.TaskService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.apache.shiro.authz.annotation.RequiresRoles; @@ -16,6 +18,7 @@ import java.util.Map; /** * 任务管理接口(10 个端点)。 */ +@Tag(name = "任务管理", description = "任务池、我的任务、审批队列和管理操作") @RestController @RequestMapping("/api/tasks") @RequiredArgsConstructor @@ -25,6 +28,7 @@ public class TaskController { private final TaskClaimService taskClaimService; /** GET /api/tasks/pool — 查询可领取任务池(角色感知) */ + @Operation(summary = "查询可领取任务池") @GetMapping("/pool") @RequiresRoles("ANNOTATOR") public Result> getPool( @@ -35,6 +39,7 @@ public class TaskController { } /** GET /api/tasks/mine — 查询我的任务 */ + @Operation(summary = "查询我的任务") @GetMapping("/mine") @RequiresRoles("ANNOTATOR") public Result> getMine( @@ -46,6 +51,7 @@ public class TaskController { } /** GET /api/tasks/pending-review — 待审批队列(REVIEWER 专属) */ + @Operation(summary = "查询待审批任务") @GetMapping("/pending-review") @RequiresRoles("REVIEWER") public Result> getPendingReview( @@ -56,6 +62,7 @@ public class TaskController { } /** GET /api/tasks — 查询全部任务(ADMIN) */ + @Operation(summary = "管理员查询全部任务") @GetMapping @RequiresRoles("ADMIN") public Result> getAll( @@ -67,6 +74,7 @@ public class TaskController { } /** POST /api/tasks — 创建任务(ADMIN) */ + @Operation(summary = "管理员创建任务") @PostMapping @RequiresRoles("ADMIN") public Result createTask(@RequestBody Map body, @@ -79,6 +87,7 @@ public class TaskController { } /** GET /api/tasks/{id} — 查询任务详情 */ + @Operation(summary = "查询任务详情") @GetMapping("/{id}") @RequiresRoles("ANNOTATOR") public Result getById(@PathVariable Long id) { @@ -86,6 +95,7 @@ public class TaskController { } /** POST /api/tasks/{id}/claim — 领取任务 */ + @Operation(summary = "领取任务") @PostMapping("/{id}/claim") @RequiresRoles("ANNOTATOR") public Result claim(@PathVariable Long id, HttpServletRequest request) { @@ -94,6 +104,7 @@ public class TaskController { } /** POST /api/tasks/{id}/unclaim — 放弃任务 */ + @Operation(summary = "放弃任务") @PostMapping("/{id}/unclaim") @RequiresRoles("ANNOTATOR") public Result unclaim(@PathVariable Long id, HttpServletRequest request) { @@ -102,6 +113,7 @@ public class TaskController { } /** POST /api/tasks/{id}/reclaim — 重领被驳回的任务 */ + @Operation(summary = "重领被驳回的任务") @PostMapping("/{id}/reclaim") @RequiresRoles("ANNOTATOR") public Result reclaim(@PathVariable Long id, HttpServletRequest request) { @@ -110,6 +122,7 @@ public class TaskController { } /** PUT /api/tasks/{id}/reassign — ADMIN 强制指派 */ + @Operation(summary = "管理员强制指派任务") @PutMapping("/{id}/reassign") @RequiresRoles("ADMIN") public Result reassign(@PathVariable Long id, diff --git a/src/main/java/com/label/module/task/dto/TaskResponse.java b/src/main/java/com/label/module/task/dto/TaskResponse.java index a6d2550..caa7397 100644 --- a/src/main/java/com/label/module/task/dto/TaskResponse.java +++ b/src/main/java/com/label/module/task/dto/TaskResponse.java @@ -1,5 +1,6 @@ package com.label.module.task.dto; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.Builder; import lombok.Data; @@ -10,17 +11,28 @@ import java.time.LocalDateTime; */ @Data @Builder +@Schema(description = "标注任务响应") public class TaskResponse { + @Schema(description = "任务主键") private Long id; + @Schema(description = "关联资料 ID") private Long sourceId; /** 任务类型(对应 taskType 字段):EXTRACTION / QA_GENERATION */ + @Schema(description = "任务类型", example = "EXTRACTION") private String taskType; + @Schema(description = "任务状态", example = "UNCLAIMED") private String status; + @Schema(description = "领取人用户 ID") private Long claimedBy; + @Schema(description = "领取时间") private LocalDateTime claimedAt; + @Schema(description = "提交时间") private LocalDateTime submittedAt; + @Schema(description = "完成时间") private LocalDateTime completedAt; /** 驳回原因(REJECTED 状态时非空) */ + @Schema(description = "驳回原因") private String rejectReason; + @Schema(description = "创建时间") private LocalDateTime createdAt; } diff --git a/src/main/java/com/label/module/user/controller/AuthController.java b/src/main/java/com/label/module/user/controller/AuthController.java index ab06cee..c3e8f24 100644 --- a/src/main/java/com/label/module/user/controller/AuthController.java +++ b/src/main/java/com/label/module/user/controller/AuthController.java @@ -6,6 +6,8 @@ import com.label.module.user.dto.LoginRequest; import com.label.module.user.dto.LoginResponse; import com.label.module.user.dto.UserInfoResponse; import com.label.module.user.service.AuthService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.*; @@ -18,6 +20,7 @@ import org.springframework.web.bind.annotation.*; * - POST /api/auth/logout → 需要有效 Token(TokenFilter 校验) * - GET /api/auth/me → 需要有效 Token(TokenFilter 校验) */ +@Tag(name = "认证管理", description = "登录、退出和当前用户信息") @RestController @RequestMapping("/api/auth") @RequiredArgsConstructor @@ -28,6 +31,7 @@ public class AuthController { /** * 登录接口(匿名,无需 Token)。 */ + @Operation(summary = "用户登录,返回 Bearer Token") @PostMapping("/login") public Result login(@RequestBody LoginRequest request) { return Result.success(authService.login(request)); @@ -36,6 +40,7 @@ public class AuthController { /** * 退出登录,立即删除 Redis Token。 */ + @Operation(summary = "退出登录并立即失效当前 Token") @PostMapping("/logout") public Result logout(HttpServletRequest request) { String token = extractToken(request); @@ -47,6 +52,7 @@ public class AuthController { * 获取当前登录用户信息。 * TokenPrincipal 由 TokenFilter 写入请求属性 "__token_principal__"。 */ + @Operation(summary = "获取当前登录用户信息") @GetMapping("/me") public Result me(HttpServletRequest request) { TokenPrincipal principal = (TokenPrincipal) request.getAttribute("__token_principal__"); diff --git a/src/main/java/com/label/module/user/controller/UserController.java b/src/main/java/com/label/module/user/controller/UserController.java index db044c3..afed456 100644 --- a/src/main/java/com/label/module/user/controller/UserController.java +++ b/src/main/java/com/label/module/user/controller/UserController.java @@ -5,6 +5,8 @@ import com.label.common.result.Result; import com.label.common.shiro.TokenPrincipal; import com.label.module.user.entity.SysUser; import com.label.module.user.service.UserService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import org.apache.shiro.authz.annotation.RequiresRoles; @@ -15,6 +17,7 @@ import java.util.Map; /** * 用户管理接口(5 个端点,全部 ADMIN 权限)。 */ +@Tag(name = "用户管理", description = "管理员维护公司用户") @RestController @RequestMapping("/api/users") @RequiredArgsConstructor @@ -23,6 +26,7 @@ public class UserController { private final UserService userService; /** GET /api/users — 分页查询用户列表 */ + @Operation(summary = "分页查询用户列表") @GetMapping @RequiresRoles("ADMIN") public Result> listUsers( @@ -33,6 +37,7 @@ public class UserController { } /** POST /api/users — 创建用户 */ + @Operation(summary = "创建用户") @PostMapping @RequiresRoles("ADMIN") public Result createUser(@RequestBody Map body, @@ -46,6 +51,7 @@ public class UserController { } /** PUT /api/users/{id} — 更新用户基本信息 */ + @Operation(summary = "更新用户基本信息") @PutMapping("/{id}") @RequiresRoles("ADMIN") public Result updateUser(@PathVariable Long id, @@ -59,6 +65,7 @@ public class UserController { } /** PUT /api/users/{id}/status — 变更用户状态 */ + @Operation(summary = "变更用户状态") @PutMapping("/{id}/status") @RequiresRoles("ADMIN") public Result updateStatus(@PathVariable Long id, @@ -69,6 +76,7 @@ public class UserController { } /** PUT /api/users/{id}/role — 变更用户角色 */ + @Operation(summary = "变更用户角色") @PutMapping("/{id}/role") @RequiresRoles("ADMIN") public Result updateRole(@PathVariable Long id, diff --git a/src/main/java/com/label/module/user/dto/LoginRequest.java b/src/main/java/com/label/module/user/dto/LoginRequest.java index 9bbcc44..e71d4b2 100644 --- a/src/main/java/com/label/module/user/dto/LoginRequest.java +++ b/src/main/java/com/label/module/user/dto/LoginRequest.java @@ -1,16 +1,21 @@ package com.label.module.user.dto; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; /** * 登录请求体。 */ @Data +@Schema(description = "登录请求") public class LoginRequest { /** 公司代码(英文简写),用于确定租户 */ + @Schema(description = "公司代码(英文简写)", example = "DEMO") private String companyCode; /** 登录用户名 */ + @Schema(description = "登录用户名", example = "admin") private String username; /** 明文密码(传输层应使用 HTTPS 保护) */ + @Schema(description = "明文密码", example = "admin123") private String password; } diff --git a/src/main/java/com/label/module/user/dto/LoginResponse.java b/src/main/java/com/label/module/user/dto/LoginResponse.java index 6c9ccff..6ed6a5c 100644 --- a/src/main/java/com/label/module/user/dto/LoginResponse.java +++ b/src/main/java/com/label/module/user/dto/LoginResponse.java @@ -1,5 +1,6 @@ package com.label.module.user.dto; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.AllArgsConstructor; import lombok.Data; @@ -8,15 +9,21 @@ import lombok.Data; */ @Data @AllArgsConstructor +@Schema(description = "登录响应") public class LoginResponse { /** Bearer Token(UUID v4),后续请求放入 Authorization 头 */ + @Schema(description = "Bearer Token", example = "550e8400-e29b-41d4-a716-446655440000") private String token; /** 用户主键 */ + @Schema(description = "用户主键") private Long userId; /** 登录用户名 */ + @Schema(description = "登录用户名") private String username; /** 角色:UPLOADER / ANNOTATOR / REVIEWER / ADMIN */ + @Schema(description = "角色", example = "ADMIN") private String role; /** Token 有效期(秒) */ + @Schema(description = "Token 有效期(秒)", example = "7200") private Long expiresIn; } diff --git a/src/main/java/com/label/module/user/dto/UserInfoResponse.java b/src/main/java/com/label/module/user/dto/UserInfoResponse.java index 7173c1b..bf60ae0 100644 --- a/src/main/java/com/label/module/user/dto/UserInfoResponse.java +++ b/src/main/java/com/label/module/user/dto/UserInfoResponse.java @@ -1,5 +1,6 @@ package com.label.module.user.dto; +import io.swagger.v3.oas.annotations.media.Schema; import lombok.AllArgsConstructor; import lombok.Data; @@ -8,11 +9,18 @@ import lombok.Data; */ @Data @AllArgsConstructor +@Schema(description = "当前登录用户信息") public class UserInfoResponse { + @Schema(description = "用户主键") private Long id; + @Schema(description = "用户名") private String username; + @Schema(description = "真实姓名") private String realName; + @Schema(description = "角色", example = "ADMIN") private String role; + @Schema(description = "所属公司 ID") private Long companyId; + @Schema(description = "所属公司名称") private String companyName; } diff --git a/src/main/java/com/label/module/video/controller/VideoController.java b/src/main/java/com/label/module/video/controller/VideoController.java index 271f265..6848e5d 100644 --- a/src/main/java/com/label/module/video/controller/VideoController.java +++ b/src/main/java/com/label/module/video/controller/VideoController.java @@ -4,6 +4,8 @@ import com.label.common.result.Result; import com.label.common.shiro.TokenPrincipal; import com.label.module.video.entity.VideoProcessJob; import com.label.module.video.service.VideoProcessService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -21,6 +23,7 @@ import java.util.Map; * POST /api/video/jobs/{jobId}/reset — 重置失败任务(ADMIN) * POST /api/video/callback — AI 回调接口(无需认证,已在 TokenFilter 中排除) */ +@Tag(name = "视频处理", description = "视频处理任务创建、查询、重置和回调") @Slf4j @RestController @RequiredArgsConstructor @@ -32,6 +35,7 @@ public class VideoController { private String callbackSecret; /** POST /api/video/process — 触发视频处理任务 */ + @Operation(summary = "触发视频处理任务") @PostMapping("/api/video/process") @RequiresRoles("ADMIN") public Result createJob(@RequestBody Map body, @@ -51,6 +55,7 @@ public class VideoController { } /** GET /api/video/jobs/{jobId} — 查询视频处理任务 */ + @Operation(summary = "查询视频处理任务状态") @GetMapping("/api/video/jobs/{jobId}") @RequiresRoles("ADMIN") public Result getJob(@PathVariable Long jobId, @@ -59,6 +64,7 @@ public class VideoController { } /** POST /api/video/jobs/{jobId}/reset — 管理员重置失败任务 */ + @Operation(summary = "重置失败的视频处理任务") @PostMapping("/api/video/jobs/{jobId}/reset") @RequiresRoles("ADMIN") public Result resetJob(@PathVariable Long jobId, @@ -76,9 +82,10 @@ public class VideoController { * { "jobId": 123, "status": "SUCCESS", "outputPath": "processed/123/frames.zip" } * { "jobId": 123, "status": "FAILED", "errorMessage": "ffmpeg error: ..." } */ + @Operation(summary = "接收 AI 服务视频处理回调") @PostMapping("/api/video/callback") public Result handleCallback(@RequestBody Map body, - HttpServletRequest request) { + HttpServletRequest request) { // 共享密钥校验(配置了 VIDEO_CALLBACK_SECRET 时强制校验) if (callbackSecret != null && !callbackSecret.isBlank()) { String provided = request.getHeader("X-Callback-Secret"); diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 0daa0e0..b915ddd 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -2,6 +2,8 @@ server: port: 8080 spring: + application: + name: label-backend datasource: url: ${SPRING_DATASOURCE_URL:jdbc:postgresql://localhost:5432/label_db} username: ${SPRING_DATASOURCE_USERNAME:label} @@ -33,6 +35,14 @@ spring: pathmatch: matching-strategy: ant_path_matcher # Shiro 与 Spring Boot 3 兼容性需要 +springdoc: + api-docs: + enabled: true + path: /v3/api-docs + swagger-ui: + enabled: true + path: /swagger-ui.html + mybatis-plus: mapper-locations: classpath*:mapper/**/*.xml type-aliases-package: com.label.module @@ -53,6 +63,14 @@ ai-service: base-url: ${AI_SERVICE_BASE_URL:http://localhost:8000} timeout: 30000 # milliseconds +shiro: + auth: + enabled: true + mock-company-id: 1 + mock-user-id: 1 + mock-role: ADMIN + mock-username: mock + token: ttl-seconds: 7200 # Token 默认有效期(秒),与 sys_config token_ttl_seconds 保持一致 @@ -61,6 +79,6 @@ video: logging: level: - com.label: DEBUG + com.label: INFO org.apache.shiro: INFO com.baomidou.mybatisplus: INFO diff --git a/src/test/java/com/label/LabelBackendApplicationTests.java b/src/test/java/com/label/LabelBackendApplicationTests.java index f83f216..fc7223f 100644 --- a/src/test/java/com/label/LabelBackendApplicationTests.java +++ b/src/test/java/com/label/LabelBackendApplicationTests.java @@ -1,7 +1,12 @@ package com.label; +import org.junit.jupiter.api.Test; import org.springframework.boot.test.context.SpringBootTest; @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.NONE) class LabelBackendApplicationTests { + + @Test + void contextLoads() { + } } diff --git a/src/test/java/com/label/unit/ApplicationConfigTest.java b/src/test/java/com/label/unit/ApplicationConfigTest.java new file mode 100644 index 0000000..fa64482 --- /dev/null +++ b/src/test/java/com/label/unit/ApplicationConfigTest.java @@ -0,0 +1,55 @@ +package com.label.unit; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.boot.env.YamlPropertySourceLoader; +import org.springframework.core.env.PropertySource; +import org.springframework.core.io.ClassPathResource; + +import java.nio.charset.StandardCharsets; + +import static org.assertj.core.api.Assertions.assertThat; + +@DisplayName("应用配置单元测试") +class ApplicationConfigTest { + + @Test + @DisplayName("application.yml 提供 Swagger 和 shiro.auth 测试开关配置") + void applicationYaml_containsSwaggerAndShiroAuthToggle() throws Exception { + PropertySource source = new YamlPropertySourceLoader() + .load("application", new ClassPathResource("application.yml")) + .get(0); + + assertThat(source.getProperty("springdoc.api-docs.enabled")).isEqualTo(true); + assertThat(source.getProperty("springdoc.api-docs.path")).isEqualTo("/v3/api-docs"); + assertThat(source.getProperty("springdoc.swagger-ui.enabled")).isEqualTo(true); + assertThat(source.getProperty("springdoc.swagger-ui.path")).isEqualTo("/swagger-ui.html"); + assertThat(source.getProperty("shiro.auth.enabled")).isEqualTo(true); + assertThat(source.getProperty("shiro.auth.mock-company-id")).isEqualTo(1); + assertThat(source.getProperty("shiro.auth.mock-user-id")).isEqualTo(1); + assertThat(source.getProperty("shiro.auth.mock-role")).isEqualTo("ADMIN"); + assertThat(source.getProperty("logging.level.com.label")).isEqualTo("INFO"); + } + + @Test + @DisplayName("application.yml 默认值不指向公网服务或携带真实默认密码") + void applicationYaml_doesNotShipPublicInfrastructureDefaults() throws Exception { + String yaml = new ClassPathResource("application.yml") + .getContentAsString(StandardCharsets.UTF_8); + + assertThat(yaml).doesNotContain("39.107.112.174"); + assertThat(yaml).doesNotContain("postgres!Pw"); + assertThat(yaml).doesNotContain("jsti@2024"); + } + + @Test + @DisplayName("logback.xml 启用 60 MB 滚动文件日志") + void logback_enablesRollingFileAppender() throws Exception { + String xml = new ClassPathResource("logback.xml") + .getContentAsString(StandardCharsets.UTF_8); + + assertThat(xml).contains("60MB"); + assertThat(xml).contains(""); + assertThat(xml).doesNotContain(""); + } +} diff --git a/src/test/java/com/label/unit/OpenApiAnnotationTest.java b/src/test/java/com/label/unit/OpenApiAnnotationTest.java new file mode 100644 index 0000000..5d26dcb --- /dev/null +++ b/src/test/java/com/label/unit/OpenApiAnnotationTest.java @@ -0,0 +1,98 @@ +package com.label.unit; + +import com.label.module.annotation.controller.ExtractionController; +import com.label.module.annotation.controller.QaController; +import com.label.module.config.controller.SysConfigController; +import com.label.module.export.controller.ExportController; +import com.label.module.source.controller.SourceController; +import com.label.module.source.dto.SourceResponse; +import com.label.module.task.controller.TaskController; +import com.label.module.task.dto.TaskResponse; +import com.label.module.user.controller.AuthController; +import com.label.module.user.controller.UserController; +import com.label.module.user.dto.LoginRequest; +import com.label.module.user.dto.LoginResponse; +import com.label.module.user.dto.UserInfoResponse; +import com.label.module.video.controller.VideoController; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.tags.Tag; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestMapping; + +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +@DisplayName("OpenAPI 注解覆盖测试") +class OpenApiAnnotationTest { + + private static final List> CONTROLLERS = List.of( + AuthController.class, + UserController.class, + SourceController.class, + TaskController.class, + ExtractionController.class, + QaController.class, + ExportController.class, + SysConfigController.class, + VideoController.class + ); + + private static final List> DTOS = List.of( + LoginRequest.class, + LoginResponse.class, + UserInfoResponse.class, + TaskResponse.class, + SourceResponse.class + ); + + @Test + @DisplayName("所有 REST Controller 都声明 @Tag") + void allControllersHaveTag() { + assertThat(CONTROLLERS) + .allSatisfy(controller -> + assertThat(controller.getAnnotation(Tag.class)) + .as(controller.getSimpleName() + " should have @Tag") + .isNotNull()); + } + + @Test + @DisplayName("所有 REST endpoint 方法都声明 @Operation") + void allEndpointMethodsHaveOperation() { + for (Class controller : CONTROLLERS) { + Arrays.stream(controller.getDeclaredMethods()) + .filter(method -> !Modifier.isPrivate(method.getModifiers())) + .filter(OpenApiAnnotationTest::isEndpointMethod) + .forEach(method -> assertThat(method.getAnnotation(Operation.class)) + .as(controller.getSimpleName() + "." + method.getName() + " should have @Operation") + .isNotNull()); + } + } + + @Test + @DisplayName("核心 DTO 都声明 @Schema") + void coreDtosHaveSchema() { + assertThat(DTOS) + .allSatisfy(dto -> + assertThat(dto.getAnnotation(Schema.class)) + .as(dto.getSimpleName() + " should have @Schema") + .isNotNull()); + } + + private static boolean isEndpointMethod(Method method) { + return method.isAnnotationPresent(GetMapping.class) + || method.isAnnotationPresent(PostMapping.class) + || method.isAnnotationPresent(PutMapping.class) + || method.isAnnotationPresent(DeleteMapping.class) + || method.isAnnotationPresent(RequestMapping.class); + } +} diff --git a/src/test/java/com/label/unit/TokenFilterTest.java b/src/test/java/com/label/unit/TokenFilterTest.java new file mode 100644 index 0000000..ccdf502 --- /dev/null +++ b/src/test/java/com/label/unit/TokenFilterTest.java @@ -0,0 +1,143 @@ +package com.label.unit; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.common.context.CompanyContext; +import com.label.common.redis.RedisKeyManager; +import com.label.common.redis.RedisService; +import com.label.common.shiro.BearerToken; +import com.label.common.shiro.TokenFilter; +import com.label.common.shiro.TokenPrincipal; +import org.apache.shiro.SecurityUtils; +import org.apache.shiro.authc.AuthenticationInfo; +import org.apache.shiro.authc.AuthenticationToken; +import org.apache.shiro.authc.SimpleAuthenticationInfo; +import org.apache.shiro.authz.AuthorizationInfo; +import org.apache.shiro.authz.SimpleAuthorizationInfo; +import org.apache.shiro.mgt.DefaultSecurityManager; +import org.apache.shiro.realm.AuthorizingRealm; +import org.apache.shiro.subject.PrincipalCollection; +import org.apache.shiro.util.ThreadContext; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.util.ReflectionTestUtils; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; + +@DisplayName("TokenFilter 单元测试") +class TokenFilterTest { + + private RedisService redisService; + private TestableTokenFilter filter; + + @BeforeEach + void setUp() { + redisService = mock(RedisService.class); + filter = new TestableTokenFilter(redisService, new ObjectMapper()); + SecurityUtils.setSecurityManager(new DefaultSecurityManager(new BearerTokenRealm())); + } + + @AfterEach + void tearDown() { + CompanyContext.clear(); + ThreadContext.remove(); + } + + @Test + @DisplayName("有效 Token 请求会刷新 token TTL,实现滑动过期") + void validToken_refreshesTokenTtl() throws Exception { + ReflectionTestUtils.setField(filter, "authEnabled", true); + ReflectionTestUtils.setField(filter, "tokenTtlSeconds", 7200L); + String token = "valid-token"; + when(redisService.hGetAll(RedisKeyManager.tokenKey(token))).thenReturn(Map.of( + "userId", "10", + "role", "ADMIN", + "companyId", "20", + "username", "admin" + )); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/tasks"); + request.addHeader("Authorization", "Bearer " + token); + MockHttpServletResponse response = new MockHttpServletResponse(); + RecordingChain chain = new RecordingChain(); + + filter.invoke(request, response, chain); + + assertThat(response.getStatus()).isEqualTo(200); + assertThat(chain.principal).isInstanceOf(TokenPrincipal.class); + verify(redisService).expire(RedisKeyManager.tokenKey(token), 7200L); + } + + @Test + @DisplayName("shiro.auth.enabled=false 时注入 mock Principal 并跳过 Redis 校验") + void authDisabled_injectsMockPrincipalWithoutRedisLookup() throws Exception { + ReflectionTestUtils.setField(filter, "authEnabled", false); + ReflectionTestUtils.setField(filter, "mockCompanyId", 3L); + ReflectionTestUtils.setField(filter, "mockUserId", 4L); + ReflectionTestUtils.setField(filter, "mockRole", "ADMIN"); + ReflectionTestUtils.setField(filter, "mockUsername", "mock-admin"); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/tasks"); + MockHttpServletResponse response = new MockHttpServletResponse(); + RecordingChain chain = new RecordingChain(); + + filter.invoke(request, response, chain); + + assertThat(response.getStatus()).isEqualTo(200); + TokenPrincipal principal = chain.principal; + assertThat(principal.getCompanyId()).isEqualTo(3L); + assertThat(principal.getUserId()).isEqualTo(4L); + assertThat(principal.getRole()).isEqualTo("ADMIN"); + assertThat(principal.getUsername()).isEqualTo("mock-admin"); + verify(redisService, never()).hGetAll(anyString()); + } + + private static final class BearerTokenRealm extends AuthorizingRealm { + @Override + public boolean supports(AuthenticationToken token) { + return token instanceof BearerToken; + } + + @Override + protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) { + return new SimpleAuthenticationInfo(token.getPrincipal(), token.getCredentials(), getName()); + } + + @Override + protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) { + TokenPrincipal principal = (TokenPrincipal) principals.getPrimaryPrincipal(); + SimpleAuthorizationInfo info = new SimpleAuthorizationInfo(); + info.addRole(principal.getRole()); + return info; + } + } + + private static final class RecordingChain implements FilterChain { + private TokenPrincipal principal; + + @Override + public void doFilter(ServletRequest request, ServletResponse response) { + principal = (TokenPrincipal) request.getAttribute("__token_principal__"); + } + } + + private static final class TestableTokenFilter extends TokenFilter { + private TestableTokenFilter(RedisService redisService, ObjectMapper objectMapper) { + super(redisService, objectMapper); + } + + private void invoke(MockHttpServletRequest request, MockHttpServletResponse response, FilterChain chain) + throws Exception { + super.doFilterInternal(request, response, chain); + } + } +}