测试并修复api导入路径错误

This commit is contained in:
2026-02-02 11:09:43 +08:00
parent 807e634318
commit 35abaa1ebb
13 changed files with 211 additions and 250 deletions

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Request, Depends from fastapi import APIRouter, Request, Depends
from typing import Any, List, Dict, Union from typing import Any, List, Dict, Union
from app.services.tjnetwork import * from app.services.tjnetwork import *
from app.api.v1.endpoints.auth import verify_token from app.auth.dependencies import get_current_user as verify_token
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
import msgpack import msgpack

View File

@@ -43,7 +43,7 @@ from app.infra.db.timescaledb import router as timescaledb_router
api_router = APIRouter() api_router = APIRouter()
# Core Services # Core Services
api_router.include_router(auth.router, tags=["Auth"]) api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增 api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增 api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
api_router.include_router(project.router, tags=["Project"]) api_router.include_router(project.router, tags=["Project"])

View File

@@ -9,6 +9,7 @@ from app.infra.db.postgresql.database import Database
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
# 数据库依赖 # 数据库依赖
async def get_db(request: Request) -> Database: async def get_db(request: Request) -> Database:
""" """
@@ -19,17 +20,19 @@ async def get_db(request: Request) -> Database:
if not hasattr(request.app.state, "db"): if not hasattr(request.app.state, "db"):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database not initialized" detail="Database not initialized",
) )
return request.app.state.db return request.app.state.db
async def get_user_repository(db: Database = Depends(get_db)) -> UserRepository: async def get_user_repository(db: Database = Depends(get_db)) -> UserRepository:
"""获取用户仓储实例""" """获取用户仓储实例"""
return UserRepository(db) return UserRepository(db)
async def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
user_repo: UserRepository = Depends(get_user_repository) user_repo: UserRepository = Depends(get_user_repository),
) -> UserInDB: ) -> UserInDB:
""" """
获取当前登录用户 获取当前登录用户
@@ -43,7 +46,9 @@ async def get_current_user(
) )
try: try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub") username: str = payload.get("sub")
token_type: str = payload.get("type", "access") token_type: str = payload.get("type", "access")
@@ -67,6 +72,7 @@ async def get_current_user(
return user return user
async def get_current_active_user( async def get_current_active_user(
current_user: UserInDB = Depends(get_current_user), current_user: UserInDB = Depends(get_current_user),
) -> UserInDB: ) -> UserInDB:
@@ -75,11 +81,11 @@ async def get_current_active_user(
""" """
if not current_user.is_active: if not current_user.is_active:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
detail="Inactive user"
) )
return current_user return current_user
async def get_current_superuser( async def get_current_superuser(
current_user: UserInDB = Depends(get_current_user), current_user: UserInDB = Depends(get_current_user),
) -> UserInDB: ) -> UserInDB:
@@ -89,6 +95,6 @@ async def get_current_superuser(
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough privileges. Superuser access required." detail="Not enough privileges. Superuser access required.",
) )
return current_user return current_user

View File

@@ -32,5 +32,6 @@ class Settings(BaseSettings):
class Config: class Config:
env_file = ".env" env_file = ".env"
extra = "ignore"
settings = Settings() settings = Settings()

View File

