测试并修复api导入路径错误
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Request, Depends
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.api.v1.endpoints.auth import verify_token
|
||||
from app.auth.dependencies import get_current_user as verify_token
|
||||
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
|
||||
import msgpack
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from app.infra.db.timescaledb import router as timescaledb_router
|
||||
api_router = APIRouter()
|
||||
|
||||
# Core Services
|
||||
api_router.include_router(auth.router, tags=["Auth"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
||||
api_router.include_router(project.router, tags=["Project"])
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.infra.db.postgresql.database import Database
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
|
||||
|
||||
|
||||
# 数据库依赖
|
||||
async def get_db(request: Request) -> Database:
|
||||
"""
|
||||
@@ -19,17 +20,19 @@ async def get_db(request: Request) -> Database:
|
||||
if not hasattr(request.app.state, "db"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Database not initialized"
|
||||
detail="Database not initialized",
|
||||
)
|
||||
return request.app.state.db
|
||||
|
||||
|
||||
async def get_user_repository(db: Database = Depends(get_db)) -> UserRepository:
|
||||
"""获取用户仓储实例"""
|
||||
return UserRepository(db)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> UserInDB:
|
||||
"""
|
||||
获取当前登录用户
|
||||
@@ -43,7 +46,9 @@ async def get_current_user(
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
token_type: str = payload.get("type", "access")
|
||||
|
||||
@@ -67,6 +72,7 @@ async def get_current_user(
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> UserInDB:
|
||||
@@ -75,11 +81,11 @@ async def get_current_active_user(
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_superuser(
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> UserInDB:
|
||||
@@ -89,6 +95,6 @@ async def get_current_superuser(
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough privileges. Superuser access required."
|
||||
detail="Not enough privileges. Superuser access required.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
@@ -32,5 +32,6 @@ class Settings(BaseSettings):
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
创建 JWT Access Token
|
||||
|
||||
@@ -18,19 +22,24 @@ def create_access_token(subject: Union[str, Any], expires_delta: Optional[timede
|
||||
JWT token 字符串
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
expire = datetime.now() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
to_encode = {
|
||||
"exp": expire,
|
||||
"sub": str(subject),
|
||||
"type": "access",
|
||||
"iat": datetime.utcnow()
|
||||
"iat": datetime.now(),
|
||||
}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(subject: Union[str, Any]) -> str:
|
||||
"""
|
||||
创建 JWT Refresh Token(长期有效)
|
||||
@@ -41,17 +50,20 @@ def create_refresh_token(subject: Union[str, Any]) -> str:
|
||||
Returns:
|
||||
JWT refresh token 字符串
|
||||
"""
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"exp": expire,
|
||||
"sub": str(subject),
|
||||
"type": "refresh",
|
||||
"iat": datetime.utcnow()
|
||||
"iat": datetime.now(),
|
||||
}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
验证密码
|
||||
@@ -65,6 +77,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""
|
||||
生成密码哈希
|
||||
|
||||
@@ -13,8 +13,8 @@ from typing import List, Dict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from influxdb_client.client.write_api import SYNCHRONOUS, ASYNCHRONOUS
|
||||
from dateutil import parser
|
||||
import get_realValue
|
||||
import get_data
|
||||
# import get_realValue
|
||||
# import get_data
|
||||
import psycopg
|
||||
import time
|
||||
import app.services.simulation as simulation
|
||||
|
||||
@@ -71,7 +71,7 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# 添加审计中间件(可选,记录关键操作)
|
||||
# 如果需要启用审计日志,取消下面的注释
|
||||
# app.add_middleware(AuditMiddleware)
|
||||
app.add_middleware(AuditMiddleware)
|
||||
|
||||
# Include Routers
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
@@ -7,18 +7,19 @@ import csv
|
||||
# get_data 是用来获取 历史数据,也就是非实时数据的接口
|
||||
# get_realtime 是用来获取 实时数据
|
||||
|
||||
|
||||
def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime:
|
||||
# 将毫秒级时间戳转换为秒级时间戳
|
||||
timestamp_seconds = timestamp / 1000
|
||||
|
||||
# 将时间戳转换为datetime对象
|
||||
utc_time = datetime.utcfromtimestamp(timestamp_seconds)
|
||||
utc_time = datetime.fromtimestamp(timestamp_seconds)
|
||||
|
||||
# 设定UTC时区
|
||||
utc_timezone = pytz.timezone('UTC')
|
||||
utc_timezone = pytz.timezone("UTC")
|
||||
|
||||
# 转换为北京时间
|
||||
beijing_timezone = pytz.timezone('Asia/Shanghai')
|
||||
beijing_timezone = pytz.timezone("Asia/Shanghai")
|
||||
beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone)
|
||||
|
||||
return beijing_time
|
||||
@@ -26,10 +27,10 @@ def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime:
|
||||
|
||||
def beijing_time_to_utc(beijing_time_str: str) -> str:
|
||||
# 定义北京时区
|
||||
beijing_timezone = pytz.timezone('Asia/Shanghai')
|
||||
beijing_timezone = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
# 将字符串转换为datetime对象
|
||||
beijing_time = datetime.strptime(beijing_time_str, '%Y-%m-%d %H:%M:%S')
|
||||
beijing_time = datetime.strptime(beijing_time_str, "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 本地化时间对象
|
||||
beijing_time = beijing_timezone.localize(beijing_time)
|
||||
@@ -38,29 +39,31 @@ def beijing_time_to_utc(beijing_time_str: str) -> str:
|
||||
utc_time = beijing_time.astimezone(pytz.utc)
|
||||
|
||||
# 转换为ISO 8601格式的字符串
|
||||
return utc_time.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
return utc_time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> List[Dict[str, Union[str, datetime, int, float]]]:
|
||||
# def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None:
|
||||
def get_history_data(
|
||||
ids: str, begin_date: str, end_date: str, downsample: Optional[str]
|
||||
) -> List[Dict[str, Union[str, datetime, int, float]]]:
|
||||
# def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None:
|
||||
# 转换输入的北京时间为UTC时间
|
||||
begin_date_utc = beijing_time_to_utc(begin_date)
|
||||
end_date_utc = beijing_time_to_utc(end_date)
|
||||
|
||||
# 数据接口的地址
|
||||
url = 'http://183.64.62.100:9057/loong/api/curves/data'
|
||||
url = "http://183.64.62.100:9057/loong/api/curves/data"
|
||||
# url = 'http://10.101.15.16:9000/loong/api/curves/data'
|
||||
# url_path = 'http://10.101.15.16:9000/loong' # 内网
|
||||
|
||||
# 设置 GET 请求的参数
|
||||
params = {
|
||||
'ids': ids,
|
||||
'beginDate': begin_date_utc,
|
||||
'endDate': end_date_utc,
|
||||
'downsample': downsample
|
||||
"ids": ids,
|
||||
"beginDate": begin_date_utc,
|
||||
"endDate": end_date_utc,
|
||||
"downsample": downsample,
|
||||
}
|
||||
|
||||
history_data_list =[]
|
||||
history_data_list = []
|
||||
|
||||
try:
|
||||
# 发送 GET 请求获取数据
|
||||
@@ -73,24 +76,26 @@ def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optio
|
||||
# 这里可以对获取到的数据进行进一步处理
|
||||
|
||||
# 打印 'mpointId' 和 'mpointName'
|
||||
for item in data['items']:
|
||||
mpoint_id = str(item['mpointId'])
|
||||
mpoint_name = item['mpointName']
|
||||
for item in data["items"]:
|
||||
mpoint_id = str(item["mpointId"])
|
||||
mpoint_name = item["mpointName"]
|
||||
# print("mpointId:", item['mpointId'])
|
||||
# print("mpointName:", item['mpointName'])
|
||||
|
||||
# 打印 'dataDate' 和 'dataValue'
|
||||
for item_data in item['data']:
|
||||
for item_data in item["data"]:
|
||||
# 将时间戳转换为北京时间
|
||||
beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate'])
|
||||
data_value = item_data['dataValue']
|
||||
beijing_time = convert_timestamp_to_beijing_time(
|
||||
item_data["dataDate"]
|
||||
)
|
||||
data_value = item_data["dataValue"]
|
||||
# 创建一个字典存储每条数据
|
||||
data_dict = {
|
||||
'time': beijing_time,
|
||||
'device_ID': str(mpoint_id),
|
||||
'description': mpoint_name,
|
||||
"time": beijing_time,
|
||||
"device_ID": str(mpoint_id),
|
||||
"description": mpoint_name,
|
||||
# 'dataDate (Beijing Time)': beijing_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'monitored_value': data_value # 保留原有类型
|
||||
"monitored_value": data_value, # 保留原有类型
|
||||
}
|
||||
|
||||
history_data_list.append(data_dict)
|
||||
|
||||
@@ -8,35 +8,33 @@ def convert_timestamp_to_beijing_time(timestamp):
|
||||
timestamp_seconds = timestamp / 1000
|
||||
|
||||
# 将时间戳转换为datetime对象
|
||||
utc_time = datetime.utcfromtimestamp(timestamp_seconds)
|
||||
utc_time = datetime.fromtimestamp(timestamp_seconds)
|
||||
|
||||
# 设定UTC时区
|
||||
utc_timezone = pytz.timezone('UTC')
|
||||
utc_timezone = pytz.timezone("UTC")
|
||||
|
||||
# 转换为北京时间
|
||||
beijing_timezone = pytz.timezone('Asia/Shanghai')
|
||||
beijing_timezone = pytz.timezone("Asia/Shanghai")
|
||||
beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone)
|
||||
|
||||
return beijing_time
|
||||
|
||||
def conver_beingtime_to_ucttime(timestr:str):
|
||||
beijing_time=datetime.strptime(timestr,'%Y-%m-%d %H:%M:%S')
|
||||
utc_time=beijing_time.astimezone(pytz.utc)
|
||||
str_utc=utc_time.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
#print(str_utc)
|
||||
|
||||
def conver_beingtime_to_ucttime(timestr: str):
|
||||
beijing_time = datetime.strptime(timestr, "%Y-%m-%d %H:%M:%S")
|
||||
utc_time = beijing_time.astimezone(pytz.utc)
|
||||
str_utc = utc_time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
# print(str_utc)
|
||||
return str_utc
|
||||
|
||||
def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]:
|
||||
|
||||
def get_hist_data(ids, begin_date, end_date) -> dict[str, dict[datetime, float]]:
|
||||
# 数据接口的地址
|
||||
url = 'http://183.64.62.100:9057/loong/api/curves/data'
|
||||
url = "http://183.64.62.100:9057/loong/api/curves/data"
|
||||
|
||||
# 设置 GET 请求的参数
|
||||
params = {
|
||||
'ids': ids,
|
||||
'beginDate': begin_date,
|
||||
'endDate': end_date
|
||||
}
|
||||
lst_data={}
|
||||
params = {"ids": ids, "beginDate": begin_date, "endDate": end_date}
|
||||
lst_data = {}
|
||||
try:
|
||||
# 发送 GET 请求获取数据
|
||||
response = requests.get(url, params=params)
|
||||
@@ -48,22 +46,27 @@ def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]:
|
||||
# 这里可以对获取到的数据进行进一步处理
|
||||
|
||||
# 打印 'mpointId' 和 'mpointName'
|
||||
for item in data['items']:
|
||||
#print("mpointId:", item['mpointId'])
|
||||
#print("mpointName:", item['mpointName'])
|
||||
for item in data["items"]:
|
||||
# print("mpointId:", item['mpointId'])
|
||||
# print("mpointName:", item['mpointName'])
|
||||
|
||||
# 打印 'dataDate' 和 'dataValue'
|
||||
data_seriers={}
|
||||
for item_data in item['data']:
|
||||
data_seriers = {}
|
||||
for item_data in item["data"]:
|
||||
# print("dataDate:", item_data['dataDate'])
|
||||
# 将时间戳转换为北京时间
|
||||
beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate'])
|
||||
print("dataDate (Beijing Time):", beijing_time.strftime('%Y-%m-%d %H:%M:%S'))
|
||||
print("dataValue:", item_data['dataValue'])
|
||||
beijing_time = convert_timestamp_to_beijing_time(
|
||||
item_data["dataDate"]
|
||||
)
|
||||
print(
|
||||
"dataDate (Beijing Time):",
|
||||
beijing_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
print("dataValue:", item_data["dataValue"])
|
||||
print() # 打印空行分隔不同条目
|
||||
r=float(item_data['dataValue'])
|
||||
data_seriers[beijing_time]=r
|
||||
lst_data[item['mpointId']]=data_seriers
|
||||
r = float(item_data["dataValue"])
|
||||
data_seriers[beijing_time] = r
|
||||
lst_data[item["mpointId"]] = data_seriers
|
||||
return lst_data
|
||||
else:
|
||||
# 如果请求不成功,打印错误信息
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
测试新增 API 集成
|
||||
|
||||
验证新的认证、用户管理和审计日志接口是否正确集成
|
||||
"""
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
def check_imports():
|
||||
"""检查关键模块是否可以导入"""
|
||||
print("=" * 60)
|
||||
print("步骤 1: 检查模块导入")
|
||||
print("=" * 60)
|
||||
|
||||
modules = [
|
||||
("app.core.encryption", "加密模块"),
|
||||
("app.core.security", "安全模块"),
|
||||
("app.core.audit", "审计模块"),
|
||||
("app.domain.models.role", "角色模型"),
|
||||
("app.domain.schemas.user", "用户Schema"),
|
||||
("app.domain.schemas.audit", "审计Schema"),
|
||||
("app.auth.permissions", "权限控制"),
|
||||
("app.api.v1.endpoints.auth", "认证接口"),
|
||||
("app.api.v1.endpoints.user_management", "用户管理接口"),
|
||||
("app.api.v1.endpoints.audit", "审计日志接口"),
|
||||
("app.infra.repositories.user_repository", "用户仓储"),
|
||||
("app.infra.repositories.audit_repository", "审计仓储"),
|
||||
("app.infra.audit.middleware", "审计中间件"),
|
||||
]
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
|
||||
for module_name, desc in modules:
|
||||
try:
|
||||
__import__(module_name)
|
||||
print(f"✓ {desc:20s} ({module_name})")
|
||||
success += 1
|
||||
except Exception as e:
|
||||
print(f"✗ {desc:20s} ({module_name})")
|
||||
print(f" 错误: {e}")
|
||||
failed += 1
|
||||
|
||||
print(f"\n结果: {success} 成功, {failed} 失败")
|
||||
print()
|
||||
return failed == 0
|
||||
|
||||
def check_router():
|
||||
"""检查路由配置"""
|
||||
print("=" * 60)
|
||||
print("步骤 2: 检查路由配置")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from app.api.v1 import router
|
||||
from app.api.v1.endpoints import auth, user_management, audit
|
||||
|
||||
print("✓ router 模块已导入")
|
||||
print("✓ auth 端点已导入")
|
||||
print("✓ user_management 端点已导入")
|
||||
print("✓ audit 端点已导入")
|
||||
|
||||
# 检查 router 中是否包含新增的路由
|
||||
api_router = router.api_router
|
||||
print(f"\n已注册的路由数量: {len(api_router.routes)}")
|
||||
|
||||
# 查找新增的路由
|
||||
auth_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/auth' in r.path]
|
||||
user_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/users' in r.path]
|
||||
audit_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/audit' in r.path]
|
||||
|
||||
print(f"认证相关路由: {len(auth_routes)} 个")
|
||||
print(f"用户管理路由: {len(user_routes)} 个")
|
||||
print(f"审计日志路由: {len(audit_routes)} 个")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ 路由配置检查失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def check_main_app():
|
||||
"""检查 main.py 配置"""
|
||||
print("\n" + "=" * 60)
|
||||
print("步骤 3: 检查 main.py 配置")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from app.main import app
|
||||
|
||||
print("✓ FastAPI app 已创建")
|
||||
print(f" 标题: {app.title}")
|
||||
print(f" 版本: {app.version}")
|
||||
|
||||
# 检查中间件
|
||||
middleware_names = [m.__class__.__name__ for m in app.user_middleware]
|
||||
print(f"\n已注册的中间件: {len(middleware_names)} 个")
|
||||
for name in middleware_names:
|
||||
print(f" - {name}")
|
||||
|
||||
# 检查路由
|
||||
print(f"\n已注册的路由: {len(app.routes)} 个")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ main.py 配置检查失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def main():
|
||||
print("\n🔍 TJWater Server API 集成测试\n")
|
||||
|
||||
results = []
|
||||
|
||||
# 测试 1: 模块导入
|
||||
results.append(("模块导入", check_imports()))
|
||||
|
||||
# 测试 2: 路由配置
|
||||
results.append(("路由配置", check_router()))
|
||||
|
||||
# 测试 3: main.py
|
||||
results.append(("main.py配置", check_main_app()))
|
||||
|
||||
# 总结
|
||||
print("\n" + "=" * 60)
|
||||
print("测试总结")
|
||||
print("=" * 60)
|
||||
|
||||
for name, success in results:
|
||||
status = "✓ 通过" if success else "✗ 失败"
|
||||
print(f"{status:8s} - {name}")
|
||||
|
||||
all_passed = all(success for _, success in results)
|
||||
|
||||
if all_passed:
|
||||
print("\n✅ 所有测试通过!")
|
||||
print("\n下一步:")
|
||||
print(" 1. 确保数据库迁移已执行")
|
||||
print(" 2. 配置 .env 文件")
|
||||
print(" 3. 启动服务: uvicorn app.main:app --reload")
|
||||
print(" 4. 访问文档: http://localhost:8000/docs")
|
||||
return 0
|
||||
else:
|
||||
print("\n❌ 部分测试失败,请检查错误信息")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
77
tests/api/test_api_integration.py
Executable file
77
tests/api/test_api_integration.py
Executable file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
测试新增 API 集成
|
||||
|
||||
验证新的认证、用户管理和审计日志接口是否正确集成
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import pytest
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"module_name, desc",
|
||||
[
|
||||
("app.core.encryption", "加密模块"),
|
||||
("app.core.security", "安全模块"),
|
||||
("app.core.audit", "审计模块"),
|
||||
("app.domain.models.role", "角色模型"),
|
||||
("app.domain.schemas.user", "用户Schema"),
|
||||
("app.domain.schemas.audit", "审计Schema"),
|
||||
("app.auth.permissions", "权限控制"),
|
||||
("app.api.v1.endpoints.auth", "认证接口"),
|
||||
("app.api.v1.endpoints.user_management", "用户管理接口"),
|
||||
("app.api.v1.endpoints.audit", "审计日志接口"),
|
||||
("app.infra.repositories.user_repository", "用户仓储"),
|
||||
("app.infra.repositories.audit_repository", "审计仓储"),
|
||||
("app.infra.audit.middleware", "审计中间件"),
|
||||
],
|
||||
)
|
||||
def test_module_imports(module_name, desc):
|
||||
"""检查关键模块是否可以导入"""
|
||||
try:
|
||||
__import__(module_name)
|
||||
except ImportError as e:
|
||||
pytest.fail(f"无法导入 {desc} ({module_name}): {e}")
|
||||
|
||||
|
||||
def test_router_configuration():
|
||||
"""检查路由配置"""
|
||||
try:
|
||||
from app.api.v1 import router
|
||||
|
||||
# 检查 router 中是否包含新增的路由
|
||||
api_router = router.api_router
|
||||
routes = [r.path for r in api_router.routes if hasattr(r, "path")]
|
||||
|
||||
# 验证基础路径是否存在
|
||||
assert any("/auth" in r for r in routes), "缺少认证相关路由 (/auth)"
|
||||
assert any("/users" in r for r in routes), "缺少用户管理路由 (/users)"
|
||||
assert any("/audit" in r for r in routes), "缺少审计日志路由 (/audit)"
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"路由配置检查失败: {e}")
|
||||
|
||||
|
||||
def test_main_app_initialization():
|
||||
"""检查 main.py 配置"""
|
||||
try:
|
||||
from app.main import app
|
||||
|
||||
assert app is not None
|
||||
assert app.title != ""
|
||||
|
||||
# 检查中间件 (简单检查是否存在)
|
||||
middleware_names = [m.cls.__name__ for m in app.user_middleware]
|
||||
# 检查是否包含审计中间件或其他关键中间件(根据实际类名修改)
|
||||
assert "AuditMiddleware" in middleware_names, "缺少审计中间件"
|
||||
|
||||
# 检查路由总数
|
||||
assert len(app.routes) > 0
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"main.py 配置检查失败: {e}")
|
||||
@@ -1,9 +1,12 @@
|
||||
"""
|
||||
测试加密功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
|
||||
def test_encryption():
|
||||
"""测试加密和解密功能"""
|
||||
@@ -33,5 +36,6 @@ def test_encryption():
|
||||
|
||||
print("\n✅ 所有加密测试通过!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_encryption()
|
||||
@@ -1,7 +1,11 @@
|
||||
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
"""
|
||||
tests.unit.test_pipeline_health_analyzer 的 Docstring
|
||||
"""
|
||||
|
||||
|
||||
def test_pipeline_health_analyzer():
|
||||
from app.algorithms.api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
|
||||
analyzer = PipelineHealthAnalyzer()
|
||||
# 创建示例输入数据(9个样本)
|
||||
@@ -51,7 +55,7 @@ def test_pipeline_health_analyzer():
|
||||
), "每个生存函数应包含x和y属性"
|
||||
|
||||
# 可选:测试绘图功能(不显示图表)
|
||||
analyzer.plot_survival(survival_functions, show_plot=True)
|
||||
analyzer.plot_survival(survival_functions, show_plot=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user