From 35abaa1ebbf0db6bf126b2b5160b2da85e0744bc Mon Sep 17 00:00:00 2001 From: Jiang Date: Mon, 2 Feb 2026 11:09:43 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E5=B9=B6=E4=BF=AE=E5=A4=8Dap?= =?UTF-8?q?i=E5=AF=BC=E5=85=A5=E8=B7=AF=E5=BE=84=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/v1/endpoints/network/geometry.py | 2 +- app/api/v1/router.py | 2 +- app/auth/dependencies.py | 34 +++-- app/core/config.py | 1 + app/core/security.py | 49 ++++--- app/infra/db/influxdb/api.py | 4 +- app/main.py | 2 +- scripts/get_data.py | 55 +++---- scripts/get_hist_data.py | 57 ++++---- test_api_integration.py | 152 -------------------- tests/api/test_api_integration.py | 77 ++++++++++ tests/{ => auth}/test_encryption.py | 18 ++- tests/unit/test_pipeline_health_analyzer.py | 8 +- 13 files changed, 211 insertions(+), 250 deletions(-) delete mode 100755 test_api_integration.py create mode 100755 tests/api/test_api_integration.py rename tests/{ => auth}/test_encryption.py (90%) diff --git a/app/api/v1/endpoints/network/geometry.py b/app/api/v1/endpoints/network/geometry.py index 52470d2..3248e92 100644 --- a/app/api/v1/endpoints/network/geometry.py +++ b/app/api/v1/endpoints/network/geometry.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Request, Depends from typing import Any, List, Dict, Union from app.services.tjnetwork import * -from app.api.v1.endpoints.auth import verify_token +from app.auth.dependencies import get_current_user as verify_token from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime import msgpack diff --git a/app/api/v1/router.py b/app/api/v1/router.py index d83a52c..40d80e5 100644 --- a/app/api/v1/router.py +++ b/app/api/v1/router.py @@ -43,7 +43,7 @@ from app.infra.db.timescaledb import router as timescaledb_router api_router = APIRouter() # Core Services -api_router.include_router(auth.router, tags=["Auth"]) +api_router.include_router(auth.router, prefix="/auth", tags=["Auth"]) api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增 api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增 api_router.include_router(project.router, tags=["Project"]) diff --git a/app/auth/dependencies.py b/app/auth/dependencies.py index 299bf3b..8b7bff3 100644 --- a/app/auth/dependencies.py +++ b/app/auth/dependencies.py @@ -9,31 +9,34 @@ from app.infra.db.postgresql.database import Database oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") + # 数据库依赖 async def get_db(request: Request) -> Database: """ 获取数据库实例 - + 从 FastAPI app.state 中获取在启动时初始化的数据库连接 """ if not hasattr(request.app.state, "db"): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database not initialized" + detail="Database not initialized", ) return request.app.state.db + async def get_user_repository(db: Database = Depends(get_db)) -> UserRepository: """获取用户仓储实例""" return UserRepository(db) + async def get_current_user( token: str = Depends(oauth2_scheme), - user_repo: UserRepository = Depends(get_user_repository) + user_repo: UserRepository = Depends(get_user_repository), ) -> UserInDB: """ 获取当前登录用户 - + 从 JWT Token 中解析用户信息,并从数据库验证 """ credentials_exception = HTTPException( @@ -41,32 +44,35 @@ async def get_current_user( detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) - + try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) username: str = payload.get("sub") token_type: str = payload.get("type", "access") - + if username is None: raise credentials_exception - + if token_type != "access": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type. Access token required.", headers={"WWW-Authenticate": "Bearer"}, ) - + except JWTError: raise credentials_exception - + # 从数据库获取用户 user = await user_repo.get_user_by_username(username) if user is None: raise credentials_exception - + return user + async def get_current_active_user( current_user: UserInDB = Depends(get_current_user), ) -> UserInDB: @@ -75,11 +81,11 @@ async def get_current_active_user( """ if not current_user.is_active: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Inactive user" + status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user" ) return current_user + async def get_current_superuser( current_user: UserInDB = Depends(get_current_user), ) -> UserInDB: @@ -89,6 +95,6 @@ async def get_current_superuser( if not current_user.is_superuser: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Not enough privileges. Superuser access required." + detail="Not enough privileges. Superuser access required.", ) return current_user diff --git a/app/core/config.py b/app/core/config.py index 52d4b03..4f0151f 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -32,5 +32,6 @@ class Settings(BaseSettings): class Config: env_file = ".env" + extra = "ignore" settings = Settings() diff --git a/app/core/security.py b/app/core/security.py index f29920a..802e837 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -1,77 +1,90 @@ from datetime import datetime, timedelta from typing import Optional, Union, Any + from jose import jwt from passlib.context import CryptContext from app.core.config import settings pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str: + +def create_access_token( + subject: Union[str, Any], expires_delta: Optional[timedelta] = None +) -> str: """ 创建 JWT Access Token - + Args: subject: 用户标识(通常是用户名或用户ID) expires_delta: 过期时间增量 - + Returns: JWT token 字符串 """ if expires_delta: - expire = datetime.utcnow() + expires_delta + expire = datetime.now() + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - + expire = datetime.now() + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) + to_encode = { "exp": expire, "sub": str(subject), "type": "access", - "iat": datetime.utcnow() + "iat": datetime.now(), } - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + encoded_jwt = jwt.encode( + to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM + ) return encoded_jwt + def create_refresh_token(subject: Union[str, Any]) -> str: """ 创建 JWT Refresh Token(长期有效) - + Args: subject: 用户标识 - + Returns: JWT refresh token 字符串 """ - expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) - + expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + to_encode = { "exp": expire, "sub": str(subject), "type": "refresh", - "iat": datetime.utcnow() + "iat": datetime.now(), } - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + encoded_jwt = jwt.encode( + to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM + ) return encoded_jwt + def verify_password(plain_password: str, hashed_password: str) -> bool: """ 验证密码 - + Args: plain_password: 明文密码 hashed_password: 密码哈希 - + Returns: 是否匹配 """ return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: """ 生成密码哈希 - + Args: password: 明文密码 - + Returns: bcrypt 哈希字符串 """ diff --git a/app/infra/db/influxdb/api.py b/app/infra/db/influxdb/api.py index e8fc7f7..6926ddc 100644 --- a/app/infra/db/influxdb/api.py +++ b/app/infra/db/influxdb/api.py @@ -13,8 +13,8 @@ from typing import List, Dict from datetime import datetime, timedelta, timezone from influxdb_client.client.write_api import SYNCHRONOUS, ASYNCHRONOUS from dateutil import parser -import get_realValue -import get_data +# import get_realValue +# import get_data import psycopg import time import app.services.simulation as simulation diff --git a/app/main.py b/app/main.py index 1dc07af..2a3e1fc 100644 --- a/app/main.py +++ b/app/main.py @@ -71,7 +71,7 @@ app.add_middleware(GZipMiddleware, minimum_size=1000) # 添加审计中间件(可选,记录关键操作) # 如果需要启用审计日志,取消下面的注释 -# app.add_middleware(AuditMiddleware) +app.add_middleware(AuditMiddleware) # Include Routers app.include_router(api_router, prefix="/api/v1") diff --git a/scripts/get_data.py b/scripts/get_data.py index 090153b..6cfd87d 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -7,18 +7,19 @@ import csv # get_data 是用来获取 历史数据,也就是非实时数据的接口 # get_realtime 是用来获取 实时数据 + def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime: # 将毫秒级时间戳转换为秒级时间戳 timestamp_seconds = timestamp / 1000 # 将时间戳转换为datetime对象 - utc_time = datetime.utcfromtimestamp(timestamp_seconds) + utc_time = datetime.fromtimestamp(timestamp_seconds) # 设定UTC时区 - utc_timezone = pytz.timezone('UTC') + utc_timezone = pytz.timezone("UTC") # 转换为北京时间 - beijing_timezone = pytz.timezone('Asia/Shanghai') + beijing_timezone = pytz.timezone("Asia/Shanghai") beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone) return beijing_time @@ -26,10 +27,10 @@ def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime: def beijing_time_to_utc(beijing_time_str: str) -> str: # 定义北京时区 - beijing_timezone = pytz.timezone('Asia/Shanghai') + beijing_timezone = pytz.timezone("Asia/Shanghai") # 将字符串转换为datetime对象 - beijing_time = datetime.strptime(beijing_time_str, '%Y-%m-%d %H:%M:%S') + beijing_time = datetime.strptime(beijing_time_str, "%Y-%m-%d %H:%M:%S") # 本地化时间对象 beijing_time = beijing_timezone.localize(beijing_time) @@ -38,29 +39,31 @@ def beijing_time_to_utc(beijing_time_str: str) -> str: utc_time = beijing_time.astimezone(pytz.utc) # 转换为ISO 8601格式的字符串 - return utc_time.strftime('%Y-%m-%dT%H:%M:%SZ') + return utc_time.strftime("%Y-%m-%dT%H:%M:%SZ") -def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> List[Dict[str, Union[str, datetime, int, float]]]: -# def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None: +def get_history_data( + ids: str, begin_date: str, end_date: str, downsample: Optional[str] +) -> List[Dict[str, Union[str, datetime, int, float]]]: + # def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None: # 转换输入的北京时间为UTC时间 begin_date_utc = beijing_time_to_utc(begin_date) end_date_utc = beijing_time_to_utc(end_date) # 数据接口的地址 - url = 'http://183.64.62.100:9057/loong/api/curves/data' + url = "http://183.64.62.100:9057/loong/api/curves/data" # url = 'http://10.101.15.16:9000/loong/api/curves/data' # url_path = 'http://10.101.15.16:9000/loong' # 内网 # 设置 GET 请求的参数 params = { - 'ids': ids, - 'beginDate': begin_date_utc, - 'endDate': end_date_utc, - 'downsample': downsample + "ids": ids, + "beginDate": begin_date_utc, + "endDate": end_date_utc, + "downsample": downsample, } - history_data_list =[] + history_data_list = [] try: # 发送 GET 请求获取数据 @@ -73,24 +76,26 @@ def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optio # 这里可以对获取到的数据进行进一步处理 # 打印 'mpointId' 和 'mpointName' - for item in data['items']: - mpoint_id = str(item['mpointId']) - mpoint_name = item['mpointName'] + for item in data["items"]: + mpoint_id = str(item["mpointId"]) + mpoint_name = item["mpointName"] # print("mpointId:", item['mpointId']) # print("mpointName:", item['mpointName']) # 打印 'dataDate' 和 'dataValue' - for item_data in item['data']: + for item_data in item["data"]: # 将时间戳转换为北京时间 - beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate']) - data_value = item_data['dataValue'] + beijing_time = convert_timestamp_to_beijing_time( + item_data["dataDate"] + ) + data_value = item_data["dataValue"] # 创建一个字典存储每条数据 data_dict = { - 'time': beijing_time, - 'device_ID': str(mpoint_id), - 'description': mpoint_name, + "time": beijing_time, + "device_ID": str(mpoint_id), + "description": mpoint_name, # 'dataDate (Beijing Time)': beijing_time.strftime('%Y-%m-%d %H:%M:%S'), - 'monitored_value': data_value # 保留原有类型 + "monitored_value": data_value, # 保留原有类型 } history_data_list.append(data_dict) @@ -164,4 +169,4 @@ def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optio # for data in data_list1: # writer.writerow([data['measurement'], data['mpointId'], data['date'], data['dataValue'], data['datetime']]) # -# print(f"筛选后的数据已保存到 {filtered_csv_file_path}") \ No newline at end of file +# print(f"筛选后的数据已保存到 {filtered_csv_file_path}") diff --git a/scripts/get_hist_data.py b/scripts/get_hist_data.py index 3626505..04386e5 100644 --- a/scripts/get_hist_data.py +++ b/scripts/get_hist_data.py @@ -8,35 +8,33 @@ def convert_timestamp_to_beijing_time(timestamp): timestamp_seconds = timestamp / 1000 # 将时间戳转换为datetime对象 - utc_time = datetime.utcfromtimestamp(timestamp_seconds) + utc_time = datetime.fromtimestamp(timestamp_seconds) # 设定UTC时区 - utc_timezone = pytz.timezone('UTC') + utc_timezone = pytz.timezone("UTC") # 转换为北京时间 - beijing_timezone = pytz.timezone('Asia/Shanghai') + beijing_timezone = pytz.timezone("Asia/Shanghai") beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone) return beijing_time -def conver_beingtime_to_ucttime(timestr:str): - beijing_time=datetime.strptime(timestr,'%Y-%m-%d %H:%M:%S') - utc_time=beijing_time.astimezone(pytz.utc) - str_utc=utc_time.strftime('%Y-%m-%dT%H:%M:%SZ') - #print(str_utc) + +def conver_beingtime_to_ucttime(timestr: str): + beijing_time = datetime.strptime(timestr, "%Y-%m-%d %H:%M:%S") + utc_time = beijing_time.astimezone(pytz.utc) + str_utc = utc_time.strftime("%Y-%m-%dT%H:%M:%SZ") + # print(str_utc) return str_utc -def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]: + +def get_hist_data(ids, begin_date, end_date) -> dict[str, dict[datetime, float]]: # 数据接口的地址 - url = 'http://183.64.62.100:9057/loong/api/curves/data' + url = "http://183.64.62.100:9057/loong/api/curves/data" # 设置 GET 请求的参数 - params = { - 'ids': ids, - 'beginDate': begin_date, - 'endDate': end_date - } - lst_data={} + params = {"ids": ids, "beginDate": begin_date, "endDate": end_date} + lst_data = {} try: # 发送 GET 请求获取数据 response = requests.get(url, params=params) @@ -48,22 +46,27 @@ def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]: # 这里可以对获取到的数据进行进一步处理 # 打印 'mpointId' 和 'mpointName' - for item in data['items']: - #print("mpointId:", item['mpointId']) - #print("mpointName:", item['mpointName']) + for item in data["items"]: + # print("mpointId:", item['mpointId']) + # print("mpointName:", item['mpointName']) # 打印 'dataDate' 和 'dataValue' - data_seriers={} - for item_data in item['data']: + data_seriers = {} + for item_data in item["data"]: # print("dataDate:", item_data['dataDate']) # 将时间戳转换为北京时间 - beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate']) - print("dataDate (Beijing Time):", beijing_time.strftime('%Y-%m-%d %H:%M:%S')) - print("dataValue:", item_data['dataValue']) + beijing_time = convert_timestamp_to_beijing_time( + item_data["dataDate"] + ) + print( + "dataDate (Beijing Time):", + beijing_time.strftime("%Y-%m-%d %H:%M:%S"), + ) + print("dataValue:", item_data["dataValue"]) print() # 打印空行分隔不同条目 - r=float(item_data['dataValue']) - data_seriers[beijing_time]=r - lst_data[item['mpointId']]=data_seriers + r = float(item_data["dataValue"]) + data_seriers[beijing_time] = r + lst_data[item["mpointId"]] = data_seriers return lst_data else: # 如果请求不成功,打印错误信息 diff --git a/test_api_integration.py b/test_api_integration.py deleted file mode 100755 index a2dcf30..0000000 --- a/test_api_integration.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python -""" -测试新增 API 集成 - -验证新的认证、用户管理和审计日志接口是否正确集成 -""" -import sys -import subprocess -import time - -def check_imports(): - """检查关键模块是否可以导入""" - print("=" * 60) - print("步骤 1: 检查模块导入") - print("=" * 60) - - modules = [ - ("app.core.encryption", "加密模块"), - ("app.core.security", "安全模块"), - ("app.core.audit", "审计模块"), - ("app.domain.models.role", "角色模型"), - ("app.domain.schemas.user", "用户Schema"), - ("app.domain.schemas.audit", "审计Schema"), - ("app.auth.permissions", "权限控制"), - ("app.api.v1.endpoints.auth", "认证接口"), - ("app.api.v1.endpoints.user_management", "用户管理接口"), - ("app.api.v1.endpoints.audit", "审计日志接口"), - ("app.infra.repositories.user_repository", "用户仓储"), - ("app.infra.repositories.audit_repository", "审计仓储"), - ("app.infra.audit.middleware", "审计中间件"), - ] - - success = 0 - failed = 0 - - for module_name, desc in modules: - try: - __import__(module_name) - print(f"✓ {desc:20s} ({module_name})") - success += 1 - except Exception as e: - print(f"✗ {desc:20s} ({module_name})") - print(f" 错误: {e}") - failed += 1 - - print(f"\n结果: {success} 成功, {failed} 失败") - print() - return failed == 0 - -def check_router(): - """检查路由配置""" - print("=" * 60) - print("步骤 2: 检查路由配置") - print("=" * 60) - - try: - from app.api.v1 import router - from app.api.v1.endpoints import auth, user_management, audit - - print("✓ router 模块已导入") - print("✓ auth 端点已导入") - print("✓ user_management 端点已导入") - print("✓ audit 端点已导入") - - # 检查 router 中是否包含新增的路由 - api_router = router.api_router - print(f"\n已注册的路由数量: {len(api_router.routes)}") - - # 查找新增的路由 - auth_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/auth' in r.path] - user_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/users' in r.path] - audit_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/audit' in r.path] - - print(f"认证相关路由: {len(auth_routes)} 个") - print(f"用户管理路由: {len(user_routes)} 个") - print(f"审计日志路由: {len(audit_routes)} 个") - - return True - except Exception as e: - print(f"✗ 路由配置检查失败: {e}") - import traceback - traceback.print_exc() - return False - -def check_main_app(): - """检查 main.py 配置""" - print("\n" + "=" * 60) - print("步骤 3: 检查 main.py 配置") - print("=" * 60) - - try: - from app.main import app - - print("✓ FastAPI app 已创建") - print(f" 标题: {app.title}") - print(f" 版本: {app.version}") - - # 检查中间件 - middleware_names = [m.__class__.__name__ for m in app.user_middleware] - print(f"\n已注册的中间件: {len(middleware_names)} 个") - for name in middleware_names: - print(f" - {name}") - - # 检查路由 - print(f"\n已注册的路由: {len(app.routes)} 个") - - return True - except Exception as e: - print(f"✗ main.py 配置检查失败: {e}") - import traceback - traceback.print_exc() - return False - -def main(): - print("\n🔍 TJWater Server API 集成测试\n") - - results = [] - - # 测试 1: 模块导入 - results.append(("模块导入", check_imports())) - - # 测试 2: 路由配置 - results.append(("路由配置", check_router())) - - # 测试 3: main.py - results.append(("main.py配置", check_main_app())) - - # 总结 - print("\n" + "=" * 60) - print("测试总结") - print("=" * 60) - - for name, success in results: - status = "✓ 通过" if success else "✗ 失败" - print(f"{status:8s} - {name}") - - all_passed = all(success for _, success in results) - - if all_passed: - print("\n✅ 所有测试通过!") - print("\n下一步:") - print(" 1. 确保数据库迁移已执行") - print(" 2. 配置 .env 文件") - print(" 3. 启动服务: uvicorn app.main:app --reload") - print(" 4. 访问文档: http://localhost:8000/docs") - return 0 - else: - print("\n❌ 部分测试失败,请检查错误信息") - return 1 - -if __name__ == "__main__": - sys.exit(main()) diff --git a/tests/api/test_api_integration.py b/tests/api/test_api_integration.py new file mode 100755 index 0000000..741555c --- /dev/null +++ b/tests/api/test_api_integration.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +""" +测试新增 API 集成 + +验证新的认证、用户管理和审计日志接口是否正确集成 +""" + +import sys +import os +import pytest + +# 将项目根目录添加到 sys.path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) + + +@pytest.mark.parametrize( + "module_name, desc", + [ + ("app.core.encryption", "加密模块"), + ("app.core.security", "安全模块"), + ("app.core.audit", "审计模块"), + ("app.domain.models.role", "角色模型"), + ("app.domain.schemas.user", "用户Schema"), + ("app.domain.schemas.audit", "审计Schema"), + ("app.auth.permissions", "权限控制"), + ("app.api.v1.endpoints.auth", "认证接口"), + ("app.api.v1.endpoints.user_management", "用户管理接口"), + ("app.api.v1.endpoints.audit", "审计日志接口"), + ("app.infra.repositories.user_repository", "用户仓储"), + ("app.infra.repositories.audit_repository", "审计仓储"), + ("app.infra.audit.middleware", "审计中间件"), + ], +) +def test_module_imports(module_name, desc): + """检查关键模块是否可以导入""" + try: + __import__(module_name) + except ImportError as e: + pytest.fail(f"无法导入 {desc} ({module_name}): {e}") + + +def test_router_configuration(): + """检查路由配置""" + try: + from app.api.v1 import router + + # 检查 router 中是否包含新增的路由 + api_router = router.api_router + routes = [r.path for r in api_router.routes if hasattr(r, "path")] + + # 验证基础路径是否存在 + assert any("/auth" in r for r in routes), "缺少认证相关路由 (/auth)" + assert any("/users" in r for r in routes), "缺少用户管理路由 (/users)" + assert any("/audit" in r for r in routes), "缺少审计日志路由 (/audit)" + + except Exception as e: + pytest.fail(f"路由配置检查失败: {e}") + + +def test_main_app_initialization(): + """检查 main.py 配置""" + try: + from app.main import app + + assert app is not None + assert app.title != "" + + # 检查中间件 (简单检查是否存在) + middleware_names = [m.cls.__name__ for m in app.user_middleware] + # 检查是否包含审计中间件或其他关键中间件(根据实际类名修改) + assert "AuditMiddleware" in middleware_names, "缺少审计中间件" + + # 检查路由总数 + assert len(app.routes) > 0 + + except Exception as e: + pytest.fail(f"main.py 配置检查失败: {e}") diff --git a/tests/test_encryption.py b/tests/auth/test_encryption.py similarity index 90% rename from tests/test_encryption.py rename to tests/auth/test_encryption.py index 3544747..11907dc 100644 --- a/tests/test_encryption.py +++ b/tests/auth/test_encryption.py @@ -1,37 +1,41 @@ """ 测试加密功能 """ + import os import sys -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + def test_encryption(): """测试加密和解密功能""" from app.core.encryption import Encryptor - + # 生成测试密钥 key = Encryptor.generate_key() print(f"✓ 生成密钥: {key}") - + # 创建加密器 encryptor = Encryptor(key=key.encode()) - + # 测试加密 test_data = "这是敏感数据 - 数据库密码: password123" encrypted = encryptor.encrypt(test_data) print(f"✓ 加密成功: {encrypted[:50]}...") - + # 测试解密 decrypted = encryptor.decrypt(encrypted) assert decrypted == test_data, "解密数据不匹配!" print(f"✓ 解密成功: {decrypted}") - + # 测试空数据 assert encryptor.encrypt("") == "" assert encryptor.decrypt("") == "" print("✓ 空数据处理正确") - + print("\n✅ 所有加密测试通过!") + if __name__ == "__main__": test_encryption() diff --git a/tests/unit/test_pipeline_health_analyzer.py b/tests/unit/test_pipeline_health_analyzer.py index a6413a2..518f51f 100644 --- a/tests/unit/test_pipeline_health_analyzer.py +++ b/tests/unit/test_pipeline_health_analyzer.py @@ -1,7 +1,11 @@ -from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer +""" +tests.unit.test_pipeline_health_analyzer 的 Docstring +""" def test_pipeline_health_analyzer(): + from app.algorithms.api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer + # 初始化分析器,假设模型文件路径为'models/rsf_model.joblib' analyzer = PipelineHealthAnalyzer() # 创建示例输入数据(9个样本) @@ -51,7 +55,7 @@ def test_pipeline_health_analyzer(): ), "每个生存函数应包含x和y属性" # 可选:测试绘图功能(不显示图表) - analyzer.plot_survival(survival_functions, show_plot=True) + analyzer.plot_survival(survival_functions, show_plot=False) if __name__ == "__main__":