@@ -1,12 +1,16 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, Union, Any from typing import Optional, Union, Any
from jose import jwt from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from app.core.config import settings from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
def create_access_token(
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
""" """
创建 JWT Access Token 创建 JWT Access Token
@@ -18,19 +22,24 @@ def create_access_token(subject: Union[str, Any], expires_delta: Optional[timede
JWT token 字符串 JWT token 字符串
""" """
if expires_delta: if expires_delta:
expire = datetime.utcnow() + expires_delta expire = datetime.now() + expires_delta
else: else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = { to_encode = {
"exp": expire, "exp": expire,
"sub": str(subject), "sub": str(subject),
"type": "access", "type": "access",
"iat": datetime.utcnow() "iat": datetime.now(),
} }
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt return encoded_jwt
def create_refresh_token(subject: Union[str, Any]) -> str: def create_refresh_token(subject: Union[str, Any]) -> str:
""" """
创建 JWT Refresh Token长期有效 创建 JWT Refresh Token长期有效
@@ -41,17 +50,20 @@ def create_refresh_token(subject: Union[str, Any]) -> str:
Returns: Returns:
JWT refresh token 字符串 JWT refresh token 字符串
""" """
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = { to_encode = {
"exp": expire, "exp": expire,
"sub": str(subject), "sub": str(subject),
"type": "refresh", "type": "refresh",
"iat": datetime.utcnow() "iat": datetime.now(),
} }
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
""" """
验证密码 验证密码
@@ -65,6 +77,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
""" """
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
""" """
生成密码哈希 生成密码哈希

View File

@@ -13,8 +13,8 @@ from typing import List, Dict
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from influxdb_client.client.write_api import SYNCHRONOUS, ASYNCHRONOUS from influxdb_client.client.write_api import SYNCHRONOUS, ASYNCHRONOUS
from dateutil import parser from dateutil import parser
import get_realValue # import get_realValue
import get_data # import get_data
import psycopg import psycopg
import time import time
import app.services.simulation as simulation import app.services.simulation as simulation

View File

@@ -71,7 +71,7 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
# 添加审计中间件(可选,记录关键操作) # 添加审计中间件(可选,记录关键操作)
# 如果需要启用审计日志,取消下面的注释 # 如果需要启用审计日志,取消下面的注释
# app.add_middleware(AuditMiddleware) app.add_middleware(AuditMiddleware)
# Include Routers # Include Routers
app.include_router(api_router, prefix="/api/v1") app.include_router(api_router, prefix="/api/v1")

View File

@@ -7,18 +7,19 @@ import csv
# get_data 是用来获取 历史数据,也就是非实时数据的接口 # get_data 是用来获取 历史数据,也就是非实时数据的接口
# get_realtime 是用来获取 实时数据 # get_realtime 是用来获取 实时数据
def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime: def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime:
# 将毫秒级时间戳转换为秒级时间戳 # 将毫秒级时间戳转换为秒级时间戳
timestamp_seconds = timestamp / 1000 timestamp_seconds = timestamp / 1000
# 将时间戳转换为datetime对象 # 将时间戳转换为datetime对象
utc_time = datetime.utcfromtimestamp(timestamp_seconds) utc_time = datetime.fromtimestamp(timestamp_seconds)
# 设定UTC时区 # 设定UTC时区
utc_timezone = pytz.timezone('UTC') utc_timezone = pytz.timezone("UTC")
# 转换为北京时间 # 转换为北京时间
beijing_timezone = pytz.timezone('Asia/Shanghai') beijing_timezone = pytz.timezone("Asia/Shanghai")
beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone) beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone)
return beijing_time return beijing_time
@@ -26,10 +27,10 @@ def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime:
def beijing_time_to_utc(beijing_time_str: str) -> str: def beijing_time_to_utc(beijing_time_str: str) -> str:
# 定义北京时区 # 定义北京时区
beijing_timezone = pytz.timezone('Asia/Shanghai') beijing_timezone = pytz.timezone("Asia/Shanghai")
# 将字符串转换为datetime对象 # 将字符串转换为datetime对象
beijing_time = datetime.strptime(beijing_time_str, '%Y-%m-%d %H:%M:%S') beijing_time = datetime.strptime(beijing_time_str, "%Y-%m-%d %H:%M:%S")
# 本地化时间对象 # 本地化时间对象
beijing_time = beijing_timezone.localize(beijing_time) beijing_time = beijing_timezone.localize(beijing_time)
@@ -38,29 +39,31 @@ def beijing_time_to_utc(beijing_time_str: str) -> str:
utc_time = beijing_time.astimezone(pytz.utc) utc_time = beijing_time.astimezone(pytz.utc)
# 转换为ISO 8601格式的字符串 # 转换为ISO 8601格式的字符串
return utc_time.strftime('%Y-%m-%dT%H:%M:%SZ') return utc_time.strftime("%Y-%m-%dT%H:%M:%SZ")
def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> List[Dict[str, Union[str, datetime, int, float]]]: def get_history_data(
# def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None: ids: str, begin_date: str, end_date: str, downsample: Optional[str]
) -> List[Dict[str, Union[str, datetime, int, float]]]:
# def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None:
# 转换输入的北京时间为UTC时间 # 转换输入的北京时间为UTC时间
begin_date_utc = beijing_time_to_utc(begin_date) begin_date_utc = beijing_time_to_utc(begin_date)
end_date_utc = beijing_time_to_utc(end_date) end_date_utc = beijing_time_to_utc(end_date)
# 数据接口的地址 # 数据接口的地址
url = 'http://183.64.62.100:9057/loong/api/curves/data' url = "http://183.64.62.100:9057/loong/api/curves/data"
# url = 'http://10.101.15.16:9000/loong/api/curves/data' # url = 'http://10.101.15.16:9000/loong/api/curves/data'
# url_path = 'http://10.101.15.16:9000/loong' # 内网 # url_path = 'http://10.101.15.16:9000/loong' # 内网
# 设置 GET 请求的参数 # 设置 GET 请求的参数
params = { params = {
'ids': ids, "ids": ids,
'beginDate': begin_date_utc, "beginDate": begin_date_utc,
'endDate': end_date_utc, "endDate": end_date_utc,
'downsample': downsample "downsample": downsample,
} }
history_data_list =[] history_data_list = []
try: try:
# 发送 GET 请求获取数据 # 发送 GET 请求获取数据
@@ -73,24 +76,26 @@ def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optio
# 这里可以对获取到的数据进行进一步处理 # 这里可以对获取到的数据进行进一步处理
# 打印 'mpointId' 和 'mpointName' # 打印 'mpointId' 和 'mpointName'
for item in data['items']: for item in data["items"]:
mpoint_id = str(item['mpointId']) mpoint_id = str(item["mpointId"])
mpoint_name = item['mpointName'] mpoint_name = item["mpointName"]
# print("mpointId:", item['mpointId']) # print("mpointId:", item['mpointId'])
# print("mpointName:", item['mpointName']) # print("mpointName:", item['mpointName'])
# 打印 'dataDate' 和 'dataValue' # 打印 'dataDate' 和 'dataValue'
for item_data in item['data']: for item_data in item["data"]:
# 将时间戳转换为北京时间 # 将时间戳转换为北京时间
beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate']) beijing_time = convert_timestamp_to_beijing_time(
data_value = item_data['dataValue'] item_data["dataDate"]
)
data_value = item_data["dataValue"]
# 创建一个字典存储每条数据 # 创建一个字典存储每条数据
data_dict = { data_dict = {
'time': beijing_time, "time": beijing_time,
'device_ID': str(mpoint_id), "device_ID": str(mpoint_id),
'description': mpoint_name, "description": mpoint_name,
# 'dataDate (Beijing Time)': beijing_time.strftime('%Y-%m-%d %H:%M:%S'), # 'dataDate (Beijing Time)': beijing_time.strftime('%Y-%m-%d %H:%M:%S'),
'monitored_value': data_value # 保留原有类型 "monitored_value": data_value, # 保留原有类型
} }
history_data_list.append(data_dict) history_data_list.append(data_dict)

