From a30b648d30db887024aa0d6d113aabf9df6adb89 Mon Sep 17 00:00:00 2001 From: wh Date: Tue, 14 Apr 2026 16:33:34 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8E=BB=E6=8E=89shiro=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../2026-04-14-auth-company-optimization.md | 66 +++++++ pom.xml | 72 +------ .../com/label/LabelBackendApplication.java | 12 -- .../com/label/annotation/RequireAuth.java | 11 ++ .../com/label/annotation/RequireRole.java | 13 ++ .../{shiro => auth}/TokenPrincipal.java | 6 +- .../com/label/common/context/UserContext.java | 23 +++ .../exception/GlobalExceptionHandler.java | 21 +- .../com/label/common/shiro/BearerToken.java | 26 --- .../com/label/common/shiro/TokenFilter.java | 139 ------------- .../com/label/common/shiro/UserRealm.java | 88 --------- .../java/com/label/config/AuthConfig.java | 20 ++ .../java/com/label/config/ShiroConfig.java | 66 ------- .../com/label/controller/AuthController.java | 13 +- .../label/controller/CompanyController.java | 73 +++++++ .../label/controller/ExportController.java | 14 +- .../controller/ExtractionController.java | 14 +- .../com/label/controller/QaController.java | 14 +- .../label/controller/SourceController.java | 12 +- .../label/controller/SysConfigController.java | 8 +- .../com/label/controller/TaskController.java | 24 +-- .../com/label/controller/UserController.java | 14 +- .../com/label/controller/VideoController.java | 14 +- .../label/interceptor/AuthInterceptor.java | 182 ++++++++++++++++++ .../ExtractionApprovedEventListener.java | 35 ++-- .../java/com/label/mapper/SysUserMapper.java | 4 + .../java/com/label/service/AuthService.java | 4 +- .../com/label/service/CompanyService.java | 122 ++++++++++++ .../java/com/label/service/ExportService.java | 2 +- .../com/label/service/ExtractionService.java | 2 +- .../com/label/service/FinetuneService.java | 2 +- .../java/com/label/service/QaService.java | 2 +- .../java/com/label/service/SourceService.java | 2 +- .../com/label/service/TaskClaimService.java | 2 +- .../java/com/label/service/TaskService.java | 2 +- .../java/com/label/service/UserService.java | 2 +- src/main/resources/application.yml | 36 ++-- .../ShiroFilterIntegrationTest.java | 160 --------------- .../com/label/unit/ApplicationConfigTest.java | 12 +- .../com/label/unit/AuthInterceptorTest.java | 151 +++++++++++++++ .../com/label/unit/CompanyServiceTest.java | 73 +++++++ .../com/label/unit/OpenApiAnnotationTest.java | 2 + .../java/com/label/unit/ShiroConfigTest.java | 40 ---- .../java/com/label/unit/TokenFilterTest.java | 127 ------------ 44 files changed, 868 insertions(+), 859 deletions(-) create mode 100644 docs/superpowers/plans/2026-04-14-auth-company-optimization.md create mode 100644 src/main/java/com/label/annotation/RequireAuth.java create mode 100644 src/main/java/com/label/annotation/RequireRole.java rename src/main/java/com/label/common/{shiro => auth}/TokenPrincipal.java (75%) create mode 100644 src/main/java/com/label/common/context/UserContext.java delete mode 100644 src/main/java/com/label/common/shiro/BearerToken.java delete mode 100644 src/main/java/com/label/common/shiro/TokenFilter.java delete mode 100644 src/main/java/com/label/common/shiro/UserRealm.java create mode 100644 src/main/java/com/label/config/AuthConfig.java delete mode 100644 src/main/java/com/label/config/ShiroConfig.java create mode 100644 src/main/java/com/label/controller/CompanyController.java create mode 100644 src/main/java/com/label/interceptor/AuthInterceptor.java create mode 100644 src/main/java/com/label/service/CompanyService.java delete mode 100644 src/test/java/com/label/integration/ShiroFilterIntegrationTest.java create mode 100644 src/test/java/com/label/unit/AuthInterceptorTest.java create mode 100644 src/test/java/com/label/unit/CompanyServiceTest.java delete mode 100644 src/test/java/com/label/unit/ShiroConfigTest.java delete mode 100644 src/test/java/com/label/unit/TokenFilterTest.java diff --git a/docs/superpowers/plans/2026-04-14-auth-company-optimization.md b/docs/superpowers/plans/2026-04-14-auth-company-optimization.md new file mode 100644 index 0000000..9b68b0e --- /dev/null +++ b/docs/superpowers/plans/2026-04-14-auth-company-optimization.md @@ -0,0 +1,66 @@ +# Auth And Company Optimization Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace the remaining Shiro authorization layer with project-owned Redis token authentication and add company CRUD APIs. + +**Architecture:** Keep the existing UUID token, Redis session storage, and `CompanyContext` tenant injection. Add project-owned `@RequireAuth` and `@RequireRole` annotations plus a Spring MVC `AuthInterceptor`, then remove Shiro config/classes/dependencies. Add `CompanyService` and `CompanyController` for `sys_company` management. + +**Tech Stack:** Java 21, Spring Boot 3.1.5, Spring MVC HandlerInterceptor, RedisTemplate, MyBatis-Plus, JUnit 5, Mockito, AssertJ. + +--- + +### Task 1: Replace Shiro With Custom Auth Interceptor + +**Files:** +- Create: `src/main/java/com/label/annotation/RequireAuth.java` +- Create: `src/main/java/com/label/annotation/RequireRole.java` +- Create: `src/main/java/com/label/interceptor/AuthInterceptor.java` +- Create: `src/main/java/com/label/common/auth/TokenPrincipal.java` +- Create: `src/main/java/com/label/common/context/UserContext.java` +- Modify: `src/main/java/com/label/config/ShiroConfig.java` +- Modify: `src/main/java/com/label/common/shiro/TokenFilter.java` +- Modify: `src/main/java/com/label/common/shiro/BearerToken.java` +- Modify: `src/main/java/com/label/common/shiro/UserRealm.java` +- Modify: `src/main/java/com/label/controller/*.java` +- Modify: `src/main/java/com/label/service/*.java` +- Modify: `pom.xml` +- Test: `src/test/java/com/label/unit/AuthInterceptorTest.java` + +- [x] Write failing tests for token loading, TTL refresh, role hierarchy, and context cleanup. +- [x] Implement annotations, principal, context, and interceptor. +- [x] Register the interceptor via Spring MVC config. +- [x] Replace controller `@RequiresRoles` usage with `@RequireRole`. +- [x] Remove Shiro-only classes, tests, dependencies, and exception handling. +- [x] Run `mvn -q "-Dtest=AuthInterceptorTest,OpenApiAnnotationTest" test` and `mvn -q -DskipTests compile`. + +### Task 2: Add Company Management + +**Files:** +- Create: `src/main/java/com/label/service/CompanyService.java` +- Create: `src/main/java/com/label/controller/CompanyController.java` +- Modify: `src/main/java/com/label/mapper/SysUserMapper.java` +- Test: `src/test/java/com/label/unit/CompanyServiceTest.java` +- Test: `src/test/java/com/label/unit/OpenApiAnnotationTest.java` + +- [x] Write failing tests for create/list/update/status/delete behavior. +- [x] Implement service validation and duplicate checks. +- [x] Implement admin-only controller endpoints under `/api/companies`. +- [x] Run `mvn -q "-Dtest=CompanyServiceTest,OpenApiAnnotationTest" test` and `mvn -q -DskipTests compile`. + +### Task 3: Configuration And Verification + +**Files:** +- Modify: `src/main/resources/application.yml` +- Modify: `src/test/java/com/label/unit/ApplicationConfigTest.java` + +- [x] Rename `shiro.auth.*` config to `auth.*`. +- [x] Update safe defaults and type-aliases package. +- [x] Run targeted unit tests and compile. +- [x] Run `mvn clean test` once and record any external environment blockers. + +### Verification Notes + +- `mvn -q "-Dtest=LabelBackendApplicationTests,ApplicationConfigTest,AuthInterceptorTest,CompanyServiceTest,OpenApiAnnotationTest" test` passed. +- `mvn -q -DskipTests compile` passed. +- `mvn clean test` compiled main/test sources and passed unit tests, then failed only because 10 Testcontainers integration tests could not find a valid Docker environment. diff --git a/pom.xml b/pom.xml index 6ecd4d6..5180e89 100644 --- a/pom.xml +++ b/pom.xml @@ -3,19 +3,16 @@ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - org.springframework.boot spring-boot-starter-parent 3.1.5 - com.label label-backend 1.0.0-SNAPSHOT jar - 21 UTF-8 @@ -24,7 +21,6 @@ 2.3.0 UTF-8 - @@ -45,32 +41,27 @@ - org.springframework.boot spring-boot-starter-web - org.springframework.boot spring-boot-starter-actuator - org.springframework.boot spring-boot-starter-data-redis - org.springframework.boot spring-boot-starter-aop - org.postgresql @@ -78,106 +69,61 @@ ${postgrescp.version} runtime - - com.baomidou mybatis-plus-boot-starter ${mybatis-plus.version} - - + - com.baomidou - mybatis-plus-jsqlparser - 3.5.10 + com.github.jsqlparser + jsqlparser + 4.4 - org.springdoc springdoc-openapi-starter-webmvc-ui - 2.3.0 + 2.3.0 - - - - - - org.apache.shiro - shiro-core - jakarta - 2.0.0 - - - - org.apache.shiro - shiro-web - jakarta - 2.0.0 - - - - org.apache.shiro - shiro-spring - jakarta - 2.0.0 - - - org.apache.shiro - shiro-web - - - - software.amazon.awssdk s3 - software.amazon.awssdk sts - org.springframework.security spring-security-crypto - org.projectlombok lombok true - org.springframework.boot spring-boot-starter-test test - org.testcontainers postgresql test - org.testcontainers @@ -185,10 +131,8 @@ test - - org.apache.maven.plugins @@ -203,7 +147,6 @@ - org.apache.maven.plugins @@ -222,7 +165,6 @@ - org.apache.maven.plugins @@ -244,8 +186,6 @@ - - - + \ No newline at end of file diff --git a/src/main/java/com/label/LabelBackendApplication.java b/src/main/java/com/label/LabelBackendApplication.java index 056e5bc..533e0da 100644 --- a/src/main/java/com/label/LabelBackendApplication.java +++ b/src/main/java/com/label/LabelBackendApplication.java @@ -1,4 +1,3 @@ - package com.label; import org.springframework.boot.SpringApplication; @@ -6,18 +5,7 @@ import org.springframework.boot.autoconfigure.SpringBootApplication; /** * 应用入口。 - * - * 排除 Shiro Web 自动配置(ShiroWebAutoConfiguration、ShiroWebFilterConfiguration、 - * ShiroWebMvcAutoConfiguration),避免其依赖的 ShiroFilter(javax.servlet.Filter) 与 - * Spring Boot 3. 的 jakarta.servlet 命名空间冲突。 认证/ 授权逻辑改由 - * TokenFilter(OncePerRequestFilter)+ ShiroConfig 手动装配。 */ - -// (excludeName = { - -// "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebAutoConfiguration", -// "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebFilterConfiguration", -// "org.apache.shiro.spring.config.web.autoconfigure.ShiroWebMvcAutoConfiguration" }) @SpringBootApplication public class LabelBackendApplication { diff --git a/src/main/java/com/label/annotation/RequireAuth.java b/src/main/java/com/label/annotation/RequireAuth.java new file mode 100644 index 0000000..9ccb677 --- /dev/null +++ b/src/main/java/com/label/annotation/RequireAuth.java @@ -0,0 +1,11 @@ +package com.label.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.METHOD, ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface RequireAuth { +} diff --git a/src/main/java/com/label/annotation/RequireRole.java b/src/main/java/com/label/annotation/RequireRole.java new file mode 100644 index 0000000..aded7a0 --- /dev/null +++ b/src/main/java/com/label/annotation/RequireRole.java @@ -0,0 +1,13 @@ +package com.label.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@RequireAuth +@Target({ElementType.METHOD, ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface RequireRole { + String value(); +} diff --git a/src/main/java/com/label/common/shiro/TokenPrincipal.java b/src/main/java/com/label/common/auth/TokenPrincipal.java similarity index 75% rename from src/main/java/com/label/common/shiro/TokenPrincipal.java rename to src/main/java/com/label/common/auth/TokenPrincipal.java index 39aa63e..219e80f 100644 --- a/src/main/java/com/label/common/shiro/TokenPrincipal.java +++ b/src/main/java/com/label/common/auth/TokenPrincipal.java @@ -1,12 +1,10 @@ -package com.label.common.shiro; +package com.label.common.auth; import lombok.AllArgsConstructor; import lombok.Getter; + import java.io.Serializable; -/** - * Shiro principal carrying the authenticated user's session data. - */ @Getter @AllArgsConstructor public class TokenPrincipal implements Serializable { diff --git a/src/main/java/com/label/common/context/UserContext.java b/src/main/java/com/label/common/context/UserContext.java new file mode 100644 index 0000000..466a1ef --- /dev/null +++ b/src/main/java/com/label/common/context/UserContext.java @@ -0,0 +1,23 @@ +package com.label.common.context; + +import com.label.common.auth.TokenPrincipal; + +public final class UserContext { + private static final ThreadLocal PRINCIPAL = new ThreadLocal<>(); + + public static void set(TokenPrincipal principal) { + PRINCIPAL.set(principal); + } + + public static TokenPrincipal get() { + return PRINCIPAL.get(); + } + + public static void clear() { + PRINCIPAL.remove(); + } + + private UserContext() { + throw new UnsupportedOperationException("Utility class"); + } +} diff --git a/src/main/java/com/label/common/exception/GlobalExceptionHandler.java b/src/main/java/com/label/common/exception/GlobalExceptionHandler.java index 676fc89..07896f6 100644 --- a/src/main/java/com/label/common/exception/GlobalExceptionHandler.java +++ b/src/main/java/com/label/common/exception/GlobalExceptionHandler.java @@ -2,8 +2,6 @@ package com.label.common.exception; import com.label.common.result.Result; import lombok.extern.slf4j.Slf4j; -import org.apache.shiro.authz.AuthorizationException; -import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; @@ -16,26 +14,15 @@ public class GlobalExceptionHandler { public ResponseEntity> handleBusinessException(BusinessException e) { log.warn("业务异常: code={}, message={}", e.getCode(), e.getMessage()); return ResponseEntity - .status(e.getHttpStatus()) - .body(Result.failure(e.getCode(), e.getMessage())); - } - - /** - * 处理 Shiro 权限不足异常(@RequiresRoles / subject.checkRole() 抛出)→ 403 - */ - @ExceptionHandler(AuthorizationException.class) - public ResponseEntity> handleAuthorizationException(AuthorizationException e) { - log.warn("权限不足: {}", e.getMessage()); - return ResponseEntity - .status(HttpStatus.FORBIDDEN) - .body(Result.failure("FORBIDDEN", "权限不足")); + .status(e.getHttpStatus()) + .body(Result.failure(e.getCode(), e.getMessage())); } @ExceptionHandler(Exception.class) public ResponseEntity> handleException(Exception e) { log.error("系统异常", e); return ResponseEntity - .internalServerError() - .body(Result.failure("INTERNAL_ERROR", "系统内部错误")); + .internalServerError() + .body(Result.failure("INTERNAL_ERROR", "系统内部错误")); } } diff --git a/src/main/java/com/label/common/shiro/BearerToken.java b/src/main/java/com/label/common/shiro/BearerToken.java deleted file mode 100644 index 5febfc9..0000000 --- a/src/main/java/com/label/common/shiro/BearerToken.java +++ /dev/null @@ -1,26 +0,0 @@ -package com.label.common.shiro; - -import org.apache.shiro.authc.AuthenticationToken; - -/** - * Shiro AuthenticationToken wrapper for Bearer token strings. - */ -public class BearerToken implements AuthenticationToken { - private final String token; - private final TokenPrincipal principal; - - public BearerToken(String token, TokenPrincipal principal) { - this.token = token; - this.principal = principal; - } - - @Override - public Object getPrincipal() { - return principal; - } - - @Override - public Object getCredentials() { - return token; - } -} diff --git a/src/main/java/com/label/common/shiro/TokenFilter.java b/src/main/java/com/label/common/shiro/TokenFilter.java deleted file mode 100644 index 2f893f0..0000000 --- a/src/main/java/com/label/common/shiro/TokenFilter.java +++ /dev/null @@ -1,139 +0,0 @@ -package com.label.common.shiro; - -import java.io.IOException; -import java.util.Map; - -import org.apache.shiro.SecurityUtils; -import org.apache.shiro.util.ThreadContext; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.http.MediaType; -import org.springframework.web.filter.OncePerRequestFilter; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.label.common.context.CompanyContext; -import com.label.common.result.Result; -import com.label.service.RedisService; -import com.label.util.RedisUtil; - -import jakarta.servlet.FilterChain; -import jakarta.servlet.ServletException; -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; - -/** - * JWT-style Bearer Token 过滤器。 - * 继承 Spring 的 OncePerRequestFilter(jakarta.servlet),避免与 Shiro 1.x - * 的 PathMatchingFilter(javax.servlet)产生命名空间冲突。 - * - * 过滤逻辑: - * - 跳过非 /api/ 路径和 /api/auth/login(公开端点) - * - 解析 "Authorization: Bearer {uuid}",查询 Redis Hash token:{uuid} - * - Token 存在 → 注入 CompanyContext,登录 Shiro Subject,继续请求链路 - * - Token 缺失或过期 → 直接返回 401 - * - finally 块中清除 CompanyContext 和 ThreadContext Subject,防止线程池串漏 - */ -@Slf4j -@RequiredArgsConstructor -public class TokenFilter extends OncePerRequestFilter { - - private final RedisService redisService; - private final ObjectMapper objectMapper; - - @Value("${shiro.auth.enabled:true}") - private boolean authEnabled; - - @Value("${shiro.auth.mock-company-id:1}") - private Long mockCompanyId; - - @Value("${shiro.auth.mock-user-id:1}") - private Long mockUserId; - - @Value("${shiro.auth.mock-role:ADMIN}") - private String mockRole; - - @Value("${shiro.auth.mock-username:mock}") - private String mockUsername; - - @Value("${token.ttl-seconds:7200}") - private long tokenTtlSeconds; - - /** - * 公开端点跳过过滤:非 /api/ 前缀路径,以及登录接口本身。 - */ - @Override - protected boolean shouldNotFilter(HttpServletRequest request) { - String path = request.getServletPath(); - return !path.startsWith("/api/") - || path.equals("/api/auth/login") - || path.equals("/api/video/callback") - || path.startsWith("/swagger-ui") - || path.startsWith("/v3/api-docs"); // AI 服务内部回调,不走用户 Token 认证 - } - - @Override - protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, - FilterChain filterChain) throws ServletException, IOException { - try { - if (!authEnabled) { - TokenPrincipal principal = new TokenPrincipal( - mockUserId, mockRole, mockCompanyId, mockUsername, "mock-token"); - CompanyContext.set(mockCompanyId); - SecurityUtils.getSubject().login(new BearerToken("mock-token", principal)); - request.setAttribute("__token_principal__", principal); - filterChain.doFilter(request, response); - return; - } - - String authHeader = request.getHeader("Authorization"); - if (authHeader == null || !authHeader.toLowerCase().startsWith("bearer ")) { - writeUnauthorized(response, "缺少或无效的认证令牌"); - return; - } - String[] parts = authHeader.split("\\s+"); - if (parts.length != 2 || !"Bearer".equalsIgnoreCase(parts[0])) { - writeUnauthorized(response, "无效的认证格式"); - return; - } - String token = parts[1]; - // String token = authHeader.substring(7).trim(); - Map tokenData = redisService.hGetAll(RedisUtil.tokenKey(token)); - - if (tokenData == null || tokenData.isEmpty()) { - writeUnauthorized(response, "令牌已过期或不存在"); - return; - } - - Long userId = Long.parseLong(tokenData.get("userId").toString()); - String role = tokenData.get("role").toString(); - Long companyId = Long.parseLong(tokenData.get("companyId").toString()); - String username = tokenData.get("username").toString(); - - // 注入多租户上下文(finally 中清除,防止线程池串漏) - CompanyContext.set(companyId); - - // 创建 TokenPrincipal 并登录 Shiro Subject,使 @RequiresRoles 等注解生效 - TokenPrincipal principal = new TokenPrincipal(userId, role, companyId, username, token); - SecurityUtils.getSubject().login(new BearerToken(token, principal)); - request.setAttribute("__token_principal__", principal); - redisService.expire(RedisUtil.tokenKey(token), tokenTtlSeconds); - redisService.expire(RedisUtil.userSessionsKey(userId), tokenTtlSeconds); - - filterChain.doFilter(request, response); - } catch (Exception e) { - log.error("解析 Token 数据失败: {}", e.getMessage()); - writeUnauthorized(response, "令牌数据格式错误"); - } finally { - // 关键:必须清除 ThreadLocal,防止线程池复用时数据串漏 - CompanyContext.clear(); - ThreadContext.unbindSubject(); - } - } - - private void writeUnauthorized(HttpServletResponse resp, String message) throws IOException { - resp.setStatus(HttpServletResponse.SC_UNAUTHORIZED); - resp.setContentType(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); - resp.getWriter().write(objectMapper.writeValueAsString(Result.failure("UNAUTHORIZED", message))); - } -} diff --git a/src/main/java/com/label/common/shiro/UserRealm.java b/src/main/java/com/label/common/shiro/UserRealm.java deleted file mode 100644 index fb9af18..0000000 --- a/src/main/java/com/label/common/shiro/UserRealm.java +++ /dev/null @@ -1,88 +0,0 @@ -package com.label.common.shiro; - -import com.label.service.RedisService; -import com.label.util.RedisUtil; - -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.apache.shiro.authc.*; -import org.apache.shiro.authz.AuthorizationInfo; -import org.apache.shiro.authz.SimpleAuthorizationInfo; -import org.apache.shiro.realm.AuthorizingRealm; -import org.apache.shiro.subject.PrincipalCollection; - -/** - * Shiro Realm for role-based authorization using token-based authentication. - * - * Role hierarchy (addInheritedRoles): - * ADMIN ⊃ REVIEWER ⊃ ANNOTATOR ⊃ UPLOADER - * - * Permission lookup order: - * 1. Redis user:perm:{userId} (TTL 5 min) - * 2. If miss: use role from TokenPrincipal - */ -@Slf4j -@RequiredArgsConstructor -public class UserRealm extends AuthorizingRealm { - - private static final long PERM_CACHE_TTL = 300L; // 5 minutes - - private final RedisService redisService; - - @Override - public boolean supports(AuthenticationToken token) { - return token instanceof BearerToken; - } - - @Override - protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException { - // Token validation is done in TokenFilter; this realm only handles authorization - // For authentication, we trust the token that was validated by TokenFilter - return new SimpleAuthenticationInfo(token.getPrincipal(), token.getCredentials(), getName()); - } - - @Override - protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) { - TokenPrincipal principal = (TokenPrincipal) principals.getPrimaryPrincipal(); - if (principal == null) { - return new SimpleAuthorizationInfo(); - } - - String role = getRoleFromCacheOrPrincipal(principal); - SimpleAuthorizationInfo info = new SimpleAuthorizationInfo(); - info.addRole(role); - addInheritedRoles(info, role); - return info; - } - - private String getRoleFromCacheOrPrincipal(TokenPrincipal principal) { - String permKey = RedisUtil.userPermKey(principal.getUserId()); - String cachedRole = redisService.get(permKey); - if (cachedRole != null && !cachedRole.isEmpty()) { - return cachedRole; - } - // Cache miss: use role from token, then refresh cache - String role = principal.getRole(); - redisService.set(permKey, role, PERM_CACHE_TTL); - return role; - } - - /** - * ADMIN inherits all roles: ADMIN ⊃ REVIEWER ⊃ ANNOTATOR ⊃ UPLOADER - */ - private void addInheritedRoles(SimpleAuthorizationInfo info, String role) { - switch (role) { - case "ADMIN": - info.addRole("REVIEWER"); - // fall through - case "REVIEWER": - info.addRole("ANNOTATOR"); - // fall through - case "ANNOTATOR": - info.addRole("UPLOADER"); - break; - default: - break; - } - } -} diff --git a/src/main/java/com/label/config/AuthConfig.java b/src/main/java/com/label/config/AuthConfig.java new file mode 100644 index 0000000..e933f17 --- /dev/null +++ b/src/main/java/com/label/config/AuthConfig.java @@ -0,0 +1,20 @@ +package com.label.config; + +import com.label.interceptor.AuthInterceptor; +import lombok.RequiredArgsConstructor; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.InterceptorRegistry; +import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; + +@Configuration +@RequiredArgsConstructor +public class AuthConfig implements WebMvcConfigurer { + + private final AuthInterceptor authInterceptor; + + @Override + public void addInterceptors(InterceptorRegistry registry) { + registry.addInterceptor(authInterceptor) + .addPathPatterns("/**"); + } +} diff --git a/src/main/java/com/label/config/ShiroConfig.java b/src/main/java/com/label/config/ShiroConfig.java deleted file mode 100644 index 5d5d8bc..0000000 --- a/src/main/java/com/label/config/ShiroConfig.java +++ /dev/null @@ -1,66 +0,0 @@ -package com.label.config; - -import java.util.List; - -import org.apache.shiro.SecurityUtils; -import org.apache.shiro.mgt.SecurityManager; -import org.apache.shiro.web.mgt.DefaultWebSecurityManager; -import org.springframework.boot.web.servlet.FilterRegistrationBean; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.label.common.shiro.TokenFilter; -import com.label.common.shiro.UserRealm; -import com.label.service.RedisService; - -/** - * Shiro 安全配置。 - * - * 设计说明: - * - 使用 Spring 的 FilterRegistrationBean 注册 TokenFilter(jakarta.servlet), - * 替代 Shiro 的 ShiroFilterFactoryBean(javax.servlet),避免 Shiro 1.x 与 - * Spring Boot 3.x 之间的 javax/jakarta 命名空间冲突。 - * - URL 路由逻辑内聚于 TokenFilter.shouldNotFilter(): - * /api/auth/login → 跳过(公开) - * 非 /api/ 路径 → 跳过(公开) - * /api/** → 强制校验 Bearer Token - * - SecurityUtils.setSecurityManager() 必须在此处调用, - * 以便 @RequiresRoles 等 AOP 注解和 SecurityUtils.getSubject() 可正常工作。 - */ -@Configuration -public class ShiroConfig { - - @Bean - public UserRealm userRealm(RedisService redisService) { - return new UserRealm(redisService); - } - - @Bean - public SecurityManager securityManager(UserRealm userRealm) { - DefaultWebSecurityManager manager = new DefaultWebSecurityManager(); - manager.setRealms(List.of(userRealm)); - // 设置全局 SecurityManager,使 SecurityUtils.getSubject() 及 AOP 注解可用 - SecurityUtils.setSecurityManager(manager); - return manager; - } - - @Bean - public TokenFilter tokenFilter(RedisService redisService, ObjectMapper objectMapper) { - return new TokenFilter(redisService, objectMapper); - } - - /** - * 将 TokenFilter 注册为 Servlet 过滤器,覆盖所有路径。 - * 实际的路径过滤逻辑由 TokenFilter.shouldNotFilter() 控制。 - */ - @Bean - public FilterRegistrationBean tokenFilterRegistration(TokenFilter tokenFilter) { - FilterRegistrationBean registration = new FilterRegistrationBean<>(); - registration.setFilter(tokenFilter); - registration.addUrlPatterns("/*"); - registration.setOrder(1); - registration.setName("tokenFilter"); - return registration; - } -} diff --git a/src/main/java/com/label/controller/AuthController.java b/src/main/java/com/label/controller/AuthController.java index 521eb8d..42aa24f 100644 --- a/src/main/java/com/label/controller/AuthController.java +++ b/src/main/java/com/label/controller/AuthController.java @@ -1,7 +1,8 @@ package com.label.controller; +import com.label.annotation.RequireAuth; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.dto.LoginRequest; import com.label.dto.LoginResponse; import com.label.dto.UserInfoResponse; @@ -16,9 +17,9 @@ import org.springframework.web.bind.annotation.*; * 认证接口:登录、退出、获取当前用户。 * * 路由设计: - * - POST /api/auth/login → 匿名(TokenFilter.shouldNotFilter 跳过) - * - POST /api/auth/logout → 需要有效 Token(TokenFilter 校验) - * - GET /api/auth/me → 需要有效 Token(TokenFilter 校验) + * - POST /api/auth/login → 匿名(AuthInterceptor 跳过) + * - POST /api/auth/logout → 需要有效 Token(AuthInterceptor 校验) + * - GET /api/auth/me → 需要有效 Token(AuthInterceptor 校验) */ @Tag(name = "认证管理", description = "登录、退出和当前用户信息") @RestController @@ -42,6 +43,7 @@ public class AuthController { */ @Operation(summary = "退出登录并立即失效当前 Token") @PostMapping("/logout") + @RequireAuth public Result logout(HttpServletRequest request) { String token = extractToken(request); authService.logout(token); @@ -50,10 +52,11 @@ public class AuthController { /** * 获取当前登录用户信息。 - * TokenPrincipal 由 TokenFilter 写入请求属性 "__token_principal__"。 + * TokenPrincipal 由 AuthInterceptor 写入请求属性 "__token_principal__"。 */ @Operation(summary = "获取当前登录用户信息") @GetMapping("/me") + @RequireAuth public Result me(HttpServletRequest request) { TokenPrincipal principal = (TokenPrincipal) request.getAttribute("__token_principal__"); return Result.success(authService.me(principal)); diff --git a/src/main/java/com/label/controller/CompanyController.java b/src/main/java/com/label/controller/CompanyController.java new file mode 100644 index 0000000..7bdfc90 --- /dev/null +++ b/src/main/java/com/label/controller/CompanyController.java @@ -0,0 +1,73 @@ +package com.label.controller; + +import com.label.annotation.RequireRole; +import com.label.common.result.PageResult; +import com.label.common.result.Result; +import com.label.entity.SysCompany; +import com.label.service.CompanyService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import lombok.RequiredArgsConstructor; +import org.springframework.http.HttpStatus; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.PutMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.ResponseStatus; +import org.springframework.web.bind.annotation.RestController; + +import java.util.Map; + +@Tag(name = "公司管理", description = "租户公司增删改查") +@RestController +@RequestMapping("/api/companies") +@RequiredArgsConstructor +public class CompanyController { + + private final CompanyService companyService; + + @Operation(summary = "分页查询公司列表") + @GetMapping + @RequireRole("ADMIN") + public Result> list( + @RequestParam(defaultValue = "1") int page, + @RequestParam(defaultValue = "20") int pageSize, + @RequestParam(required = false) String status) { + return Result.success(companyService.list(page, pageSize, status)); + } + + @Operation(summary = "创建公司") + @PostMapping + @RequireRole("ADMIN") + @ResponseStatus(HttpStatus.CREATED) + public Result create(@RequestBody Map body) { + return Result.success(companyService.create(body.get("companyName"), body.get("companyCode"))); + } + + @Operation(summary = "更新公司信息") + @PutMapping("/{id}") + @RequireRole("ADMIN") + public Result update(@PathVariable Long id, @RequestBody Map body) { + return Result.success(companyService.update(id, body.get("companyName"), body.get("companyCode"))); + } + + @Operation(summary = "更新公司状态") + @PutMapping("/{id}/status") + @RequireRole("ADMIN") + public Result updateStatus(@PathVariable Long id, @RequestBody Map body) { + companyService.updateStatus(id, body.get("status")); + return Result.success(null); + } + + @Operation(summary = "删除公司") + @DeleteMapping("/{id}") + @RequireRole("ADMIN") + public Result delete(@PathVariable Long id) { + companyService.delete(id); + return Result.success(null); + } +} diff --git a/src/main/java/com/label/controller/ExportController.java b/src/main/java/com/label/controller/ExportController.java index 5ae1c92..6e88d7e 100644 --- a/src/main/java/com/label/controller/ExportController.java +++ b/src/main/java/com/label/controller/ExportController.java @@ -1,8 +1,9 @@ package com.label.controller; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.PageResult; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.entity.TrainingDataset; import com.label.entity.ExportBatch; import com.label.service.ExportService; @@ -11,7 +12,6 @@ import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.http.HttpStatus; import org.springframework.web.bind.annotation.*; @@ -32,7 +32,7 @@ public class ExportController { /** GET /api/training/samples — 分页查询已审批可导出样本 */ @Operation(summary = "分页查询可导出训练样本") @GetMapping("/api/training/samples") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result> listSamples( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, @@ -45,7 +45,7 @@ public class ExportController { /** POST /api/export/batch — 创建导出批次 */ @Operation(summary = "创建导出批次") @PostMapping("/api/export/batch") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") @ResponseStatus(HttpStatus.CREATED) public Result createBatch(@RequestBody Map body, HttpServletRequest request) { @@ -60,7 +60,7 @@ public class ExportController { /** POST /api/export/{batchId}/finetune — 提交微调任务 */ @Operation(summary = "提交微调任务") @PostMapping("/api/export/{batchId}/finetune") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result> triggerFinetune(@PathVariable Long batchId, HttpServletRequest request) { return Result.success(finetuneService.trigger(batchId, principal(request))); @@ -69,7 +69,7 @@ public class ExportController { /** GET /api/export/{batchId}/status — 查询微调状态 */ @Operation(summary = "查询微调状态") @GetMapping("/api/export/{batchId}/status") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result> getFinetuneStatus(@PathVariable Long batchId, HttpServletRequest request) { return Result.success(finetuneService.getStatus(batchId, principal(request))); @@ -78,7 +78,7 @@ public class ExportController { /** GET /api/export/list — 分页查询导出批次列表 */ @Operation(summary = "分页查询导出批次") @GetMapping("/api/export/list") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result> listBatches( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, diff --git a/src/main/java/com/label/controller/ExtractionController.java b/src/main/java/com/label/controller/ExtractionController.java index 0ef3400..65abec5 100644 --- a/src/main/java/com/label/controller/ExtractionController.java +++ b/src/main/java/com/label/controller/ExtractionController.java @@ -1,13 +1,13 @@ package com.label.controller; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.service.ExtractionService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.web.bind.annotation.*; import java.util.Map; @@ -26,7 +26,7 @@ public class ExtractionController { /** GET /api/extraction/{taskId} — 获取当前标注结果 */ @Operation(summary = "获取提取标注结果") @GetMapping("/{taskId}") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result> getResult(@PathVariable Long taskId, HttpServletRequest request) { return Result.success(extractionService.getResult(taskId, principal(request))); @@ -35,7 +35,7 @@ public class ExtractionController { /** PUT /api/extraction/{taskId} — 更新标注结果(整体覆盖) */ @Operation(summary = "更新提取标注结果") @PutMapping("/{taskId}") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result updateResult(@PathVariable Long taskId, @RequestBody String resultJson, HttpServletRequest request) { @@ -46,7 +46,7 @@ public class ExtractionController { /** POST /api/extraction/{taskId}/submit — 提交标注结果 */ @Operation(summary = "提交提取标注结果") @PostMapping("/{taskId}/submit") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result submit(@PathVariable Long taskId, HttpServletRequest request) { extractionService.submit(taskId, principal(request)); @@ -56,7 +56,7 @@ public class ExtractionController { /** POST /api/extraction/{taskId}/approve — 审批通过(REVIEWER) */ @Operation(summary = "审批通过提取结果") @PostMapping("/{taskId}/approve") - @RequiresRoles("REVIEWER") + @RequireRole("REVIEWER") public Result approve(@PathVariable Long taskId, HttpServletRequest request) { extractionService.approve(taskId, principal(request)); @@ -66,7 +66,7 @@ public class ExtractionController { /** POST /api/extraction/{taskId}/reject — 驳回(REVIEWER) */ @Operation(summary = "驳回提取结果") @PostMapping("/{taskId}/reject") - @RequiresRoles("REVIEWER") + @RequireRole("REVIEWER") public Result reject(@PathVariable Long taskId, @RequestBody Map body, HttpServletRequest request) { diff --git a/src/main/java/com/label/controller/QaController.java b/src/main/java/com/label/controller/QaController.java index 6641ffe..5c30b9c 100644 --- a/src/main/java/com/label/controller/QaController.java +++ b/src/main/java/com/label/controller/QaController.java @@ -1,13 +1,13 @@ package com.label.controller; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.service.QaService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.web.bind.annotation.*; import java.util.Map; @@ -26,7 +26,7 @@ public class QaController { /** GET /api/qa/{taskId} — 获取候选问答对 */ @Operation(summary = "获取候选问答对") @GetMapping("/{taskId}") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result> getResult(@PathVariable Long taskId, HttpServletRequest request) { return Result.success(qaService.getResult(taskId, principal(request))); @@ -35,7 +35,7 @@ public class QaController { /** PUT /api/qa/{taskId} — 整体覆盖问答对 */ @Operation(summary = "更新候选问答对") @PutMapping("/{taskId}") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result updateResult(@PathVariable Long taskId, @RequestBody String body, HttpServletRequest request) { @@ -46,7 +46,7 @@ public class QaController { /** POST /api/qa/{taskId}/submit — 提交问答对 */ @Operation(summary = "提交问答对") @PostMapping("/{taskId}/submit") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result submit(@PathVariable Long taskId, HttpServletRequest request) { qaService.submit(taskId, principal(request)); @@ -56,7 +56,7 @@ public class QaController { /** POST /api/qa/{taskId}/approve — 审批通过(REVIEWER) */ @Operation(summary = "审批通过问答对") @PostMapping("/{taskId}/approve") - @RequiresRoles("REVIEWER") + @RequireRole("REVIEWER") public Result approve(@PathVariable Long taskId, HttpServletRequest request) { qaService.approve(taskId, principal(request)); @@ -66,7 +66,7 @@ public class QaController { /** POST /api/qa/{taskId}/reject — 驳回(REVIEWER) */ @Operation(summary = "驳回答案对") @PostMapping("/{taskId}/reject") - @RequiresRoles("REVIEWER") + @RequireRole("REVIEWER") public Result reject(@PathVariable Long taskId, @RequestBody Map body, HttpServletRequest request) { diff --git a/src/main/java/com/label/controller/SourceController.java b/src/main/java/com/label/controller/SourceController.java index e4abfa6..e134d7b 100644 --- a/src/main/java/com/label/controller/SourceController.java +++ b/src/main/java/com/label/controller/SourceController.java @@ -1,15 +1,15 @@ package com.label.controller; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.PageResult; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.dto.SourceResponse; import com.label.service.SourceService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.http.HttpStatus; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; @@ -35,7 +35,7 @@ public class SourceController { */ @Operation(summary = "上传原始资料", description = "dataType: text,image, video") @PostMapping("/upload") - @RequiresRoles("UPLOADER") + @RequireRole("UPLOADER") @ResponseStatus(HttpStatus.CREATED) public Result upload( @RequestParam("file") MultipartFile file, @@ -51,7 +51,7 @@ public class SourceController { */ @Operation(summary = "分页查询资料列表") @GetMapping("/list") - @RequiresRoles("UPLOADER") + @RequireRole("UPLOADER") public Result> list( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, @@ -67,7 +67,7 @@ public class SourceController { */ @Operation(summary = "查询资料详情") @GetMapping("/{id}") - @RequiresRoles("UPLOADER") + @RequireRole("UPLOADER") public Result findById(@PathVariable Long id) { return Result.success(sourceService.findById(id)); } @@ -78,7 +78,7 @@ public class SourceController { */ @Operation(summary = "删除资料") @DeleteMapping("/{id}") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result delete(@PathVariable Long id, HttpServletRequest request) { TokenPrincipal principal = (TokenPrincipal) request.getAttribute("__token_principal__"); sourceService.delete(id, principal.getCompanyId()); diff --git a/src/main/java/com/label/controller/SysConfigController.java b/src/main/java/com/label/controller/SysConfigController.java index c632453..b0d775e 100644 --- a/src/main/java/com/label/controller/SysConfigController.java +++ b/src/main/java/com/label/controller/SysConfigController.java @@ -1,14 +1,14 @@ package com.label.controller; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.entity.SysConfig; import com.label.service.SysConfigService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.web.bind.annotation.*; import java.util.List; @@ -36,7 +36,7 @@ public class SysConfigController { */ @Operation(summary = "查询合并后的系统配置") @GetMapping("/api/config") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result>> listConfig(HttpServletRequest request) { TokenPrincipal principal = principal(request); return Result.success(sysConfigService.list(principal.getCompanyId())); @@ -49,7 +49,7 @@ public class SysConfigController { */ @Operation(summary = "更新或创建公司专属配置") @PutMapping("/api/config/{key}") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result updateConfig(@PathVariable String key, @RequestBody Map body, HttpServletRequest request) { diff --git a/src/main/java/com/label/controller/TaskController.java b/src/main/java/com/label/controller/TaskController.java index 99f15a6..a63ed8a 100644 --- a/src/main/java/com/label/controller/TaskController.java +++ b/src/main/java/com/label/controller/TaskController.java @@ -1,8 +1,9 @@ package com.label.controller; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.PageResult; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.dto.TaskResponse; import com.label.service.TaskClaimService; import com.label.service.TaskService; @@ -10,7 +11,6 @@ import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.web.bind.annotation.*; import java.util.Map; @@ -30,7 +30,7 @@ public class TaskController { /** GET /api/tasks/pool — 查询可领取任务池(角色感知) */ @Operation(summary = "查询可领取任务池") @GetMapping("/pool") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result> getPool( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, @@ -41,7 +41,7 @@ public class TaskController { /** GET /api/tasks/mine — 查询我的任务 */ @Operation(summary = "查询我的任务") @GetMapping("/mine") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result> getMine( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, @@ -53,7 +53,7 @@ public class TaskController { /** GET /api/tasks/pending-review — 待审批队列(REVIEWER 专属) */ @Operation(summary = "查询待审批任务") @GetMapping("/pending-review") - @RequiresRoles("REVIEWER") + @RequireRole("REVIEWER") public Result> getPendingReview( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, @@ -64,7 +64,7 @@ public class TaskController { /** GET /api/tasks — 查询全部任务(ADMIN) */ @Operation(summary = "管理员查询全部任务") @GetMapping - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result> getAll( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, @@ -76,7 +76,7 @@ public class TaskController { /** POST /api/tasks — 创建任务(ADMIN) */ @Operation(summary = "管理员创建任务") @PostMapping - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result createTask(@RequestBody Map body, HttpServletRequest request) { Long sourceId = Long.parseLong(body.get("sourceId").toString()); @@ -89,7 +89,7 @@ public class TaskController { /** GET /api/tasks/{id} — 查询任务详情 */ @Operation(summary = "查询任务详情") @GetMapping("/{id}") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result getById(@PathVariable Long id) { return Result.success(taskService.toPublicResponse(taskService.getById(id))); } @@ -97,7 +97,7 @@ public class TaskController { /** POST /api/tasks/{id}/claim — 领取任务 */ @Operation(summary = "领取任务") @PostMapping("/{id}/claim") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result claim(@PathVariable Long id, HttpServletRequest request) { taskClaimService.claim(id, principal(request)); return Result.success(null); @@ -106,7 +106,7 @@ public class TaskController { /** POST /api/tasks/{id}/unclaim — 放弃任务 */ @Operation(summary = "放弃任务") @PostMapping("/{id}/unclaim") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result unclaim(@PathVariable Long id, HttpServletRequest request) { taskClaimService.unclaim(id, principal(request)); return Result.success(null); @@ -115,7 +115,7 @@ public class TaskController { /** POST /api/tasks/{id}/reclaim — 重领被驳回的任务 */ @Operation(summary = "重领被驳回的任务") @PostMapping("/{id}/reclaim") - @RequiresRoles("ANNOTATOR") + @RequireRole("ANNOTATOR") public Result reclaim(@PathVariable Long id, HttpServletRequest request) { taskClaimService.reclaim(id, principal(request)); return Result.success(null); @@ -124,7 +124,7 @@ public class TaskController { /** PUT /api/tasks/{id}/reassign — ADMIN 强制指派 */ @Operation(summary = "管理员强制指派任务") @PutMapping("/{id}/reassign") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result reassign(@PathVariable Long id, @RequestBody Map body, HttpServletRequest request) { diff --git a/src/main/java/com/label/controller/UserController.java b/src/main/java/com/label/controller/UserController.java index 4b61a2e..8e450e9 100644 --- a/src/main/java/com/label/controller/UserController.java +++ b/src/main/java/com/label/controller/UserController.java @@ -2,7 +2,6 @@ package com.label.controller; import java.util.Map; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; @@ -12,9 +11,10 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.PageResult; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.entity.SysUser; import com.label.service.UserService; @@ -37,7 +37,7 @@ public class UserController { /** GET /api/users — 分页查询用户列表 */ @Operation(summary = "分页查询用户列表") @GetMapping - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result> listUsers( @RequestParam(defaultValue = "1") int page, @RequestParam(defaultValue = "20") int pageSize, @@ -48,7 +48,7 @@ public class UserController { /** POST /api/users — 创建用户 */ @Operation(summary = "创建用户") @PostMapping - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result createUser(@RequestBody Map body, HttpServletRequest request) { return Result.success(userService.createUser( @@ -62,7 +62,7 @@ public class UserController { /** PUT /api/users/{id} — 更新用户基本信息 */ @Operation(summary = "更新用户基本信息") @PutMapping("/{id}") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result updateUser(@PathVariable Long id, @RequestBody Map body, HttpServletRequest request) { @@ -76,7 +76,7 @@ public class UserController { /** PUT /api/users/{id}/status — 变更用户状态 */ @Operation(summary = "变更用户状态", description = "status:ACTIVE、DISABLED") @PutMapping("/{id}/status") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result updateStatus(@PathVariable Long id, @RequestBody Map body, HttpServletRequest request) { @@ -87,7 +87,7 @@ public class UserController { /** PUT /api/users/{id}/role — 变更用户角色 */ @Operation(summary = "变更用户角色", description = "role:ADMIN、UPLOADER、VIEWER") @PutMapping("/{id}/role") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result updateRole(@PathVariable Long id, @RequestBody Map body, HttpServletRequest request) { diff --git a/src/main/java/com/label/controller/VideoController.java b/src/main/java/com/label/controller/VideoController.java index 41dc239..1749f51 100644 --- a/src/main/java/com/label/controller/VideoController.java +++ b/src/main/java/com/label/controller/VideoController.java @@ -1,7 +1,8 @@ package com.label.controller; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; import com.label.common.result.Result; -import com.label.common.shiro.TokenPrincipal; import com.label.entity.VideoProcessJob; import com.label.service.VideoProcessService; import io.swagger.v3.oas.annotations.Operation; @@ -9,7 +10,6 @@ import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.apache.shiro.authz.annotation.RequiresRoles; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.bind.annotation.*; @@ -21,7 +21,7 @@ import java.util.Map; * POST /api/video/process — 触发视频处理(ADMIN) * GET /api/video/jobs/{jobId} — 查询任务状态(ADMIN) * POST /api/video/jobs/{jobId}/reset — 重置失败任务(ADMIN) - * POST /api/video/callback — AI 回调接口(无需认证,已在 TokenFilter 中排除) + * POST /api/video/callback — AI 回调接口(无需认证,已在 AuthInterceptor 中排除) */ @Tag(name = "视频处理", description = "视频处理任务创建、查询、重置和回调") @Slf4j @@ -37,7 +37,7 @@ public class VideoController { /** POST /api/video/process — 触发视频处理任务 */ @Operation(summary = "触发视频处理任务") @PostMapping("/api/video/process") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result createJob(@RequestBody Map body, HttpServletRequest request) { Object sourceIdVal = body.get("sourceId"); @@ -57,7 +57,7 @@ public class VideoController { /** GET /api/video/jobs/{jobId} — 查询视频处理任务 */ @Operation(summary = "查询视频处理任务状态") @GetMapping("/api/video/jobs/{jobId}") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result getJob(@PathVariable Long jobId, HttpServletRequest request) { return Result.success(videoProcessService.getJob(jobId, principal(request).getCompanyId())); @@ -66,7 +66,7 @@ public class VideoController { /** POST /api/video/jobs/{jobId}/reset — 管理员重置失败任务 */ @Operation(summary = "重置失败的视频处理任务") @PostMapping("/api/video/jobs/{jobId}/reset") - @RequiresRoles("ADMIN") + @RequireRole("ADMIN") public Result resetJob(@PathVariable Long jobId, HttpServletRequest request) { return Result.success(videoProcessService.reset(jobId, principal(request).getCompanyId())); @@ -75,7 +75,7 @@ public class VideoController { /** * POST /api/video/callback — AI 服务回调(无需 Bearer Token)。 * - * 此端点已在 TokenFilter.shouldNotFilter() 中排除认证, + * 此端点已在 AuthInterceptor 中排除认证, * 由 AI 服务直接调用,携带 jobId、status、outputPath 等参数。 * * Body 示例: diff --git a/src/main/java/com/label/interceptor/AuthInterceptor.java b/src/main/java/com/label/interceptor/AuthInterceptor.java new file mode 100644 index 0000000..fa18903 --- /dev/null +++ b/src/main/java/com/label/interceptor/AuthInterceptor.java @@ -0,0 +1,182 @@ +package com.label.interceptor; + +import java.io.IOException; +import java.util.Map; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.annotation.AnnotatedElementUtils; +import org.springframework.http.MediaType; +import org.springframework.stereotype.Component; +import org.springframework.web.method.HandlerMethod; +import org.springframework.web.servlet.HandlerInterceptor; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; +import com.label.common.context.CompanyContext; +import com.label.common.context.UserContext; +import com.label.common.result.Result; +import com.label.service.RedisService; +import com.label.util.RedisUtil; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@Component +@RequiredArgsConstructor +public class AuthInterceptor implements HandlerInterceptor { + + private final RedisService redisService; + private final ObjectMapper objectMapper; + + @Value("${auth.enabled:true}") + private boolean authEnabled; + + @Value("${auth.mock-company-id:1}") + private Long mockCompanyId; + + @Value("${auth.mock-user-id:1}") + private Long mockUserId; + + @Value("${auth.mock-role:ADMIN}") + private String mockRole; + + @Value("${auth.mock-username:mock}") + private String mockUsername; + + @Value("${token.ttl-seconds:7200}") + private long tokenTtlSeconds; + + @Override + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) + throws Exception { + String path = requestPath(request); + if (isPublicPath(path)) { + return true; + } + + TokenPrincipal principal = authEnabled + ? resolvePrincipal(request, response) + : new TokenPrincipal(mockUserId, mockRole, mockCompanyId, mockUsername, "mock-token"); + if (principal == null) { + return false; + } + + bindPrincipal(request, principal); + + RequireRole requiredRole = requiredRole(handler); + if (requiredRole != null && !hasRole(principal.getRole(), requiredRole.value())) { + writeFailure(response, HttpServletResponse.SC_FORBIDDEN, "FORBIDDEN", "权限不足"); + return false; + } + + return true; + } + + @Override + public void afterCompletion(HttpServletRequest request, HttpServletResponse response, + Object handler, Exception ex) { + UserContext.clear(); + CompanyContext.clear(); + } + + private TokenPrincipal resolvePrincipal(HttpServletRequest request, HttpServletResponse response) + throws IOException { + String authHeader = request.getHeader("Authorization"); + if (authHeader == null || !authHeader.toLowerCase().startsWith("bearer ")) { + writeFailure(response, HttpServletResponse.SC_UNAUTHORIZED, + "UNAUTHORIZED", "缺少或无效的认证令牌"); + return null; + } + + String[] parts = authHeader.split("\\s+"); + if (parts.length != 2 || !"Bearer".equalsIgnoreCase(parts[0])) { + writeFailure(response, HttpServletResponse.SC_UNAUTHORIZED, + "UNAUTHORIZED", "无效的认证格式"); + return null; + } + + String token = parts[1]; + Map tokenData = redisService.hGetAll(RedisUtil.tokenKey(token)); + if (tokenData == null || tokenData.isEmpty()) { + writeFailure(response, HttpServletResponse.SC_UNAUTHORIZED, + "UNAUTHORIZED", "令牌已过期或不存在"); + return null; + } + + try { + Long userId = Long.parseLong(tokenData.get("userId").toString()); + String role = tokenData.get("role").toString(); + Long companyId = Long.parseLong(tokenData.get("companyId").toString()); + String username = tokenData.get("username").toString(); + redisService.expire(RedisUtil.tokenKey(token), tokenTtlSeconds); + redisService.expire(RedisUtil.userSessionsKey(userId), tokenTtlSeconds); + return new TokenPrincipal(userId, role, companyId, username, token); + } catch (Exception e) { + log.warn("解析 Token 数据失败: {}", e.getMessage()); + writeFailure(response, HttpServletResponse.SC_UNAUTHORIZED, + "UNAUTHORIZED", "令牌数据格式错误"); + return null; + } + } + + private void bindPrincipal(HttpServletRequest request, TokenPrincipal principal) { + CompanyContext.set(principal.getCompanyId()); + UserContext.set(principal); + request.setAttribute("__token_principal__", principal); + } + + private RequireRole requiredRole(Object handler) { + if (!(handler instanceof HandlerMethod handlerMethod)) { + return null; + } + + RequireRole methodRole = AnnotatedElementUtils.findMergedAnnotation( + handlerMethod.getMethod(), RequireRole.class); + if (methodRole != null) { + return methodRole; + } + return AnnotatedElementUtils.findMergedAnnotation( + handlerMethod.getBeanType(), RequireRole.class); + } + + private boolean hasRole(String actualRole, String requiredRole) { + return roleLevel(actualRole) >= roleLevel(requiredRole); + } + + private int roleLevel(String role) { + return switch (role) { + case "ADMIN" -> 4; + case "REVIEWER" -> 3; + case "ANNOTATOR" -> 2; + case "UPLOADER" -> 1; + default -> 0; + }; + } + + private boolean isPublicPath(String path) { + return !path.startsWith("/api/") + || path.equals("/api/auth/login") + || path.equals("/api/video/callback") + || path.startsWith("/swagger-ui") + || path.startsWith("/v3/api-docs"); + } + + private String requestPath(HttpServletRequest request) { + String path = request.getServletPath(); + if (path == null || path.isBlank()) { + path = request.getRequestURI(); + } + return path != null ? path : ""; + } + + private void writeFailure(HttpServletResponse response, int status, String code, String message) + throws IOException { + response.setStatus(status); + response.setContentType(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + response.getWriter().write(objectMapper.writeValueAsString(Result.failure(code, message))); + } +} diff --git a/src/main/java/com/label/listener/ExtractionApprovedEventListener.java b/src/main/java/com/label/listener/ExtractionApprovedEventListener.java index 3cac541..89964f6 100644 --- a/src/main/java/com/label/listener/ExtractionApprovedEventListener.java +++ b/src/main/java/com/label/listener/ExtractionApprovedEventListener.java @@ -1,18 +1,9 @@ package com.label.listener; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.label.common.ai.AiServiceClient; -import com.label.common.context.CompanyContext; -import com.label.entity.TrainingDataset; -import com.label.mapper.AnnotationResultMapper; -import com.label.mapper.TrainingDatasetMapper; -import com.label.entity.SourceData; -import com.label.mapper.SourceDataMapper; -import com.label.service.TaskClaimService; -import com.label.service.TaskService; -import com.label.event.ExtractionApprovedEvent; -import lombok.RequiredArgsConstructor; -import lombok.extern.slf4j.Slf4j; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import org.springframework.transaction.annotation.Propagation; @@ -20,9 +11,18 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.event.TransactionPhase; import org.springframework.transaction.event.TransactionalEventListener; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.common.ai.AiServiceClient; +import com.label.common.context.CompanyContext; +import com.label.entity.SourceData; +import com.label.entity.TrainingDataset; +import com.label.event.ExtractionApprovedEvent; +import com.label.mapper.SourceDataMapper; +import com.label.mapper.TrainingDatasetMapper; +import com.label.service.TaskService; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; /** * 提取审批通过后的异步处理器。 @@ -89,7 +89,8 @@ public class ExtractionApprovedEventListener { ? aiServiceClient.genImageQa(req) : aiServiceClient.genTextQa(req); qaPairs = response != null && response.getQaPairs() != null - ? response.getQaPairs() : Collections.emptyList(); + ? response.getQaPairs() + : Collections.emptyList(); } catch (Exception e) { log.warn("AI 问答生成失败(taskId={}):{},将使用空问答对", event.getTaskId(), e.getMessage()); qaPairs = Collections.emptyList(); diff --git a/src/main/java/com/label/mapper/SysUserMapper.java b/src/main/java/com/label/mapper/SysUserMapper.java index 79033ff..be06306 100644 --- a/src/main/java/com/label/mapper/SysUserMapper.java +++ b/src/main/java/com/label/mapper/SysUserMapper.java @@ -31,4 +31,8 @@ public interface SysUserMapper extends BaseMapper { @Select("SELECT * FROM sys_user WHERE company_id = #{companyId} AND username = #{username} AND status = 'ACTIVE'") SysUser selectByCompanyAndUsername(@Param("companyId") Long companyId, @Param("username") String username); + + @InterceptorIgnore(tenantLine = "true") + @Select("SELECT COUNT(1) FROM sys_user WHERE company_id = #{companyId}") + Long countByCompanyId(@Param("companyId") Long companyId); } diff --git a/src/main/java/com/label/service/AuthService.java b/src/main/java/com/label/service/AuthService.java index 98921a7..b2d427a 100644 --- a/src/main/java/com/label/service/AuthService.java +++ b/src/main/java/com/label/service/AuthService.java @@ -1,7 +1,7 @@ package com.label.service; import com.label.common.exception.BusinessException; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.dto.LoginRequest; import com.label.dto.LoginResponse; import com.label.dto.UserInfoResponse; @@ -117,7 +117,7 @@ public class AuthService { /** * 获取当前登录用户详情(含 realName、companyName)。 * - * @param principal TokenFilter 注入的当前用户主体 + * @param principal AuthInterceptor 注入的当前用户主体 * @return 用户信息响应体 */ public UserInfoResponse me(TokenPrincipal principal) { diff --git a/src/main/java/com/label/service/CompanyService.java b/src/main/java/com/label/service/CompanyService.java new file mode 100644 index 0000000..ed5c7c4 --- /dev/null +++ b/src/main/java/com/label/service/CompanyService.java @@ -0,0 +1,122 @@ +package com.label.service; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.label.common.exception.BusinessException; +import com.label.common.result.PageResult; +import com.label.entity.SysCompany; +import com.label.mapper.SysCompanyMapper; +import com.label.mapper.SysUserMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Slf4j +@Service +@RequiredArgsConstructor +public class CompanyService { + + private final SysCompanyMapper companyMapper; + private final SysUserMapper userMapper; + + public PageResult list(int page, int pageSize, String status) { + pageSize = Math.min(pageSize, 100); + LambdaQueryWrapper wrapper = new LambdaQueryWrapper() + .orderByDesc(SysCompany::getCreatedAt); + if (status != null && !status.isBlank()) { + wrapper.eq(SysCompany::getStatus, status); + } + + Page result = companyMapper.selectPage(new Page<>(page, pageSize), wrapper); + return PageResult.of(result.getRecords(), result.getTotal(), page, pageSize); + } + + @Transactional + public SysCompany create(String companyName, String companyCode) { + String normalizedName = requireText(companyName, "公司名称不能为空"); + String normalizedCode = normalizeCode(companyCode); + + ensureUniqueCode(null, normalizedCode); + ensureUniqueName(null, normalizedName); + + SysCompany company = new SysCompany(); + company.setCompanyName(normalizedName); + company.setCompanyCode(normalizedCode); + company.setStatus("ACTIVE"); + companyMapper.insert(company); + log.info("公司已创建: id={}, code={}", company.getId(), normalizedCode); + return company; + } + + @Transactional + public SysCompany update(Long companyId, String companyName, String companyCode) { + SysCompany company = getExistingCompany(companyId); + String normalizedName = requireText(companyName, "公司名称不能为空"); + String normalizedCode = normalizeCode(companyCode); + + ensureUniqueCode(companyId, normalizedCode); + ensureUniqueName(companyId, normalizedName); + + company.setCompanyName(normalizedName); + company.setCompanyCode(normalizedCode); + companyMapper.updateById(company); + return company; + } + + @Transactional + public void updateStatus(Long companyId, String status) { + SysCompany company = getExistingCompany(companyId); + if (!"ACTIVE".equals(status) && !"DISABLED".equals(status)) { + throw new BusinessException("INVALID_COMPANY_STATUS", "公司状态不合法", HttpStatus.BAD_REQUEST); + } + company.setStatus(status); + companyMapper.updateById(company); + } + + @Transactional + public void delete(Long companyId) { + getExistingCompany(companyId); + Long userCount = userMapper.countByCompanyId(companyId); + if (userCount != null && userCount > 0) { + throw new BusinessException("COMPANY_HAS_USERS", "公司下仍存在用户,无法删除", HttpStatus.CONFLICT); + } + companyMapper.deleteById(companyId); + } + + private SysCompany getExistingCompany(Long companyId) { + SysCompany company = companyMapper.selectById(companyId); + if (company == null) { + throw new BusinessException("NOT_FOUND", "公司不存在: " + companyId, HttpStatus.NOT_FOUND); + } + return company; + } + + private void ensureUniqueCode(Long companyId, String companyCode) { + SysCompany existing = companyMapper.selectByCompanyCode(companyCode); + if (existing != null && !existing.getId().equals(companyId)) { + throw new BusinessException("DUPLICATE_COMPANY_CODE", "公司代码已存在", HttpStatus.CONFLICT); + } + } + + private void ensureUniqueName(Long companyId, String companyName) { + SysCompany existing = companyMapper.selectOne(new LambdaQueryWrapper() + .eq(SysCompany::getCompanyName, companyName) + .last("LIMIT 1")); + if (existing != null && !existing.getId().equals(companyId)) { + throw new BusinessException("DUPLICATE_COMPANY_NAME", "公司名称已存在", HttpStatus.CONFLICT); + } + } + + private String requireText(String text, String message) { + if (text == null || text.isBlank()) { + throw new BusinessException("INVALID_COMPANY_FIELD", message, HttpStatus.BAD_REQUEST); + } + return text.trim(); + } + + private String normalizeCode(String companyCode) { + return requireText(companyCode, "公司代码不能为空").toUpperCase(); + } +} diff --git a/src/main/java/com/label/service/ExportService.java b/src/main/java/com/label/service/ExportService.java index 6641afa..da5232c 100644 --- a/src/main/java/com/label/service/ExportService.java +++ b/src/main/java/com/label/service/ExportService.java @@ -5,7 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.label.common.exception.BusinessException; import com.label.common.result.PageResult; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.common.storage.RustFsClient; import com.label.entity.TrainingDataset; import com.label.mapper.TrainingDatasetMapper; diff --git a/src/main/java/com/label/service/ExtractionService.java b/src/main/java/com/label/service/ExtractionService.java index e6d11e1..0959770 100644 --- a/src/main/java/com/label/service/ExtractionService.java +++ b/src/main/java/com/label/service/ExtractionService.java @@ -4,7 +4,7 @@ import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.fasterxml.jackson.databind.ObjectMapper; import com.label.common.ai.AiServiceClient; import com.label.common.exception.BusinessException; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.common.statemachine.StateValidator; import com.label.common.statemachine.TaskStatus; import com.label.entity.AnnotationResult; diff --git a/src/main/java/com/label/service/FinetuneService.java b/src/main/java/com/label/service/FinetuneService.java index 8f687df..d8d55e6 100644 --- a/src/main/java/com/label/service/FinetuneService.java +++ b/src/main/java/com/label/service/FinetuneService.java @@ -2,7 +2,7 @@ package com.label.service; import com.label.common.ai.AiServiceClient; import com.label.common.exception.BusinessException; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.entity.ExportBatch; import com.label.mapper.ExportBatchMapper; import lombok.RequiredArgsConstructor; diff --git a/src/main/java/com/label/service/QaService.java b/src/main/java/com/label/service/QaService.java index 976d07f..3df42de 100644 --- a/src/main/java/com/label/service/QaService.java +++ b/src/main/java/com/label/service/QaService.java @@ -3,7 +3,7 @@ package com.label.service; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.fasterxml.jackson.databind.ObjectMapper; import com.label.common.exception.BusinessException; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.common.statemachine.StateValidator; import com.label.common.statemachine.TaskStatus; import com.label.entity.TrainingDataset; diff --git a/src/main/java/com/label/service/SourceService.java b/src/main/java/com/label/service/SourceService.java index 6f567c0..7c80d79 100644 --- a/src/main/java/com/label/service/SourceService.java +++ b/src/main/java/com/label/service/SourceService.java @@ -5,7 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.label.common.exception.BusinessException; import com.label.common.result.PageResult; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.common.storage.RustFsClient; import com.label.dto.SourceResponse; import com.label.entity.SourceData; diff --git a/src/main/java/com/label/service/TaskClaimService.java b/src/main/java/com/label/service/TaskClaimService.java index 2db50c3..f38de3d 100644 --- a/src/main/java/com/label/service/TaskClaimService.java +++ b/src/main/java/com/label/service/TaskClaimService.java @@ -2,7 +2,7 @@ package com.label.service; import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.label.common.exception.BusinessException; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.common.statemachine.StateValidator; import com.label.common.statemachine.TaskStatus; import com.label.entity.AnnotationTask; diff --git a/src/main/java/com/label/service/TaskService.java b/src/main/java/com/label/service/TaskService.java index 529da41..3771985 100644 --- a/src/main/java/com/label/service/TaskService.java +++ b/src/main/java/com/label/service/TaskService.java @@ -5,7 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.label.common.exception.BusinessException; import com.label.common.result.PageResult; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.dto.TaskResponse; import com.label.entity.AnnotationTask; import com.label.mapper.AnnotationTaskMapper; diff --git a/src/main/java/com/label/service/UserService.java b/src/main/java/com/label/service/UserService.java index a3f28ab..b292b64 100644 --- a/src/main/java/com/label/service/UserService.java +++ b/src/main/java/com/label/service/UserService.java @@ -12,7 +12,7 @@ import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.label.common.exception.BusinessException; import com.label.common.result.PageResult; -import com.label.common.shiro.TokenPrincipal; +import com.label.common.auth.TokenPrincipal; import com.label.entity.SysUser; import com.label.mapper.SysUserMapper; import com.label.util.RedisUtil; diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index be46b7f..eaae6e4 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -5,9 +5,9 @@ spring: application: name: label-backend datasource: - url: ${SPRING_DATASOURCE_URL:jdbc:postgresql://39.107.112.174:5432/labeldb} + url: ${SPRING_DATASOURCE_URL:jdbc:postgresql://localhost:5432/labeldb} username: ${SPRING_DATASOURCE_USERNAME:postgres} - password: ${SPRING_DATASOURCE_PASSWORD:postgres!Pw} + password: ${SPRING_DATASOURCE_PASSWORD:} driver-class-name: org.postgresql.Driver hikari: maximum-pool-size: 20 @@ -16,9 +16,9 @@ spring: data: redis: - host: ${SPRING_DATA_REDIS_HOST:39.107.112.174} + host: ${SPRING_DATA_REDIS_HOST:localhost} port: ${SPRING_DATA_REDIS_PORT:6379} - password: ${SPRING_DATA_REDIS_PASSWORD:jsti@2024} + password: ${SPRING_DATA_REDIS_PASSWORD:} timeout: 5000ms lettuce: pool: @@ -33,7 +33,7 @@ spring: mvc: pathmatch: - matching-strategy: ant_path_matcher # Shiro 与 Spring Boot 3 兼容性需要 + matching-strategy: ant_path_matcher springdoc: api-docs: @@ -45,7 +45,7 @@ springdoc: mybatis-plus: mapper-locations: classpath*:mapper/**/*.xml - type-aliases-package: com.label.module + type-aliases-package: com.label.entity configuration: map-underscore-to-camel-case: true log-impl: org.apache.ibatis.logging.slf4j.Slf4jImpl @@ -54,31 +54,29 @@ mybatis-plus: id-type: auto rustfs: - endpoint: ${RUSTFS_ENDPOINT:http://39.107.112.174:9000} + endpoint: ${RUSTFS_ENDPOINT:http://localhost:9000} access-key: ${RUSTFS_ACCESS_KEY:admin} - secret-key: ${RUSTFS_SECRET_KEY:your_strong_password} + secret-key: ${RUSTFS_SECRET_KEY:local-secret-key} region: us-east-1 ai-service: base-url: ${AI_SERVICE_BASE_URL:http://localhost:8000} - timeout: 30000 # milliseconds + timeout: 30000 -shiro: - auth: - enabled: false - mock-company-id: 1 - mock-user-id: 1 - mock-role: ADMIN - mock-username: mock +auth: + enabled: true + mock-company-id: 1 + mock-user-id: 1 + mock-role: ADMIN + mock-username: mock token: - ttl-seconds: 7200 # Token 默认有效期(秒),与 sys_config token_ttl_seconds 保持一致 + ttl-seconds: 7200 video: - callback-secret: ${VIDEO_CALLBACK_SECRET:} # AI 服务回调共享密钥,为空时跳过校验(开发环境) + callback-secret: ${VIDEO_CALLBACK_SECRET:} logging: level: com.label: INFO - org.apache.shiro: INFO com.baomidou.mybatisplus: INFO diff --git a/src/test/java/com/label/integration/ShiroFilterIntegrationTest.java b/src/test/java/com/label/integration/ShiroFilterIntegrationTest.java deleted file mode 100644 index ef85234..0000000 --- a/src/test/java/com/label/integration/ShiroFilterIntegrationTest.java +++ /dev/null @@ -1,160 +0,0 @@ -package com.label.integration; - -import com.label.AbstractIntegrationTest; -import com.label.common.result.Result; -import com.label.service.RedisService; -import com.label.util.RedisUtil; - -import org.apache.shiro.SecurityUtils; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.context.TestConfiguration; -import org.springframework.boot.test.web.client.TestRestTemplate; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Import; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; -import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.RestController; - -import java.util.HashMap; -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Shiro 过滤器集成测试: - * - 无 Token → 401 Unauthorized - * - Token 不存在(已过期或伪造)→ 401 Unauthorized - * - 有效 Token 但角色不足(ANNOTATOR 访问 REVIEWER 端点)→ 403 Forbidden - * - 有效 Token 且角色满足(REVIEWER 访问 REVIEWER 端点)→ 200 OK - */ -@Import(ShiroFilterIntegrationTest.TestConfig.class) -public class ShiroFilterIntegrationTest extends AbstractIntegrationTest { - - /** 仅供测试的临时 Token,测试结束后清理 */ - private static final String REVIEWER_TOKEN = "test-reviewer-token-uuid-fixed"; - private static final String ANNOTATOR_TOKEN = "test-annotator-token-uuid-fixed"; - - @Autowired - private TestRestTemplate restTemplate; - - @Autowired - private RedisService redisService; - - // ------------------------------------------------------------------ 测试 Controller -- - - /** - * 测试专用配置:注册仅在测试环境存在的端点 - */ - @TestConfiguration - static class TestConfig { - @Bean - public ReviewerOnlyController reviewerOnlyController() { - return new ReviewerOnlyController(); - } - } - - /** - * 需要 REVIEWER 角色的测试端点。 - * 调用 subject.checkRole() —— 角色不足时抛出 AuthorizationException → 403。 - */ - @RestController - static class ReviewerOnlyController { - @GetMapping("/api/test/reviewer-only") - public Result reviewerOnly() { - // 验证当前 Subject 是否持有 REVIEWER 角色 - SecurityUtils.getSubject().checkRole("REVIEWER"); - return Result.success("ok"); - } - } - - // ------------------------------------------------------------------ 测试前后置 -- - - @BeforeEach - void setupTokens() { - // REVIEWER Token:companyId=1, userId=2 - Map reviewerData = new HashMap<>(); - reviewerData.put("userId", "2"); - reviewerData.put("role", "REVIEWER"); - reviewerData.put("companyId", "1"); - reviewerData.put("username", "reviewer01"); - redisService.hSetAll(RedisUtil.tokenKey(REVIEWER_TOKEN), reviewerData, 3600L); - - // ANNOTATOR Token:companyId=1, userId=3 - Map annotatorData = new HashMap<>(); - annotatorData.put("userId", "3"); - annotatorData.put("role", "ANNOTATOR"); - annotatorData.put("companyId", "1"); - annotatorData.put("username", "annotator01"); - redisService.hSetAll(RedisUtil.tokenKey(ANNOTATOR_TOKEN), annotatorData, 3600L); - } - - @AfterEach - void cleanupTokens() { - redisService.delete(RedisUtil.tokenKey(REVIEWER_TOKEN)); - redisService.delete(RedisUtil.tokenKey(ANNOTATOR_TOKEN)); - } - - // ------------------------------------------------------------------ 测试用例 -- - - @Test - @DisplayName("无 Authorization 头 → 401 Unauthorized") - void noToken_returns401() { - ResponseEntity response = restTemplate.getForEntity( - baseUrl("/api/test/reviewer-only"), String.class); - - assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); - } - - @Test - @DisplayName("Token 不存在于 Redis → 401 Unauthorized") - void expiredToken_returns401() { - ResponseEntity response = restTemplate.exchange( - baseUrl("/api/test/reviewer-only"), - HttpMethod.GET, - bearerRequest("non-existent-token-xyz"), - String.class); - - assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); - } - - @Test - @DisplayName("有效 Token 但角色不足(ANNOTATOR 访问 REVIEWER 端点)→ 403 Forbidden") - void annotatorToken_onReviewerEndpoint_returns403() { - ResponseEntity response = restTemplate.exchange( - baseUrl("/api/test/reviewer-only"), - HttpMethod.GET, - bearerRequest(ANNOTATOR_TOKEN), - String.class); - - assertThat(response.getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); - } - - @Test - @DisplayName("有效 Token 且角色满足(REVIEWER 访问 REVIEWER 端点)→ 200 OK") - void reviewerToken_onReviewerEndpoint_returns200() { - ResponseEntity response = restTemplate.exchange( - baseUrl("/api/test/reviewer-only"), - HttpMethod.GET, - bearerRequest(REVIEWER_TOKEN), - String.class); - - assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); - } - - // ------------------------------------------------------------------ 工具方法 -- - - /** 构造带 Bearer Token 的请求实体 */ - private HttpEntity bearerRequest(String token) { - HttpHeaders headers = new HttpHeaders(); - headers.set("Authorization", "Bearer " + token); - return new HttpEntity<>(headers); - } -} diff --git a/src/test/java/com/label/unit/ApplicationConfigTest.java b/src/test/java/com/label/unit/ApplicationConfigTest.java index fa64482..2ca303e 100644 --- a/src/test/java/com/label/unit/ApplicationConfigTest.java +++ b/src/test/java/com/label/unit/ApplicationConfigTest.java @@ -14,8 +14,8 @@ import static org.assertj.core.api.Assertions.assertThat; class ApplicationConfigTest { @Test - @DisplayName("application.yml 提供 Swagger 和 shiro.auth 测试开关配置") - void applicationYaml_containsSwaggerAndShiroAuthToggle() throws Exception { + @DisplayName("application.yml 提供 Swagger 和 auth 测试开关配置") + void applicationYaml_containsSwaggerAndAuthToggle() throws Exception { PropertySource source = new YamlPropertySourceLoader() .load("application", new ClassPathResource("application.yml")) .get(0); @@ -24,10 +24,10 @@ class ApplicationConfigTest { assertThat(source.getProperty("springdoc.api-docs.path")).isEqualTo("/v3/api-docs"); assertThat(source.getProperty("springdoc.swagger-ui.enabled")).isEqualTo(true); assertThat(source.getProperty("springdoc.swagger-ui.path")).isEqualTo("/swagger-ui.html"); - assertThat(source.getProperty("shiro.auth.enabled")).isEqualTo(true); - assertThat(source.getProperty("shiro.auth.mock-company-id")).isEqualTo(1); - assertThat(source.getProperty("shiro.auth.mock-user-id")).isEqualTo(1); - assertThat(source.getProperty("shiro.auth.mock-role")).isEqualTo("ADMIN"); + assertThat(source.getProperty("auth.enabled")).isEqualTo(true); + assertThat(source.getProperty("auth.mock-company-id")).isEqualTo(1); + assertThat(source.getProperty("auth.mock-user-id")).isEqualTo(1); + assertThat(source.getProperty("auth.mock-role")).isEqualTo("ADMIN"); assertThat(source.getProperty("logging.level.com.label")).isEqualTo("INFO"); } diff --git a/src/test/java/com/label/unit/AuthInterceptorTest.java b/src/test/java/com/label/unit/AuthInterceptorTest.java new file mode 100644 index 0000000..37b7df5 --- /dev/null +++ b/src/test/java/com/label/unit/AuthInterceptorTest.java @@ -0,0 +1,151 @@ +package com.label.unit; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.annotation.RequireAuth; +import com.label.annotation.RequireRole; +import com.label.common.auth.TokenPrincipal; +import com.label.common.context.CompanyContext; +import com.label.common.context.UserContext; +import com.label.interceptor.AuthInterceptor; +import com.label.service.RedisService; +import com.label.util.RedisUtil; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.web.method.HandlerMethod; + +import java.lang.reflect.Method; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@DisplayName("自定义认证鉴权拦截器测试") +class AuthInterceptorTest { + + private final RedisService redisService = mock(RedisService.class); + private final AuthInterceptor interceptor = new AuthInterceptor(redisService, new ObjectMapper()); + + @AfterEach + void tearDown() { + CompanyContext.clear(); + UserContext.clear(); + } + + @Test + @DisplayName("有效 Token 会注入 Principal、租户上下文并刷新 TTL") + void validTokenInjectsPrincipalAndRefreshesTtl() throws Exception { + ReflectionTestUtils.setField(interceptor, "authEnabled", true); + ReflectionTestUtils.setField(interceptor, "tokenTtlSeconds", 7200L); + when(redisService.hGetAll(RedisUtil.tokenKey("valid-token"))).thenReturn(Map.of( + "userId", "10", + "role", "ADMIN", + "companyId", "20", + "username", "admin" + )); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/test/admin"); + request.addHeader("Authorization", "Bearer valid-token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + boolean proceed = interceptor.preHandle(request, response, handler("adminOnly")); + + assertThat(proceed).isTrue(); + TokenPrincipal principal = (TokenPrincipal) request.getAttribute("__token_principal__"); + assertThat(principal.getUserId()).isEqualTo(10L); + assertThat(principal.getRole()).isEqualTo("ADMIN"); + assertThat(CompanyContext.get()).isEqualTo(20L); + assertThat(UserContext.get()).isSameAs(principal); + verify(redisService).expire(RedisUtil.tokenKey("valid-token"), 7200L); + verify(redisService).expire(RedisUtil.userSessionsKey(10L), 7200L); + } + + @Test + @DisplayName("角色继承规则允许 ADMIN 访问 REVIEWER 接口") + void adminRoleInheritsReviewerRole() throws Exception { + ReflectionTestUtils.setField(interceptor, "authEnabled", true); + when(redisService.hGetAll(RedisUtil.tokenKey("admin-token"))).thenReturn(Map.of( + "userId", "1", + "role", "ADMIN", + "companyId", "1", + "username", "admin" + )); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/test/reviewer"); + request.addHeader("Authorization", "Bearer admin-token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + assertThat(interceptor.preHandle(request, response, handler("reviewerOnly"))).isTrue(); + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + } + + @Test + @DisplayName("角色不足时返回 403") + void insufficientRoleReturnsForbidden() throws Exception { + ReflectionTestUtils.setField(interceptor, "authEnabled", true); + when(redisService.hGetAll(RedisUtil.tokenKey("annotator-token"))).thenReturn(Map.of( + "userId", "2", + "role", "ANNOTATOR", + "companyId", "1", + "username", "annotator" + )); + + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/test/reviewer"); + request.addHeader("Authorization", "Bearer annotator-token"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + assertThat(interceptor.preHandle(request, response, handler("reviewerOnly"))).isFalse(); + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); + } + + @Test + @DisplayName("缺少 Token 时返回 401") + void missingTokenReturnsUnauthorized() throws Exception { + ReflectionTestUtils.setField(interceptor, "authEnabled", true); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/test/admin"); + MockHttpServletResponse response = new MockHttpServletResponse(); + + assertThat(interceptor.preHandle(request, response, handler("adminOnly"))).isFalse(); + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_UNAUTHORIZED); + verify(redisService, never()).hGetAll(org.mockito.ArgumentMatchers.anyString()); + } + + @Test + @DisplayName("请求完成后清理用户和公司 ThreadLocal") + void afterCompletionClearsContexts() throws Exception { + CompanyContext.set(20L); + UserContext.set(new TokenPrincipal(10L, "ADMIN", 20L, "admin", "token")); + + interceptor.afterCompletion(new MockHttpServletRequest(), new MockHttpServletResponse(), + handler("adminOnly"), null); + + assertThat(CompanyContext.get()).isEqualTo(-1L); + assertThat(UserContext.get()).isNull(); + } + + private static HandlerMethod handler(String methodName) throws NoSuchMethodException { + Method method = TestController.class.getDeclaredMethod(methodName); + return new HandlerMethod(new TestController(), method); + } + + private static class TestController { + @RequireRole("ADMIN") + void adminOnly() { + } + + @RequireRole("REVIEWER") + void reviewerOnly() { + } + + @RequireAuth + void authenticatedOnly() { + } + } +} diff --git a/src/test/java/com/label/unit/CompanyServiceTest.java b/src/test/java/com/label/unit/CompanyServiceTest.java new file mode 100644 index 0000000..6dd8dd6 --- /dev/null +++ b/src/test/java/com/label/unit/CompanyServiceTest.java @@ -0,0 +1,73 @@ +package com.label.unit; + +import com.label.common.exception.BusinessException; +import com.label.entity.SysCompany; +import com.label.mapper.SysCompanyMapper; +import com.label.mapper.SysUserMapper; +import com.label.service.CompanyService; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@DisplayName("公司管理服务测试") +class CompanyServiceTest { + + private final SysCompanyMapper companyMapper = mock(SysCompanyMapper.class); + private final SysUserMapper userMapper = mock(SysUserMapper.class); + private final CompanyService companyService = new CompanyService(companyMapper, userMapper); + + @Test + @DisplayName("创建公司时写入 ACTIVE 状态并保存公司代码") + void createCompanyInsertsActiveCompany() { + SysCompany company = companyService.create("测试公司", "TEST"); + + assertThat(company.getCompanyName()).isEqualTo("测试公司"); + assertThat(company.getCompanyCode()).isEqualTo("TEST"); + assertThat(company.getStatus()).isEqualTo("ACTIVE"); + verify(companyMapper).insert(any(SysCompany.class)); + } + + @Test + @DisplayName("创建公司时拒绝重复公司代码") + void createCompanyRejectsDuplicateCode() { + SysCompany existing = new SysCompany(); + existing.setId(1L); + when(companyMapper.selectByCompanyCode("DEMO")).thenReturn(existing); + + assertThatThrownBy(() -> companyService.create("演示公司", "DEMO")) + .isInstanceOf(BusinessException.class) + .hasMessageContaining("公司代码已存在"); + } + + @Test + @DisplayName("禁用公司时只允许 ACTIVE 或 DISABLED") + void updateStatusRejectsInvalidStatus() { + SysCompany existing = new SysCompany(); + existing.setId(1L); + existing.setStatus("ACTIVE"); + when(companyMapper.selectById(1L)).thenReturn(existing); + + assertThatThrownBy(() -> companyService.updateStatus(1L, "DELETED")) + .isInstanceOf(BusinessException.class) + .hasMessageContaining("公司状态不合法"); + } + + @Test + @DisplayName("删除公司时若仍有关联用户则拒绝删除") + void deleteRejectsCompanyWithUsers() { + SysCompany existing = new SysCompany(); + existing.setId(1L); + when(companyMapper.selectById(1L)).thenReturn(existing); + when(userMapper.countByCompanyId(1L)).thenReturn(2L); + + assertThatThrownBy(() -> companyService.delete(1L)) + .isInstanceOf(BusinessException.class) + .hasMessageContaining("公司下仍存在用户"); + } +} diff --git a/src/test/java/com/label/unit/OpenApiAnnotationTest.java b/src/test/java/com/label/unit/OpenApiAnnotationTest.java index 35c7f99..230c67a 100644 --- a/src/test/java/com/label/unit/OpenApiAnnotationTest.java +++ b/src/test/java/com/label/unit/OpenApiAnnotationTest.java @@ -1,6 +1,7 @@ package com.label.unit; import com.label.controller.AuthController; +import com.label.controller.CompanyController; import com.label.controller.ExportController; import com.label.controller.ExtractionController; import com.label.controller.QaController; @@ -37,6 +38,7 @@ class OpenApiAnnotationTest { private static final List> CONTROLLERS = List.of( AuthController.class, + CompanyController.class, UserController.class, SourceController.class, TaskController.class, diff --git a/src/test/java/com/label/unit/ShiroConfigTest.java b/src/test/java/com/label/unit/ShiroConfigTest.java deleted file mode 100644 index dc82ca4..0000000 --- a/src/test/java/com/label/unit/ShiroConfigTest.java +++ /dev/null @@ -1,40 +0,0 @@ -package com.label.unit; - -import com.label.common.shiro.UserRealm; -import com.label.config.ShiroConfig; -import com.label.service.RedisService; - -import org.apache.shiro.SecurityUtils; -import org.apache.shiro.mgt.SecurityManager; -import org.apache.shiro.util.ThreadContext; -import org.apache.shiro.web.mgt.DefaultWebSecurityManager; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.mockito.Mockito.mock; - -@DisplayName("ShiroConfig 单元测试") -class ShiroConfigTest { - - @AfterEach - void tearDown() { - org.apache.shiro.util.ThreadContext.remove(); - } - - @Test - @DisplayName("securityManager 不应依赖 DefaultWebSecurityManager,以避免 javax.servlet 兼容性问题") - void securityManager_shouldNotDependOnDefaultWebSecurityManager() { - ShiroConfig config = new ShiroConfig(); - RedisService redisService = mock(RedisService.class); - UserRealm realm = config.userRealm(redisService); - - SecurityManager securityManager = config.securityManager(realm); - - assertThat(securityManager).isNotInstanceOf(DefaultWebSecurityManager.class); - ThreadContext.bind(securityManager); - assertThatCode(SecurityUtils::getSubject).doesNotThrowAnyException(); - } -} diff --git a/src/test/java/com/label/unit/TokenFilterTest.java b/src/test/java/com/label/unit/TokenFilterTest.java deleted file mode 100644 index 2098f64..0000000 --- a/src/test/java/com/label/unit/TokenFilterTest.java +++ /dev/null @@ -1,127 +0,0 @@ -package com.label.unit; - -import com.fasterxml.jackson.databind.ObjectMapper; -import com.label.common.context.CompanyContext; -import com.label.common.shiro.TokenFilter; -import com.label.common.shiro.TokenPrincipal; -import com.label.common.shiro.UserRealm; -import com.label.config.ShiroConfig; -import com.label.service.RedisService; -import com.label.util.RedisUtil; - -import org.apache.shiro.SecurityUtils; -import org.apache.shiro.mgt.DefaultSecurityManager; -import org.apache.shiro.util.ThreadContext; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.test.util.ReflectionTestUtils; - -import jakarta.servlet.FilterChain; -import jakarta.servlet.ServletRequest; -import jakarta.servlet.ServletResponse; -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.*; - -@DisplayName("TokenFilter 单元测试") -class TokenFilterTest { - - private RedisService redisService; - private TestableTokenFilter filter; - private DefaultSecurityManager securityManager; - - @BeforeEach - void setUp() { - redisService = mock(RedisService.class); - UserRealm userRealm = new UserRealm(redisService); - securityManager = (DefaultSecurityManager) new ShiroConfig().securityManager(userRealm); - filter = new TestableTokenFilter(redisService, new ObjectMapper(), securityManager); - SecurityUtils.setSecurityManager(securityManager); - } - - @AfterEach - void tearDown() { - CompanyContext.clear(); - ThreadContext.remove(); - } - - @Test - @DisplayName("有效 Token 请求会刷新 token TTL,实现滑动过期") - void validToken_refreshesTokenTtl() throws Exception { - ReflectionTestUtils.setField(filter, "authEnabled", true); - ReflectionTestUtils.setField(filter, "tokenTtlSeconds", 7200L); - String token = "valid-token"; - when(redisService.hGetAll(RedisUtil.tokenKey(token))).thenReturn(Map.of( - "userId", "10", - "role", "ADMIN", - "companyId", "20", - "username", "admin" - )); - - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/tasks"); - request.addHeader("Authorization", "Bearer " + token); - MockHttpServletResponse response = new MockHttpServletResponse(); - RecordingChain chain = new RecordingChain(); - - filter.invoke(request, response, chain); - - assertThat(response.getStatus()).isEqualTo(200); - assertThat(chain.principal).isInstanceOf(TokenPrincipal.class); - assertThat(chain.roleChecked).isTrue(); - verify(redisService).expire(RedisUtil.tokenKey(token), 7200L); - } - - @Test - @DisplayName("shiro.auth.enabled=false 时注入 mock Principal 并跳过 Redis 校验") - void authDisabled_injectsMockPrincipalWithoutRedisLookup() throws Exception { - ReflectionTestUtils.setField(filter, "authEnabled", false); - ReflectionTestUtils.setField(filter, "mockCompanyId", 3L); - ReflectionTestUtils.setField(filter, "mockUserId", 4L); - ReflectionTestUtils.setField(filter, "mockRole", "ADMIN"); - ReflectionTestUtils.setField(filter, "mockUsername", "mock-admin"); - - MockHttpServletRequest request = new MockHttpServletRequest("GET", "/api/tasks"); - MockHttpServletResponse response = new MockHttpServletResponse(); - RecordingChain chain = new RecordingChain(); - - filter.invoke(request, response, chain); - - assertThat(response.getStatus()).isEqualTo(200); - TokenPrincipal principal = chain.principal; - assertThat(principal.getCompanyId()).isEqualTo(3L); - assertThat(principal.getUserId()).isEqualTo(4L); - assertThat(principal.getRole()).isEqualTo("ADMIN"); - assertThat(principal.getUsername()).isEqualTo("mock-admin"); - assertThat(chain.roleChecked).isTrue(); - verify(redisService, never()).hGetAll(anyString()); - } - - private static final class RecordingChain implements FilterChain { - private TokenPrincipal principal; - private boolean roleChecked; - - @Override - public void doFilter(ServletRequest request, ServletResponse response) { - principal = (TokenPrincipal) request.getAttribute("__token_principal__"); - SecurityUtils.getSubject().checkRole(principal.getRole()); - roleChecked = true; - } - } - - private static final class TestableTokenFilter extends TokenFilter { - private TestableTokenFilter(RedisService redisService, ObjectMapper objectMapper, - DefaultSecurityManager securityManager) { - super(redisService, objectMapper); - } - - private void invoke(MockHttpServletRequest request, MockHttpServletResponse response, FilterChain chain) - throws Exception { - super.doFilterInternal(request, response, chain); - } - } -}