修改shiro 兼容性问题

This commit is contained in:
wh
2026-04-13 19:58:49 +08:00
parent 5d74578aa3
commit e8235eeec5
6 changed files with 94 additions and 61 deletions

View File

@@ -2,9 +2,10 @@ package com.label.common.shiro;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.label.common.redis.RedisService; import com.label.common.redis.RedisService;
import org.apache.shiro.SecurityUtils; import org.apache.shiro.mgt.DefaultSessionStorageEvaluator;
import org.apache.shiro.mgt.DefaultSecurityManager;
import org.apache.shiro.mgt.DefaultSubjectDAO;
import org.apache.shiro.mgt.SecurityManager; import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
@@ -12,18 +13,7 @@ import org.springframework.context.annotation.Configuration;
import java.util.List; import java.util.List;
/** /**
* Shiro 安全配置。 * Shiro security configuration for the Jakarta servlet stack.
*
* 设计说明:
* - 使用 Spring 的 FilterRegistrationBean 注册 TokenFilterjakarta.servlet
* 替代 Shiro 的 ShiroFilterFactoryBeanjavax.servlet避免 Shiro 1.x 与
* Spring Boot 3.x 之间的 javax/jakarta 命名空间冲突。
* - URL 路由逻辑内聚于 TokenFilter.shouldNotFilter()
* /api/auth/login → 跳过(公开)
* 非 /api/ 路径 → 跳过(公开)
* /api/** → 强制校验 Bearer Token
* - SecurityUtils.setSecurityManager() 必须在此处调用,
* 以便 @RequiresRoles 等 AOP 注解和 SecurityUtils.getSubject() 可正常工作。
*/ */
@Configuration @Configuration
public class ShiroConfig { public class ShiroConfig {
@@ -35,22 +25,28 @@ public class ShiroConfig {
@Bean @Bean
public SecurityManager securityManager(UserRealm userRealm) { public SecurityManager securityManager(UserRealm userRealm) {
DefaultWebSecurityManager manager = new DefaultWebSecurityManager(); // Keep Shiro on the core stack. Shiro 1.x web classes depend on javax.servlet.
DefaultSecurityManager manager = new DefaultSecurityManager();
manager.setRealms(List.of(userRealm)); manager.setRealms(List.of(userRealm));
// 设置全局 SecurityManager使 SecurityUtils.getSubject() 及 AOP 注解可用 manager.setSubjectDAO(statelessSubjectDao());
SecurityUtils.setSecurityManager(manager);
return manager; return manager;
} }
@Bean private DefaultSubjectDAO statelessSubjectDao() {
public TokenFilter tokenFilter(RedisService redisService, ObjectMapper objectMapper) { DefaultSessionStorageEvaluator evaluator = new DefaultSessionStorageEvaluator();
return new TokenFilter(redisService, objectMapper); evaluator.setSessionStorageEnabled(false);
DefaultSubjectDAO subjectDAO = new DefaultSubjectDAO();
subjectDAO.setSessionStorageEvaluator(evaluator);
return subjectDAO;
}
@Bean
public TokenFilter tokenFilter(RedisService redisService, ObjectMapper objectMapper,
SecurityManager securityManager) {
return new TokenFilter(redisService, objectMapper, securityManager);
} }
/**
* 将 TokenFilter 注册为 Servlet 过滤器,覆盖所有路径。
* 实际的路径过滤逻辑由 TokenFilter.shouldNotFilter() 控制。
*/
@Bean @Bean
public FilterRegistrationBean<TokenFilter> tokenFilterRegistration(TokenFilter tokenFilter) { public FilterRegistrationBean<TokenFilter> tokenFilterRegistration(TokenFilter tokenFilter) {
FilterRegistrationBean<TokenFilter> registration = new FilterRegistrationBean<>(); FilterRegistrationBean<TokenFilter> registration = new FilterRegistrationBean<>();

View File

@@ -11,7 +11,9 @@ import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.shiro.SecurityUtils; import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.subject.SimplePrincipalCollection;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext; import org.apache.shiro.util.ThreadContext;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
@@ -38,6 +40,7 @@ public class TokenFilter extends OncePerRequestFilter {
private final RedisService redisService; private final RedisService redisService;
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;
private final SecurityManager securityManager;
@Value("${shiro.auth.enabled:true}") @Value("${shiro.auth.enabled:true}")
private boolean authEnabled; private boolean authEnabled;
@@ -78,7 +81,7 @@ public class TokenFilter extends OncePerRequestFilter {
TokenPrincipal principal = new TokenPrincipal( TokenPrincipal principal = new TokenPrincipal(
mockUserId, mockRole, mockCompanyId, mockUsername, "mock-token"); mockUserId, mockRole, mockCompanyId, mockUsername, "mock-token");
CompanyContext.set(mockCompanyId); CompanyContext.set(mockCompanyId);
SecurityUtils.getSubject().login(new BearerToken("mock-token", principal)); bindSubject(principal);
request.setAttribute("__token_principal__", principal); request.setAttribute("__token_principal__", principal);
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
return; return;
@@ -113,7 +116,7 @@ public class TokenFilter extends OncePerRequestFilter {
// 创建 TokenPrincipal 并登录 Shiro Subject使 @RequiresRoles 等注解生效 // 创建 TokenPrincipal 并登录 Shiro Subject使 @RequiresRoles 等注解生效
TokenPrincipal principal = new TokenPrincipal(userId, role, companyId, username, token); TokenPrincipal principal = new TokenPrincipal(userId, role, companyId, username, token);
SecurityUtils.getSubject().login(new BearerToken(token, principal)); bindSubject(principal);
request.setAttribute("__token_principal__", principal); request.setAttribute("__token_principal__", principal);
redisService.expire(RedisKeyManager.tokenKey(token), tokenTtlSeconds); redisService.expire(RedisKeyManager.tokenKey(token), tokenTtlSeconds);
redisService.expire(RedisKeyManager.userSessionsKey(userId), tokenTtlSeconds); redisService.expire(RedisKeyManager.userSessionsKey(userId), tokenTtlSeconds);
@@ -126,9 +129,21 @@ public class TokenFilter extends OncePerRequestFilter {
// 关键:必须清除 ThreadLocal防止线程池复用时数据串漏 // 关键:必须清除 ThreadLocal防止线程池复用时数据串漏
CompanyContext.clear(); CompanyContext.clear();
ThreadContext.unbindSubject(); ThreadContext.unbindSubject();
ThreadContext.unbindSecurityManager();
} }
} }
private void bindSubject(TokenPrincipal principal) {
SimplePrincipalCollection principals = new SimplePrincipalCollection(principal, UserRealm.class.getName());
Subject subject = new Subject.Builder(securityManager)
.principals(principals)
.authenticated(true)
.sessionCreationEnabled(false)
.buildSubject();
ThreadContext.bind(securityManager);
ThreadContext.bind(subject);
}
private void writeUnauthorized(HttpServletResponse resp, String message) throws IOException { private void writeUnauthorized(HttpServletResponse resp, String message) throws IOException {
resp.setStatus(HttpServletResponse.SC_UNAUTHORIZED); resp.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
resp.setContentType(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); resp.setContentType(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8");

View File

@@ -33,7 +33,7 @@ public class SourceController {
* 上传文件multipart/form-data * 上传文件multipart/form-data
* 返回 201 Created + 资料摘要。 * 返回 201 Created + 资料摘要。
*/ */
@Operation(summary = "上传原始资料") @Operation(summary = "上传原始资料", description = "dataType: text,image, video")
@PostMapping("/upload") @PostMapping("/upload")
@RequiresRoles("UPLOADER") @RequiresRoles("UPLOADER")
@ResponseStatus(HttpStatus.CREATED) @ResponseStatus(HttpStatus.CREATED)

View File

@@ -65,7 +65,7 @@ ai-service:
shiro: shiro:
auth: auth:
enabled: true enabled: false
mock-company-id: 1 mock-company-id: 1
mock-user-id: 1 mock-user-id: 1
mock-role: ADMIN mock-role: ADMIN

View File

@@ -0,0 +1,39 @@
package com.label.unit;
import com.label.common.redis.RedisService;
import com.label.common.shiro.ShiroConfig;
import com.label.common.shiro.UserRealm;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.util.ThreadContext;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.Mockito.mock;
@DisplayName("ShiroConfig 单元测试")
class ShiroConfigTest {
@AfterEach
void tearDown() {
org.apache.shiro.util.ThreadContext.remove();
}
@Test
@DisplayName("securityManager 不应依赖 DefaultWebSecurityManager以避免 javax.servlet 兼容性问题")
void securityManager_shouldNotDependOnDefaultWebSecurityManager() {
ShiroConfig config = new ShiroConfig();
RedisService redisService = mock(RedisService.class);
UserRealm realm = config.userRealm(redisService);
SecurityManager securityManager = config.securityManager(realm);
assertThat(securityManager).isNotInstanceOf(DefaultWebSecurityManager.class);
ThreadContext.bind(securityManager);
assertThatCode(SecurityUtils::getSubject).doesNotThrowAnyException();
}
}

View File

@@ -4,18 +4,12 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.label.common.context.CompanyContext; import com.label.common.context.CompanyContext;
import com.label.common.redis.RedisKeyManager; import com.label.common.redis.RedisKeyManager;
import com.label.common.redis.RedisService; import com.label.common.redis.RedisService;
import com.label.common.shiro.BearerToken; import com.label.common.shiro.ShiroConfig;
import com.label.common.shiro.TokenFilter; import com.label.common.shiro.TokenFilter;
import com.label.common.shiro.TokenPrincipal; import com.label.common.shiro.TokenPrincipal;
import com.label.common.shiro.UserRealm;
import org.apache.shiro.SecurityUtils; import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authc.AuthenticationInfo;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.authc.SimpleAuthenticationInfo;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.mgt.DefaultSecurityManager; import org.apache.shiro.mgt.DefaultSecurityManager;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.util.ThreadContext; import org.apache.shiro.util.ThreadContext;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@@ -38,12 +32,15 @@ class TokenFilterTest {
private RedisService redisService; private RedisService redisService;
private TestableTokenFilter filter; private TestableTokenFilter filter;
private DefaultSecurityManager securityManager;
@BeforeEach @BeforeEach
void setUp() { void setUp() {
redisService = mock(RedisService.class); redisService = mock(RedisService.class);
filter = new TestableTokenFilter(redisService, new ObjectMapper()); UserRealm userRealm = new UserRealm(redisService);
SecurityUtils.setSecurityManager(new DefaultSecurityManager(new BearerTokenRealm())); securityManager = (DefaultSecurityManager) new ShiroConfig().securityManager(userRealm);
filter = new TestableTokenFilter(redisService, new ObjectMapper(), securityManager);
SecurityUtils.setSecurityManager(securityManager);
} }
@AfterEach @AfterEach
@@ -74,6 +71,7 @@ class TokenFilterTest {
assertThat(response.getStatus()).isEqualTo(200); assertThat(response.getStatus()).isEqualTo(200);
assertThat(chain.principal).isInstanceOf(TokenPrincipal.class); assertThat(chain.principal).isInstanceOf(TokenPrincipal.class);
assertThat(chain.roleChecked).isTrue();
verify(redisService).expire(RedisKeyManager.tokenKey(token), 7200L); verify(redisService).expire(RedisKeyManager.tokenKey(token), 7200L);
} }
@@ -98,41 +96,26 @@ class TokenFilterTest {
assertThat(principal.getUserId()).isEqualTo(4L); assertThat(principal.getUserId()).isEqualTo(4L);
assertThat(principal.getRole()).isEqualTo("ADMIN"); assertThat(principal.getRole()).isEqualTo("ADMIN");
assertThat(principal.getUsername()).isEqualTo("mock-admin"); assertThat(principal.getUsername()).isEqualTo("mock-admin");
assertThat(chain.roleChecked).isTrue();
verify(redisService, never()).hGetAll(anyString()); verify(redisService, never()).hGetAll(anyString());
} }
private static final class BearerTokenRealm extends AuthorizingRealm {
@Override
public boolean supports(AuthenticationToken token) {
return token instanceof BearerToken;
}
@Override
protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) {
return new SimpleAuthenticationInfo(token.getPrincipal(), token.getCredentials(), getName());
}
@Override
protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
TokenPrincipal principal = (TokenPrincipal) principals.getPrimaryPrincipal();
SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();
info.addRole(principal.getRole());
return info;
}
}
private static final class RecordingChain implements FilterChain { private static final class RecordingChain implements FilterChain {
private TokenPrincipal principal; private TokenPrincipal principal;
private boolean roleChecked;
@Override @Override
public void doFilter(ServletRequest request, ServletResponse response) { public void doFilter(ServletRequest request, ServletResponse response) {
principal = (TokenPrincipal) request.getAttribute("__token_principal__"); principal = (TokenPrincipal) request.getAttribute("__token_principal__");
SecurityUtils.getSubject().checkRole(principal.getRole());
roleChecked = true;
} }
} }
private static final class TestableTokenFilter extends TokenFilter { private static final class TestableTokenFilter extends TokenFilter {
private TestableTokenFilter(RedisService redisService, ObjectMapper objectMapper) { private TestableTokenFilter(RedisService redisService, ObjectMapper objectMapper,
super(redisService, objectMapper); DefaultSecurityManager securityManager) {
super(redisService, objectMapper, securityManager);
} }
private void invoke(MockHttpServletRequest request, MockHttpServletResponse response, FilterChain chain) private void invoke(MockHttpServletRequest request, MockHttpServletResponse response, FilterChain chain)