View File

@@ -8,35 +8,33 @@ def convert_timestamp_to_beijing_time(timestamp):
timestamp_seconds = timestamp / 1000 timestamp_seconds = timestamp / 1000
# 将时间戳转换为datetime对象 # 将时间戳转换为datetime对象
utc_time = datetime.utcfromtimestamp(timestamp_seconds) utc_time = datetime.fromtimestamp(timestamp_seconds)
# 设定UTC时区 # 设定UTC时区
utc_timezone = pytz.timezone('UTC') utc_timezone = pytz.timezone("UTC")
# 转换为北京时间 # 转换为北京时间
beijing_timezone = pytz.timezone('Asia/Shanghai') beijing_timezone = pytz.timezone("Asia/Shanghai")
beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone) beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone)
return beijing_time return beijing_time
def conver_beingtime_to_ucttime(timestr:str):
beijing_time=datetime.strptime(timestr,'%Y-%m-%d %H:%M:%S') def conver_beingtime_to_ucttime(timestr: str):
utc_time=beijing_time.astimezone(pytz.utc) beijing_time = datetime.strptime(timestr, "%Y-%m-%d %H:%M:%S")
str_utc=utc_time.strftime('%Y-%m-%dT%H:%M:%SZ') utc_time = beijing_time.astimezone(pytz.utc)
#print(str_utc) str_utc = utc_time.strftime("%Y-%m-%dT%H:%M:%SZ")
# print(str_utc)
return str_utc return str_utc
def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]:
def get_hist_data(ids, begin_date, end_date) -> dict[str, dict[datetime, float]]:
# 数据接口的地址 # 数据接口的地址
url = 'http://183.64.62.100:9057/loong/api/curves/data' url = "http://183.64.62.100:9057/loong/api/curves/data"
# 设置 GET 请求的参数 # 设置 GET 请求的参数
params = { params = {"ids": ids, "beginDate": begin_date, "endDate": end_date}
'ids': ids, lst_data = {}
'beginDate': begin_date,
'endDate': end_date
}
lst_data={}
try: try:
# 发送 GET 请求获取数据 # 发送 GET 请求获取数据
response = requests.get(url, params=params) response = requests.get(url, params=params)
@@ -48,22 +46,27 @@ def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]:
# 这里可以对获取到的数据进行进一步处理 # 这里可以对获取到的数据进行进一步处理
# 打印 'mpointId' 和 'mpointName' # 打印 'mpointId' 和 'mpointName'
for item in data['items']: for item in data["items"]:
#print("mpointId:", item['mpointId']) # print("mpointId:", item['mpointId'])
#print("mpointName:", item['mpointName']) # print("mpointName:", item['mpointName'])
# 打印 'dataDate' 和 'dataValue' # 打印 'dataDate' 和 'dataValue'
data_seriers={} data_seriers = {}
for item_data in item['data']: for item_data in item["data"]:
# print("dataDate:", item_data['dataDate']) # print("dataDate:", item_data['dataDate'])
# 将时间戳转换为北京时间 # 将时间戳转换为北京时间
beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate']) beijing_time = convert_timestamp_to_beijing_time(
print("dataDate (Beijing Time):", beijing_time.strftime('%Y-%m-%d %H:%M:%S')) item_data["dataDate"]
print("dataValue:", item_data['dataValue']) )
print(
"dataDate (Beijing Time):",
beijing_time.strftime("%Y-%m-%d %H:%M:%S"),
)
print("dataValue:", item_data["dataValue"])
print() # 打印空行分隔不同条目 print() # 打印空行分隔不同条目
r=float(item_data['dataValue']) r = float(item_data["dataValue"])
data_seriers[beijing_time]=r data_seriers[beijing_time] = r
lst_data[item['mpointId']]=data_seriers lst_data[item["mpointId"]] = data_seriers
return lst_data return lst_data
else: else:
# 如果请求不成功,打印错误信息 # 如果请求不成功,打印错误信息

