From 0cd99aa22c4d3807aa01199494f84b7dd6338f5f Mon Sep 17 00:00:00 2001 From: wh Date: Thu, 9 Apr 2026 13:54:35 +0800 Subject: [PATCH] On branch 001-label-backend-spec Changes to be committed: new file: src/main/java/com/label/common/shiro/BearerToken.java new file: src/main/java/com/label/common/shiro/ShiroConfig.java new file: src/main/java/com/label/common/shiro/TokenFilter.java new file: src/main/java/com/label/common/shiro/TokenPrincipal.java new file: src/main/java/com/label/common/shiro/UserRealm.java modified: src/main/java/com/label/common/statemachine/DatasetStatus.java new file: src/test/java/com/label/AbstractIntegrationTest.java new file: src/test/java/com/label/unit/StateMachineTest.java new file: src/test/resources/db/init.sql --- .../com/label/common/shiro/BearerToken.java | 26 ++ .../com/label/common/shiro/ShiroConfig.java | 71 ++++ .../com/label/common/shiro/TokenFilter.java | 95 +++++ .../label/common/shiro/TokenPrincipal.java | 18 + .../com/label/common/shiro/UserRealm.java | 87 +++++ .../common/statemachine/DatasetStatus.java | 5 +- .../com/label/AbstractIntegrationTest.java | 87 +++++ .../java/com/label/unit/StateMachineTest.java | 265 ++++++++++++++ src/test/resources/db/init.sql | 332 ++++++++++++++++++ 9 files changed, 984 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/label/common/shiro/BearerToken.java create mode 100644 src/main/java/com/label/common/shiro/ShiroConfig.java create mode 100644 src/main/java/com/label/common/shiro/TokenFilter.java create mode 100644 src/main/java/com/label/common/shiro/TokenPrincipal.java create mode 100644 src/main/java/com/label/common/shiro/UserRealm.java create mode 100644 src/test/java/com/label/AbstractIntegrationTest.java create mode 100644 src/test/java/com/label/unit/StateMachineTest.java create mode 100644 src/test/resources/db/init.sql diff --git a/src/main/java/com/label/common/shiro/BearerToken.java b/src/main/java/com/label/common/shiro/BearerToken.java new file mode 100644 index 0000000..5febfc9 --- /dev/null +++ b/src/main/java/com/label/common/shiro/BearerToken.java @@ -0,0 +1,26 @@ +package com.label.common.shiro; + +import org.apache.shiro.authc.AuthenticationToken; + +/** + * Shiro AuthenticationToken wrapper for Bearer token strings. + */ +public class BearerToken implements AuthenticationToken { + private final String token; + private final TokenPrincipal principal; + + public BearerToken(String token, TokenPrincipal principal) { + this.token = token; + this.principal = principal; + } + + @Override + public Object getPrincipal() { + return principal; + } + + @Override + public Object getCredentials() { + return token; + } +} diff --git a/src/main/java/com/label/common/shiro/ShiroConfig.java b/src/main/java/com/label/common/shiro/ShiroConfig.java new file mode 100644 index 0000000..b199f5d --- /dev/null +++ b/src/main/java/com/label/common/shiro/ShiroConfig.java @@ -0,0 +1,71 @@ +package com.label.common.shiro; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.common.redis.RedisService; +import org.apache.shiro.mgt.SecurityManager; +import org.apache.shiro.realm.Realm; +import org.apache.shiro.spring.web.ShiroFilterFactoryBean; +import org.apache.shiro.web.mgt.DefaultWebSecurityManager; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import jakarta.servlet.Filter; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Shiro security configuration. + * + * Filter chain: + * /api/auth/login → anon (no auth required) + * /api/auth/logout → tokenFilter + * /api/** → tokenFilter (all other API endpoints require auth) + * /actuator/** → anon (health check) + * /** → anon (default) + * + * NOTE: spring.mvc.pathmatch.matching-strategy=ant_path_matcher MUST be set + * in application.yml for Shiro to work correctly with Spring Boot 3. + */ +@Configuration +public class ShiroConfig { + + @Bean + public UserRealm userRealm(RedisService redisService) { + return new UserRealm(redisService); + } + + @Bean + public SecurityManager securityManager(UserRealm userRealm) { + DefaultWebSecurityManager manager = new DefaultWebSecurityManager(); + manager.setRealms(List.of(userRealm)); + return manager; + } + + @Bean + public TokenFilter tokenFilter(RedisService redisService, ObjectMapper objectMapper) { + return new TokenFilter(redisService, objectMapper); + } + + @Bean + public ShiroFilterFactoryBean shiroFilterFactoryBean(SecurityManager securityManager, + TokenFilter tokenFilter) { + ShiroFilterFactoryBean factory = new ShiroFilterFactoryBean(); + factory.setSecurityManager(securityManager); + + // Register custom filters + Map filters = new LinkedHashMap<>(); + filters.put("tokenFilter", tokenFilter); + factory.setFilters(filters); + + // Filter chain definition (ORDER MATTERS - first match wins) + Map filterChainDef = new LinkedHashMap<>(); + filterChainDef.put("/api/auth/login", "anon"); + filterChainDef.put("/actuator/**", "anon"); + filterChainDef.put("/api/**", "tokenFilter"); + filterChainDef.put("/**", "anon"); + factory.setFilterChainDefinitionMap(filterChainDef); + + return factory; + } +} diff --git a/src/main/java/com/label/common/shiro/TokenFilter.java b/src/main/java/com/label/common/shiro/TokenFilter.java new file mode 100644 index 0000000..19ce508 --- /dev/null +++ b/src/main/java/com/label/common/shiro/TokenFilter.java @@ -0,0 +1,95 @@ +package com.label.common.shiro; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.label.common.context.CompanyContext; +import com.label.common.redis.RedisKeyManager; +import com.label.common.redis.RedisService; +import com.label.common.result.Result; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.shiro.web.filter.PathMatchingFilter; +import org.springframework.http.MediaType; + +import java.io.IOException; +import java.util.Map; + +/** + * Shiro filter: parses "Authorization: Bearer {uuid}", validates against Redis, + * injects CompanyContext and Shiro subject principals. + * + * KEY DESIGN: + * - CompanyContext.clear() MUST be called in finally block to prevent thread pool leakage + * - Token lookup is from Redis Hash token:{uuid} → {userId, role, companyId, username} + * - 401 on missing/invalid token; filter continues for valid token + */ +@Slf4j +@RequiredArgsConstructor +public class TokenFilter extends PathMatchingFilter { + + private final RedisService redisService; + private final ObjectMapper objectMapper; + + @Override + protected boolean onPreHandle(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception { + HttpServletRequest req = (HttpServletRequest) request; + HttpServletResponse resp = (HttpServletResponse) response; + + String authHeader = req.getHeader("Authorization"); + if (authHeader == null || !authHeader.startsWith("Bearer ")) { + writeUnauthorized(resp, "缺少或无效的认证令牌"); + return false; + } + + String token = authHeader.substring(7).trim(); + String tokenKey = RedisKeyManager.tokenKey(token); + Map tokenData = redisService.hGetAll(tokenKey); + + if (tokenData == null || tokenData.isEmpty()) { + writeUnauthorized(resp, "令牌已过期或不存在"); + return false; + } + + try { + Long userId = Long.parseLong(tokenData.get("userId").toString()); + String role = tokenData.get("role").toString(); + Long companyId = Long.parseLong(tokenData.get("companyId").toString()); + String username = tokenData.get("username").toString(); + + // Inject company context (must be cleared in finally) + CompanyContext.set(companyId); + + // Bind Shiro subject with token principal + TokenPrincipal principal = new TokenPrincipal(userId, role, companyId, username, token); + request.setAttribute("__token_principal__", principal); + + return true; + } catch (Exception e) { + log.error("解析 Token 数据失败: {}", e.getMessage()); + writeUnauthorized(resp, "令牌数据格式错误"); + return false; + } + } + + @Override + public void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain) + throws ServletException, IOException { + try { + super.doFilterInternal(request, response, chain); + } finally { + // CRITICAL: Always clear ThreadLocal to prevent leakage in thread pool + CompanyContext.clear(); + } + } + + private void writeUnauthorized(HttpServletResponse resp, String message) throws IOException { + resp.setStatus(HttpServletResponse.SC_UNAUTHORIZED); + resp.setContentType(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + resp.getWriter().write(objectMapper.writeValueAsString(Result.failure("UNAUTHORIZED", message))); + } +} diff --git a/src/main/java/com/label/common/shiro/TokenPrincipal.java b/src/main/java/com/label/common/shiro/TokenPrincipal.java new file mode 100644 index 0000000..39aa63e --- /dev/null +++ b/src/main/java/com/label/common/shiro/TokenPrincipal.java @@ -0,0 +1,18 @@ +package com.label.common.shiro; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import java.io.Serializable; + +/** + * Shiro principal carrying the authenticated user's session data. + */ +@Getter +@AllArgsConstructor +public class TokenPrincipal implements Serializable { + private final Long userId; + private final String role; + private final Long companyId; + private final String username; + private final String token; +} diff --git a/src/main/java/com/label/common/shiro/UserRealm.java b/src/main/java/com/label/common/shiro/UserRealm.java new file mode 100644 index 0000000..0fb11d9 --- /dev/null +++ b/src/main/java/com/label/common/shiro/UserRealm.java @@ -0,0 +1,87 @@ +package com.label.common.shiro; + +import com.label.common.redis.RedisKeyManager; +import com.label.common.redis.RedisService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.shiro.authc.*; +import org.apache.shiro.authz.AuthorizationInfo; +import org.apache.shiro.authz.SimpleAuthorizationInfo; +import org.apache.shiro.realm.AuthorizingRealm; +import org.apache.shiro.subject.PrincipalCollection; + +/** + * Shiro Realm for role-based authorization using token-based authentication. + * + * Role hierarchy (addInheritedRoles): + * ADMIN ⊃ REVIEWER ⊃ ANNOTATOR ⊃ UPLOADER + * + * Permission lookup order: + * 1. Redis user:perm:{userId} (TTL 5 min) + * 2. If miss: use role from TokenPrincipal + */ +@Slf4j +@RequiredArgsConstructor +public class UserRealm extends AuthorizingRealm { + + private static final long PERM_CACHE_TTL = 300L; // 5 minutes + + private final RedisService redisService; + + @Override + public boolean supports(AuthenticationToken token) { + return token instanceof BearerToken; + } + + @Override + protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException { + // Token validation is done in TokenFilter; this realm only handles authorization + // For authentication, we trust the token that was validated by TokenFilter + return new SimpleAuthenticationInfo(token.getPrincipal(), token.getCredentials(), getName()); + } + + @Override + protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) { + TokenPrincipal principal = (TokenPrincipal) principals.getPrimaryPrincipal(); + if (principal == null) { + return new SimpleAuthorizationInfo(); + } + + String role = getRoleFromCacheOrPrincipal(principal); + SimpleAuthorizationInfo info = new SimpleAuthorizationInfo(); + info.addRole(role); + addInheritedRoles(info, role); + return info; + } + + private String getRoleFromCacheOrPrincipal(TokenPrincipal principal) { + String permKey = RedisKeyManager.userPermKey(principal.getUserId()); + String cachedRole = redisService.get(permKey); + if (cachedRole != null && !cachedRole.isEmpty()) { + return cachedRole; + } + // Cache miss: use role from token, then refresh cache + String role = principal.getRole(); + redisService.set(permKey, role, PERM_CACHE_TTL); + return role; + } + + /** + * ADMIN inherits all roles: ADMIN ⊃ REVIEWER ⊃ ANNOTATOR ⊃ UPLOADER + */ + private void addInheritedRoles(SimpleAuthorizationInfo info, String role) { + switch (role) { + case "ADMIN": + info.addRole("REVIEWER"); + // fall through + case "REVIEWER": + info.addRole("ANNOTATOR"); + // fall through + case "ANNOTATOR": + info.addRole("UPLOADER"); + break; + default: + break; + } + } +} diff --git a/src/main/java/com/label/common/statemachine/DatasetStatus.java b/src/main/java/com/label/common/statemachine/DatasetStatus.java index 753d508..e1eca1c 100644 --- a/src/main/java/com/label/common/statemachine/DatasetStatus.java +++ b/src/main/java/com/label/common/statemachine/DatasetStatus.java @@ -7,7 +7,8 @@ public enum DatasetStatus { PENDING_REVIEW, APPROVED, REJECTED; public static final Map> TRANSITIONS = Map.of( - PENDING_REVIEW, Set.of(APPROVED, REJECTED) - // APPROVED/REJECTED: terminal states + PENDING_REVIEW, Set.of(APPROVED, REJECTED), + REJECTED, Set.of(PENDING_REVIEW) // 重新提交审核 + // APPROVED: terminal state ); } diff --git a/src/test/java/com/label/AbstractIntegrationTest.java b/src/test/java/com/label/AbstractIntegrationTest.java new file mode 100644 index 0000000..8679173 --- /dev/null +++ b/src/test/java/com/label/AbstractIntegrationTest.java @@ -0,0 +1,87 @@ +package com.label; + +import org.junit.jupiter.api.BeforeEach; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; +import org.testcontainers.utility.MountableFile; + +/** + * Base class for all integration tests. + * + * Starts real PostgreSQL 16 and Redis 7 containers (shared across test class instances). + * Executes sql/init.sql to initialize schema and seed data. + * + * DESIGN: + * - @Container with static fields → containers are shared across test methods (faster) + * - @DynamicPropertySource → overrides datasource/redis properties at runtime + * - @BeforeEach cleanData() → truncates business tables (not sys_company/sys_user) between tests + */ +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@Testcontainers +public abstract class AbstractIntegrationTest { + + @LocalServerPort + protected int port; + + @Autowired + protected JdbcTemplate jdbcTemplate; + + @SuppressWarnings("resource") + @Container + protected static final PostgreSQLContainer postgres = + new PostgreSQLContainer<>(DockerImageName.parse("postgres:16-alpine")) + .withDatabaseName("label_db") + .withUsername("label") + .withPassword("label_password") + .withCopyFileToContainer( + MountableFile.forClasspathResource("db/init.sql"), + "/docker-entrypoint-initdb.d/init.sql"); + + @SuppressWarnings("resource") + @Container + protected static final GenericContainer redis = + new GenericContainer<>(DockerImageName.parse("redis:7-alpine")) + .withExposedPorts(6379) + .withCommand("redis-server", "--requirepass", "test_redis_password"); + + @DynamicPropertySource + static void configureProperties(DynamicPropertyRegistry registry) { + registry.add("spring.datasource.url", postgres::getJdbcUrl); + registry.add("spring.datasource.username", postgres::getUsername); + registry.add("spring.datasource.password", postgres::getPassword); + registry.add("spring.data.redis.host", redis::getHost); + registry.add("spring.data.redis.port", () -> redis.getMappedPort(6379)); + registry.add("spring.data.redis.password", () -> "test_redis_password"); + } + + /** + * Clean only business data between tests to keep schema intact. + * Keep sys_company and sys_user since init.sql seeds them. + */ + @BeforeEach + void cleanData() { + jdbcTemplate.execute("TRUNCATE TABLE video_process_job, annotation_task_history, " + + "sys_operation_log, sys_config, export_batch, training_dataset, " + + "annotation_result, annotation_task, source_data RESTART IDENTITY CASCADE"); + // Re-insert global sys_config entries that were truncated + jdbcTemplate.execute("INSERT INTO sys_config (company_id, config_key, config_value) VALUES " + + "(NULL, 'token_ttl_seconds', '7200'), " + + "(NULL, 'model_default', 'glm-4'), " + + "(NULL, 'video_frame_interval', '30') " + + "ON CONFLICT DO NOTHING"); + } + + /** Helper: get base URL for REST calls */ + protected String baseUrl(String path) { + return "http://localhost:" + port + path; + } +} diff --git a/src/test/java/com/label/unit/StateMachineTest.java b/src/test/java/com/label/unit/StateMachineTest.java new file mode 100644 index 0000000..a563970 --- /dev/null +++ b/src/test/java/com/label/unit/StateMachineTest.java @@ -0,0 +1,265 @@ +package com.label.unit; + +import com.label.common.exception.BusinessException; +import com.label.common.statemachine.*; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.*; + +/** + * Unit tests for all state machine enums and StateValidator. + * No Spring context needed - pure unit tests. + */ +@DisplayName("状态机单元测试") +class StateMachineTest { + + // ===== SourceStatus ===== + @Nested + @DisplayName("SourceStatus 状态机") + class SourceStatusTest { + + @Test + @DisplayName("合法转换:PENDING → EXTRACTING(文本/图片直接提取)") + void pendingToExtracting() { + assertThatCode(() -> + StateValidator.assertTransition(SourceStatus.TRANSITIONS, SourceStatus.PENDING, SourceStatus.EXTRACTING) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:PENDING → PREPROCESSING(视频上传)") + void pendingToPreprocessing() { + assertThatCode(() -> + StateValidator.assertTransition(SourceStatus.TRANSITIONS, SourceStatus.PENDING, SourceStatus.PREPROCESSING) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:PREPROCESSING → PENDING(视频预处理完成)") + void preprocessingToPending() { + assertThatCode(() -> + StateValidator.assertTransition(SourceStatus.TRANSITIONS, SourceStatus.PREPROCESSING, SourceStatus.PENDING) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:EXTRACTING → QA_REVIEW(提取审批通过)") + void extractingToQaReview() { + assertThatCode(() -> + StateValidator.assertTransition(SourceStatus.TRANSITIONS, SourceStatus.EXTRACTING, SourceStatus.QA_REVIEW) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:QA_REVIEW → APPROVED(QA 审批通过)") + void qaReviewToApproved() { + assertThatCode(() -> + StateValidator.assertTransition(SourceStatus.TRANSITIONS, SourceStatus.QA_REVIEW, SourceStatus.APPROVED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("非法转换:APPROVED → PENDING 抛出异常") + void approvedToPendingFails() { + assertThatThrownBy(() -> + StateValidator.assertTransition(SourceStatus.TRANSITIONS, SourceStatus.APPROVED, SourceStatus.PENDING) + ).isInstanceOf(BusinessException.class) + .extracting("code").isEqualTo("INVALID_STATE_TRANSITION"); + } + + @Test + @DisplayName("非法转换:PENDING → APPROVED(跳过中间状态)抛出异常") + void pendingToApprovedFails() { + assertThatThrownBy(() -> + StateValidator.assertTransition(SourceStatus.TRANSITIONS, SourceStatus.PENDING, SourceStatus.APPROVED) + ).isInstanceOf(BusinessException.class) + .extracting("code").isEqualTo("INVALID_STATE_TRANSITION"); + } + } + + // ===== TaskStatus ===== + @Nested + @DisplayName("TaskStatus 状态机") + class TaskStatusTest { + + @Test + @DisplayName("合法转换:UNCLAIMED → IN_PROGRESS(领取)") + void unclaimedToInProgress() { + assertThatCode(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.UNCLAIMED, TaskStatus.IN_PROGRESS) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:IN_PROGRESS → SUBMITTED(提交)") + void inProgressToSubmitted() { + assertThatCode(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.IN_PROGRESS, TaskStatus.SUBMITTED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:IN_PROGRESS → UNCLAIMED(放弃)") + void inProgressToUnclaimed() { + assertThatCode(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.IN_PROGRESS, TaskStatus.UNCLAIMED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:IN_PROGRESS → IN_PROGRESS(ADMIN 强制转移,持有人变更)") + void inProgressToInProgress() { + assertThatCode(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.IN_PROGRESS, TaskStatus.IN_PROGRESS) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:SUBMITTED → APPROVED(审批通过)") + void submittedToApproved() { + assertThatCode(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.SUBMITTED, TaskStatus.APPROVED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:SUBMITTED → REJECTED(审批驳回)") + void submittedToRejected() { + assertThatCode(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.SUBMITTED, TaskStatus.REJECTED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:REJECTED → IN_PROGRESS(标注员重领)") + void rejectedToInProgress() { + assertThatCode(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.REJECTED, TaskStatus.IN_PROGRESS) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("非法转换:APPROVED → IN_PROGRESS 抛出异常") + void approvedToInProgressFails() { + assertThatThrownBy(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.APPROVED, TaskStatus.IN_PROGRESS) + ).isInstanceOf(BusinessException.class) + .extracting("code").isEqualTo("INVALID_STATE_TRANSITION"); + } + + @Test + @DisplayName("非法转换:UNCLAIMED → SUBMITTED(跳过 IN_PROGRESS)抛出异常") + void unclaimedToSubmittedFails() { + assertThatThrownBy(() -> + StateValidator.assertTransition(TaskStatus.TRANSITIONS, TaskStatus.UNCLAIMED, TaskStatus.SUBMITTED) + ).isInstanceOf(BusinessException.class) + .extracting("code").isEqualTo("INVALID_STATE_TRANSITION"); + } + } + + // ===== DatasetStatus ===== + @Nested + @DisplayName("DatasetStatus 状态机") + class DatasetStatusTest { + + @Test + @DisplayName("合法转换:PENDING_REVIEW → APPROVED") + void pendingReviewToApproved() { + assertThatCode(() -> + StateValidator.assertTransition(DatasetStatus.TRANSITIONS, DatasetStatus.PENDING_REVIEW, DatasetStatus.APPROVED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:PENDING_REVIEW → REJECTED") + void pendingReviewToRejected() { + assertThatCode(() -> + StateValidator.assertTransition(DatasetStatus.TRANSITIONS, DatasetStatus.PENDING_REVIEW, DatasetStatus.REJECTED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:REJECTED → PENDING_REVIEW(重新提交)") + void rejectedToPendingReview() { + assertThatCode(() -> + StateValidator.assertTransition(DatasetStatus.TRANSITIONS, DatasetStatus.REJECTED, DatasetStatus.PENDING_REVIEW) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("非法转换:APPROVED → REJECTED 抛出异常") + void approvedToRejectedFails() { + assertThatThrownBy(() -> + StateValidator.assertTransition(DatasetStatus.TRANSITIONS, DatasetStatus.APPROVED, DatasetStatus.REJECTED) + ).isInstanceOf(BusinessException.class) + .extracting("code").isEqualTo("INVALID_STATE_TRANSITION"); + } + } + + // ===== VideoJobStatus ===== + @Nested + @DisplayName("VideoJobStatus 状态机") + class VideoJobStatusTest { + + @Test + @DisplayName("合法转换:PENDING → RUNNING") + void pendingToRunning() { + assertThatCode(() -> + StateValidator.assertTransition(VideoJobStatus.TRANSITIONS, VideoJobStatus.PENDING, VideoJobStatus.RUNNING) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:RUNNING → SUCCESS") + void runningToSuccess() { + assertThatCode(() -> + StateValidator.assertTransition(VideoJobStatus.TRANSITIONS, VideoJobStatus.RUNNING, VideoJobStatus.SUCCESS) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:RUNNING → RETRYING(失败且未超重试次数)") + void runningToRetrying() { + assertThatCode(() -> + StateValidator.assertTransition(VideoJobStatus.TRANSITIONS, VideoJobStatus.RUNNING, VideoJobStatus.RETRYING) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:RUNNING → FAILED(失败且超过最大重试)") + void runningToFailed() { + assertThatCode(() -> + StateValidator.assertTransition(VideoJobStatus.TRANSITIONS, VideoJobStatus.RUNNING, VideoJobStatus.FAILED) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("合法转换:RETRYING → RUNNING(AI 重试)") + void retryingToRunning() { + assertThatCode(() -> + StateValidator.assertTransition(VideoJobStatus.TRANSITIONS, VideoJobStatus.RETRYING, VideoJobStatus.RUNNING) + ).doesNotThrowAnyException(); + } + + @Test + @DisplayName("非法转换:FAILED → PENDING 不在状态机内(ADMIN 手动触发,不走 StateValidator)") + void failedToPendingNotInStateMachine() { + // FAILED → PENDING is intentionally NOT in TRANSITIONS (ADMIN manual reset via special API) + assertThatThrownBy(() -> + StateValidator.assertTransition(VideoJobStatus.TRANSITIONS, VideoJobStatus.FAILED, VideoJobStatus.PENDING) + ).isInstanceOf(BusinessException.class) + .extracting("code").isEqualTo("INVALID_STATE_TRANSITION"); + } + + @Test + @DisplayName("非法转换:SUCCESS → RUNNING 抛出异常") + void successToRunningFails() { + assertThatThrownBy(() -> + StateValidator.assertTransition(VideoJobStatus.TRANSITIONS, VideoJobStatus.SUCCESS, VideoJobStatus.RUNNING) + ).isInstanceOf(BusinessException.class) + .extracting("code").isEqualTo("INVALID_STATE_TRANSITION"); + } + } +} diff --git a/src/test/resources/db/init.sql b/src/test/resources/db/init.sql new file mode 100644 index 0000000..1824039 --- /dev/null +++ b/src/test/resources/db/init.sql @@ -0,0 +1,332 @@ +-- label_backend init.sql +-- PostgreSQL 14+ +-- 按依赖顺序建全部 11 张表: +-- sys_company → sys_user → source_data → annotation_task → annotation_result +-- → training_dataset → export_batch → sys_config → sys_operation_log +-- → annotation_task_history → video_process_job +-- 含所有索引及初始配置数据 + +-- ============================================================ +-- 扩展 +-- ============================================================ +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +-- ============================================================ +-- 1. sys_company(租户) +-- ============================================================ +CREATE TABLE IF NOT EXISTS sys_company ( + id BIGSERIAL PRIMARY KEY, + company_name VARCHAR(100) NOT NULL, + company_code VARCHAR(50) NOT NULL, + status VARCHAR(10) NOT NULL DEFAULT 'ACTIVE', -- ACTIVE / DISABLED + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + CONSTRAINT uk_sys_company_name UNIQUE (company_name), + CONSTRAINT uk_sys_company_code UNIQUE (company_code) +); + +-- ============================================================ +-- 2. sys_user(用户) +-- ============================================================ +CREATE TABLE IF NOT EXISTS sys_user ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT NOT NULL REFERENCES sys_company(id), + username VARCHAR(50) NOT NULL, + password_hash VARCHAR(255) NOT NULL, -- BCrypt, strength >= 10 + real_name VARCHAR(50), + role VARCHAR(20) NOT NULL, -- UPLOADER / ANNOTATOR / REVIEWER / ADMIN + status VARCHAR(10) NOT NULL DEFAULT 'ACTIVE', -- ACTIVE / DISABLED + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + CONSTRAINT uk_sys_user_company_username UNIQUE (company_id, username) +); + +CREATE INDEX IF NOT EXISTS idx_sys_user_company_id + ON sys_user (company_id); + +-- ============================================================ +-- 3. source_data(原始资料) +-- ============================================================ +CREATE TABLE IF NOT EXISTS source_data ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT NOT NULL REFERENCES sys_company(id), + uploader_id BIGINT REFERENCES sys_user(id), + data_type VARCHAR(20) NOT NULL, -- TEXT / IMAGE / VIDEO + file_path VARCHAR(500) NOT NULL, -- RustFS object path + file_name VARCHAR(255) NOT NULL, + file_size BIGINT, + bucket_name VARCHAR(100) NOT NULL, + parent_source_id BIGINT REFERENCES source_data(id), -- 视频帧 / 文本片段 + status VARCHAR(20) NOT NULL DEFAULT 'PENDING', + -- PENDING / PREPROCESSING / EXTRACTING / QA_REVIEW / APPROVED + reject_reason TEXT, -- 保留字段(当前无 REJECTED 状态) + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_source_data_company_id + ON source_data (company_id); +CREATE INDEX IF NOT EXISTS idx_source_data_company_status + ON source_data (company_id, status); +CREATE INDEX IF NOT EXISTS idx_source_data_parent_source_id + ON source_data (parent_source_id); + +-- ============================================================ +-- 4. annotation_task(标注任务) +-- ============================================================ +CREATE TABLE IF NOT EXISTS annotation_task ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT NOT NULL REFERENCES sys_company(id), + source_id BIGINT NOT NULL REFERENCES source_data(id), + task_type VARCHAR(30) NOT NULL, -- EXTRACTION / QA_GENERATION + status VARCHAR(20) NOT NULL DEFAULT 'UNCLAIMED', + -- UNCLAIMED / IN_PROGRESS / SUBMITTED / APPROVED / REJECTED + claimed_by BIGINT REFERENCES sys_user(id), + claimed_at TIMESTAMP, + submitted_at TIMESTAMP, + completed_at TIMESTAMP, + is_final BOOLEAN NOT NULL DEFAULT FALSE, -- true 即 APPROVED 且无需再审 + ai_model VARCHAR(50), + reject_reason TEXT, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_annotation_task_company_status + ON annotation_task (company_id, status); +CREATE INDEX IF NOT EXISTS idx_annotation_task_source_id + ON annotation_task (source_id); +CREATE INDEX IF NOT EXISTS idx_annotation_task_claimed_by + ON annotation_task (claimed_by); + +-- ============================================================ +-- 5. annotation_result(标注结果,JSONB) +-- ============================================================ +CREATE TABLE IF NOT EXISTS annotation_result ( + id BIGSERIAL NOT NULL, + task_id BIGINT NOT NULL REFERENCES annotation_task(id), + company_id BIGINT NOT NULL REFERENCES sys_company(id), + result_json JSONB NOT NULL DEFAULT '[]'::jsonb, -- 整体替换语义 + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW(), + CONSTRAINT pk_annotation_result PRIMARY KEY (id), + CONSTRAINT uk_annotation_result_task_id UNIQUE (task_id) +); + +CREATE INDEX IF NOT EXISTS idx_annotation_result_task_id + ON annotation_result (task_id); +CREATE INDEX IF NOT EXISTS idx_annotation_result_company_id + ON annotation_result (company_id); + +-- ============================================================ +-- 6. training_dataset(训练数据集) +-- export_batch_id FK 在 export_batch 建完后补加 +-- ============================================================ +CREATE TABLE IF NOT EXISTS training_dataset ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT NOT NULL REFERENCES sys_company(id), + task_id BIGINT NOT NULL REFERENCES annotation_task(id), + source_id BIGINT NOT NULL REFERENCES source_data(id), + sample_type VARCHAR(20) NOT NULL, -- TEXT / IMAGE / VIDEO_FRAME + glm_format_json JSONB NOT NULL, -- GLM fine-tune 格式 + status VARCHAR(20) NOT NULL DEFAULT 'PENDING_REVIEW', + -- PENDING_REVIEW / APPROVED / REJECTED + export_batch_id BIGINT, -- 导出后填写,FK 在下方补加 + exported_at TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_training_dataset_company_status + ON training_dataset (company_id, status); +CREATE INDEX IF NOT EXISTS idx_training_dataset_task_id + ON training_dataset (task_id); + +-- ============================================================ +-- 7. export_batch(导出批次) +-- ============================================================ +CREATE TABLE IF NOT EXISTS export_batch ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT NOT NULL REFERENCES sys_company(id), + batch_uuid UUID NOT NULL DEFAULT gen_random_uuid(), + sample_count INT NOT NULL DEFAULT 0, + dataset_file_path VARCHAR(500), -- 导出 JSONL 的 RustFS 路径 + glm_job_id VARCHAR(100), -- GLM fine-tune 任务 ID + finetune_status VARCHAR(20) NOT NULL DEFAULT 'NOT_STARTED', + -- NOT_STARTED / RUNNING / COMPLETED / FAILED + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_export_batch_company_id + ON export_batch (company_id); + +-- 补加 training_dataset.export_batch_id FK +ALTER TABLE training_dataset + ADD CONSTRAINT fk_training_dataset_export_batch + FOREIGN KEY (export_batch_id) REFERENCES export_batch(id) + NOT VALID; -- 允许已有 NULL 行,不强制回溯校验 + +-- ============================================================ +-- 8. sys_config(系统配置) +-- ============================================================ +CREATE TABLE IF NOT EXISTS sys_config ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT REFERENCES sys_company(id), -- NULL = 全局默认 + config_key VARCHAR(100) NOT NULL, + config_value TEXT NOT NULL, + description VARCHAR(255), + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +-- 公司级配置唯一索引 +CREATE UNIQUE INDEX IF NOT EXISTS uk_sys_config_company_key + ON sys_config (company_id, config_key) + WHERE company_id IS NOT NULL; + +-- 全局配置唯一索引 +CREATE UNIQUE INDEX IF NOT EXISTS uk_sys_config_global_key + ON sys_config (config_key) + WHERE company_id IS NULL; + +CREATE INDEX IF NOT EXISTS idx_sys_config_company_key + ON sys_config (company_id, config_key); + +-- ============================================================ +-- 9. sys_operation_log(操作日志,仅追加) +-- ============================================================ +CREATE TABLE IF NOT EXISTS sys_operation_log ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT NOT NULL REFERENCES sys_company(id), + operator_id BIGINT REFERENCES sys_user(id), + operation_type VARCHAR(50) NOT NULL, -- 例如 EXTRACTION_APPROVE / USER_LOGIN + target_id BIGINT, + target_type VARCHAR(50), + detail JSONB, + result VARCHAR(10), -- SUCCESS / FAILURE + error_message TEXT, + operated_at TIMESTAMP NOT NULL DEFAULT NOW() + -- 无 updated_at(仅追加表,永不更新) +); + +CREATE INDEX IF NOT EXISTS idx_sys_operation_log_company_operated_at + ON sys_operation_log (company_id, operated_at); +CREATE INDEX IF NOT EXISTS idx_sys_operation_log_operator_id + ON sys_operation_log (operator_id); + +-- ============================================================ +-- 10. annotation_task_history(任务状态历史,仅追加) +-- ============================================================ +CREATE TABLE IF NOT EXISTS annotation_task_history ( + id BIGSERIAL PRIMARY KEY, + task_id BIGINT NOT NULL REFERENCES annotation_task(id), + company_id BIGINT NOT NULL REFERENCES sys_company(id), + from_status VARCHAR(20), + to_status VARCHAR(20) NOT NULL, + operator_id BIGINT REFERENCES sys_user(id), + operator_role VARCHAR(20), + comment TEXT, + created_at TIMESTAMP NOT NULL DEFAULT NOW() + -- 无 updated_at(仅追加表,永不更新) +); + +CREATE INDEX IF NOT EXISTS idx_annotation_task_history_task_id + ON annotation_task_history (task_id); +CREATE INDEX IF NOT EXISTS idx_annotation_task_history_company_id + ON annotation_task_history (company_id); + +-- ============================================================ +-- 11. video_process_job(视频处理作业) +-- ============================================================ +CREATE TABLE IF NOT EXISTS video_process_job ( + id BIGSERIAL PRIMARY KEY, + company_id BIGINT NOT NULL REFERENCES sys_company(id), + source_id BIGINT NOT NULL REFERENCES source_data(id), + job_type VARCHAR(30) NOT NULL, -- FRAME_EXTRACT / VIDEO_TO_TEXT + status VARCHAR(20) NOT NULL DEFAULT 'PENDING', + -- PENDING / RUNNING / SUCCESS / FAILED / RETRYING + params JSONB, -- 例如 {"frameInterval": 30, "mode": "FRAME"} + output_path VARCHAR(500), -- 完成后的 RustFS 输出路径 + retry_count INT NOT NULL DEFAULT 0, + max_retries INT NOT NULL DEFAULT 3, + error_message TEXT, + started_at TIMESTAMP, + completed_at TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_video_process_job_company_id + ON video_process_job (company_id); +CREATE INDEX IF NOT EXISTS idx_video_process_job_source_id + ON video_process_job (source_id); +CREATE INDEX IF NOT EXISTS idx_video_process_job_status + ON video_process_job (status); + +-- ============================================================ +-- 初始数据 +-- ============================================================ + +-- 1. 演示公司 +INSERT INTO sys_company (company_name, company_code, status) +VALUES ('演示公司', 'DEMO', 'ACTIVE') +ON CONFLICT DO NOTHING; + +-- 2. 初始用户(BCrypt strength=10) +-- admin / admin123 +-- reviewer01/ review123 +-- annotator01/annot123 +-- uploader01 / upload123 +INSERT INTO sys_user (company_id, username, password_hash, real_name, role, status) +SELECT + c.id, + u.username, + u.password_hash, + u.real_name, + u.role, + 'ACTIVE' +FROM sys_company c +CROSS JOIN (VALUES + ('admin', + '$2a$10$B8iR5z43URiNPm.eut3JvufIPBuvGx5ZZmqyUqE1A1WdbZppX5bmi', + '管理员', + 'ADMIN'), + ('reviewer01', + '$2a$10$euOJZRfUtYNW7WHpfW1Ciee5b3rjkYFe3yQHT/uCQWrYVc0XQcukm', + '审核员01', + 'REVIEWER'), + ('annotator01', + '$2a$10$8UKwHPNASauKMTrqosR0Reg1X1gkFzFlGa/HBwNLXUELaj4e/zcqu', + '标注员01', + 'ANNOTATOR'), + ('uploader01', + '$2a$10$o2d7jsT31vyxIJHUo50mUefoZLLvGqft97zaL9OQCjRxn9ie1H/1O', + '上传员01', + 'UPLOADER') +) AS u(username, password_hash, real_name, role) +WHERE c.company_code = 'DEMO' +ON CONFLICT (company_id, username) DO NOTHING; + +-- 3. 全局系统配置 +INSERT INTO sys_config (company_id, config_key, config_value, description) +VALUES + (NULL, 'token_ttl_seconds', '7200', + '会话凭证有效期(秒)'), + (NULL, 'model_default', 'glm-4', + 'AI 辅助默认模型'), + (NULL, 'video_frame_interval', '30', + '视频帧提取间隔(帧数)'), + (NULL, 'prompt_extract_text', + '请提取以下文本中的主语-谓语-宾语三元组,以JSON数组格式返回,每个元素包含subject、predicate、object、sourceText、startOffset、endOffset字段。', + '文本三元组提取 Prompt 模板'), + (NULL, 'prompt_extract_image', + '请提取图片中的实体关系四元组,以JSON数组格式返回,每个元素包含subject、relation、object、modifier、confidence字段。', + '图片四元组提取 Prompt 模板'), + (NULL, 'prompt_qa_gen_text', + '根据以下文本三元组生成高质量问答对,以JSON数组格式返回,每个元素包含question、answer、difficulty字段。', + '文本问答生成 Prompt 模板'), + (NULL, 'prompt_qa_gen_image', + '根据以下图片四元组生成高质量问答对,以JSON数组格式返回,每个元素包含question、answer、imageRef、difficulty字段。', + '图片问答生成 Prompt 模板') +ON CONFLICT DO NOTHING;