View File

@@ -1,152 +0,0 @@
#!/usr/bin/env python
"""
测试新增 API 集成
验证新的认证、用户管理和审计日志接口是否正确集成
"""
import sys
import subprocess
import time
def check_imports():
"""检查关键模块是否可以导入"""
print("=" * 60)
print("步骤 1: 检查模块导入")
print("=" * 60)
modules = [
("app.core.encryption", "加密模块"),
("app.core.security", "安全模块"),
("app.core.audit", "审计模块"),
("app.domain.models.role", "角色模型"),
("app.domain.schemas.user", "用户Schema"),
("app.domain.schemas.audit", "审计Schema"),
("app.auth.permissions", "权限控制"),
("app.api.v1.endpoints.auth", "认证接口"),
("app.api.v1.endpoints.user_management", "用户管理接口"),
("app.api.v1.endpoints.audit", "审计日志接口"),
("app.infra.repositories.user_repository", "用户仓储"),
("app.infra.repositories.audit_repository", "审计仓储"),
("app.infra.audit.middleware", "审计中间件"),
]
success = 0
failed = 0
for module_name, desc in modules:
try:
__import__(module_name)
print(f"{desc:20s} ({module_name})")
success += 1
except Exception as e:
print(f"{desc:20s} ({module_name})")
print(f" 错误: {e}")
failed += 1
print(f"\n结果: {success} 成功, {failed} 失败")
print()
return failed == 0
def check_router():
"""检查路由配置"""
print("=" * 60)
print("步骤 2: 检查路由配置")
print("=" * 60)
try:
from app.api.v1 import router
from app.api.v1.endpoints import auth, user_management, audit
print("✓ router 模块已导入")
print("✓ auth 端点已导入")
print("✓ user_management 端点已导入")
print("✓ audit 端点已导入")
# 检查 router 中是否包含新增的路由
api_router = router.api_router
print(f"\n已注册的路由数量: {len(api_router.routes)}")
# 查找新增的路由
auth_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/auth' in r.path]
user_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/users' in r.path]
audit_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/audit' in r.path]
print(f"认证相关路由: {len(auth_routes)}")
print(f"用户管理路由: {len(user_routes)}")
print(f"审计日志路由: {len(audit_routes)}")
return True
except Exception as e:
print(f"✗ 路由配置检查失败: {e}")
import traceback
traceback.print_exc()
return False
def check_main_app():
"""检查 main.py 配置"""
print("\n" + "=" * 60)
print("步骤 3: 检查 main.py 配置")
print("=" * 60)
try:
from app.main import app
print("✓ FastAPI app 已创建")
print(f" 标题: {app.title}")
print(f" 版本: {app.version}")
# 检查中间件
middleware_names = [m.__class__.__name__ for m in app.user_middleware]
print(f"\n已注册的中间件: {len(middleware_names)}")
for name in middleware_names:
print(f" - {name}")
# 检查路由
print(f"\n已注册的路由: {len(app.routes)}")
return True
except Exception as e:
print(f"✗ main.py 配置检查失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
print("\n🔍 TJWater Server API 集成测试\n")
results = []
# 测试 1: 模块导入
results.append(("模块导入", check_imports()))
# 测试 2: 路由配置
results.append(("路由配置", check_router()))
# 测试 3: main.py
results.append(("main.py配置", check_main_app()))
# 总结
print("\n" + "=" * 60)
print("测试总结")
print("=" * 60)
for name, success in results:
status = "✓ 通过" if success else "✗ 失败"
print(f"{status:8s} - {name}")
all_passed = all(success for _, success in results)
if all_passed:
print("\n✅ 所有测试通过!")
print("\n下一步:")
print(" 1. 确保数据库迁移已执行")
print(" 2. 配置 .env 文件")
print(" 3. 启动服务: uvicorn app.main:app --reload")
print(" 4. 访问文档: http://localhost:8000/docs")
return 0
else:
print("\n❌ 部分测试失败,请检查错误信息")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python
"""
测试新增 API 集成
验证新的认证、用户管理和审计日志接口是否正确集成
"""
import sys
import os
import pytest
# 将项目根目录添加到 sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
@pytest.mark.parametrize(
"module_name, desc",
[
("app.core.encryption", "加密模块"),
("app.core.security", "安全模块"),
("app.core.audit", "审计模块"),
("app.domain.models.role", "角色模型"),
("app.domain.schemas.user", "用户Schema"),
("app.domain.schemas.audit", "审计Schema"),
("app.auth.permissions", "权限控制"),
("app.api.v1.endpoints.auth", "认证接口"),
("app.api.v1.endpoints.user_management", "用户管理接口"),
("app.api.v1.endpoints.audit", "审计日志接口"),
("app.infra.repositories.user_repository", "用户仓储"),
("app.infra.repositories.audit_repository", "审计仓储"),
("app.infra.audit.middleware", "审计中间件"),
],
)
def test_module_imports(module_name, desc):
"""检查关键模块是否可以导入"""
try:
__import__(module_name)
except ImportError as e:
pytest.fail(f"无法导入 {desc} ({module_name}): {e}")
def test_router_configuration():
"""检查路由配置"""
try:
from app.api.v1 import router
# 检查 router 中是否包含新增的路由
api_router = router.api_router
routes = [r.path for r in api_router.routes if hasattr(r, "path")]
# 验证基础路径是否存在
assert any("/auth" in r for r in routes), "缺少认证相关路由 (/auth)"
assert any("/users" in r for r in routes), "缺少用户管理路由 (/users)"
assert any("/audit" in r for r in routes), "缺少审计日志路由 (/audit)"
except Exception as e:
pytest.fail(f"路由配置检查失败: {e}")
def test_main_app_initialization():
"""检查 main.py 配置"""
try:
from app.main import app
assert app is not None
assert app.title != ""
# 检查中间件 (简单检查是否存在)
middleware_names = [m.cls.__name__ for m in app.user_middleware]
# 检查是否包含审计中间件或其他关键中间件(根据实际类名修改)
assert "AuditMiddleware" in middleware_names, "缺少审计中间件"
# 检查路由总数
assert len(app.routes) > 0
except Exception as e:
pytest.fail(f"main.py 配置检查失败: {e}")

View File

@@ -1,9 +1,12 @@
""" """
测试加密功能 测试加密功能
""" """
import os import os
import sys import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
def test_encryption(): def test_encryption():
"""测试加密和解密功能""" """测试加密和解密功能"""
@@ -33,5 +36,6 @@ def test_encryption():
print("\n✅ 所有加密测试通过!") print("\n✅ 所有加密测试通过!")
if __name__ == "__main__": if __name__ == "__main__":
test_encryption() test_encryption()

View File

@@ -1,7 +1,11 @@
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer """
tests.unit.test_pipeline_health_analyzer 的 Docstring
"""
def test_pipeline_health_analyzer(): def test_pipeline_health_analyzer():
from app.algorithms.api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib' # 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
analyzer = PipelineHealthAnalyzer() analyzer = PipelineHealthAnalyzer()
# 创建示例输入数据9个样本 # 创建示例输入数据9个样本
@@ -51,7 +55,7 @@ def test_pipeline_health_analyzer():
), "每个生存函数应包含x和y属性" ), "每个生存函数应包含x和y属性"
# 可选:测试绘图功能(不显示图表) # 可选:测试绘图功能(不显示图表)
analyzer.plot_survival(survival_functions, show_plot=True) analyzer.plot_survival(survival_functions, show_plot=False)
if __name__ == "__main__": if __name__ == "__main__":