测试并修复api导入路径错误

This commit is contained in:
2026-02-02 11:09:43 +08:00
parent 807e634318
commit 35abaa1ebb
13 changed files with 211 additions and 250 deletions

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Request, Depends
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.api.v1.endpoints.auth import verify_token
from app.auth.dependencies import get_current_user as verify_token
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
import msgpack

View File

@@ -43,7 +43,7 @@ from app.infra.db.timescaledb import router as timescaledb_router
api_router = APIRouter()
# Core Services
api_router.include_router(auth.router, tags=["Auth"])
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
api_router.include_router(project.router, tags=["Project"])

View File

@@ -9,31 +9,34 @@ from app.infra.db.postgresql.database import Database
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
# 数据库依赖
async def get_db(request: Request) -> Database:
"""
获取数据库实例
从 FastAPI app.state 中获取在启动时初始化的数据库连接
"""
if not hasattr(request.app.state, "db"):
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database not initialized"
detail="Database not initialized",
)
return request.app.state.db
async def get_user_repository(db: Database = Depends(get_db)) -> UserRepository:
"""获取用户仓储实例"""
return UserRepository(db)
async def get_current_user(
token: str = Depends(oauth2_scheme),
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> UserInDB:
"""
获取当前登录用户
从 JWT Token 中解析用户信息,并从数据库验证
"""
credentials_exception = HTTPException(
@@ -41,32 +44,35 @@ async def get_current_user(
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
token_type: str = payload.get("type", "access")
if username is None:
raise credentials_exception
if token_type != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type. Access token required.",
headers={"WWW-Authenticate": "Bearer"},
)
except JWTError:
raise credentials_exception
# 从数据库获取用户
user = await user_repo.get_user_by_username(username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: UserInDB = Depends(get_current_user),
) -> UserInDB:
@@ -75,11 +81,11 @@ async def get_current_active_user(
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return current_user
async def get_current_superuser(
current_user: UserInDB = Depends(get_current_user),
) -> UserInDB:
@@ -89,6 +95,6 @@ async def get_current_superuser(
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough privileges. Superuser access required."
detail="Not enough privileges. Superuser access required.",
)
return current_user

View File

@@ -32,5 +32,6 @@ class Settings(BaseSettings):
class Config:
env_file = ".env"
extra = "ignore"
settings = Settings()

View File

@@ -1,77 +1,90 @@
from datetime import datetime, timedelta
from typing import Optional, Union, Any
from jose import jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
def create_access_token(
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
"""
创建 JWT Access Token
Args:
subject: 用户标识通常是用户名或用户ID
expires_delta: 过期时间增量
Returns:
JWT token 字符串
"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
expire = datetime.now() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {
"exp": expire,
"sub": str(subject),
"type": "access",
"iat": datetime.utcnow()
"iat": datetime.now(),
}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(subject: Union[str, Any]) -> str:
"""
创建 JWT Refresh Token长期有效
Args:
subject: 用户标识
Returns:
JWT refresh token 字符串
"""
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"exp": expire,
"sub": str(subject),
"type": "refresh",
"iat": datetime.utcnow()
"iat": datetime.now(),
}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
验证密码
Args:
plain_password: 明文密码
hashed_password: 密码哈希
Returns:
是否匹配
"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""
生成密码哈希
Args:
password: 明文密码
Returns:
bcrypt 哈希字符串
"""

View File

@@ -13,8 +13,8 @@ from typing import List, Dict
from datetime import datetime, timedelta, timezone
from influxdb_client.client.write_api import SYNCHRONOUS, ASYNCHRONOUS
from dateutil import parser
import get_realValue
import get_data
# import get_realValue
# import get_data
import psycopg
import time
import app.services.simulation as simulation

View File

@@ -71,7 +71,7 @@ app.add_middleware(GZipMiddleware, minimum_size=1000)
# 添加审计中间件(可选,记录关键操作)
# 如果需要启用审计日志,取消下面的注释
# app.add_middleware(AuditMiddleware)
app.add_middleware(AuditMiddleware)
# Include Routers
app.include_router(api_router, prefix="/api/v1")

View File

@@ -7,18 +7,19 @@ import csv
# get_data 是用来获取 历史数据,也就是非实时数据的接口
# get_realtime 是用来获取 实时数据
def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime:
# 将毫秒级时间戳转换为秒级时间戳
timestamp_seconds = timestamp / 1000
# 将时间戳转换为datetime对象
utc_time = datetime.utcfromtimestamp(timestamp_seconds)
utc_time = datetime.fromtimestamp(timestamp_seconds)
# 设定UTC时区
utc_timezone = pytz.timezone('UTC')
utc_timezone = pytz.timezone("UTC")
# 转换为北京时间
beijing_timezone = pytz.timezone('Asia/Shanghai')
beijing_timezone = pytz.timezone("Asia/Shanghai")
beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone)
return beijing_time
@@ -26,10 +27,10 @@ def convert_timestamp_to_beijing_time(timestamp: Union[int, float]) -> datetime:
def beijing_time_to_utc(beijing_time_str: str) -> str:
# 定义北京时区
beijing_timezone = pytz.timezone('Asia/Shanghai')
beijing_timezone = pytz.timezone("Asia/Shanghai")
# 将字符串转换为datetime对象
beijing_time = datetime.strptime(beijing_time_str, '%Y-%m-%d %H:%M:%S')
beijing_time = datetime.strptime(beijing_time_str, "%Y-%m-%d %H:%M:%S")
# 本地化时间对象
beijing_time = beijing_timezone.localize(beijing_time)
@@ -38,29 +39,31 @@ def beijing_time_to_utc(beijing_time_str: str) -> str:
utc_time = beijing_time.astimezone(pytz.utc)
# 转换为ISO 8601格式的字符串
return utc_time.strftime('%Y-%m-%dT%H:%M:%SZ')
return utc_time.strftime("%Y-%m-%dT%H:%M:%SZ")
def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> List[Dict[str, Union[str, datetime, int, float]]]:
# def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None:
def get_history_data(
ids: str, begin_date: str, end_date: str, downsample: Optional[str]
) -> List[Dict[str, Union[str, datetime, int, float]]]:
# def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optional[str]) -> None:
# 转换输入的北京时间为UTC时间
begin_date_utc = beijing_time_to_utc(begin_date)
end_date_utc = beijing_time_to_utc(end_date)
# 数据接口的地址
url = 'http://183.64.62.100:9057/loong/api/curves/data'
url = "http://183.64.62.100:9057/loong/api/curves/data"
# url = 'http://10.101.15.16:9000/loong/api/curves/data'
# url_path = 'http://10.101.15.16:9000/loong' # 内网
# 设置 GET 请求的参数
params = {
'ids': ids,
'beginDate': begin_date_utc,
'endDate': end_date_utc,
'downsample': downsample
"ids": ids,
"beginDate": begin_date_utc,
"endDate": end_date_utc,
"downsample": downsample,
}
history_data_list =[]
history_data_list = []
try:
# 发送 GET 请求获取数据
@@ -73,24 +76,26 @@ def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optio
# 这里可以对获取到的数据进行进一步处理
# 打印 'mpointId' 和 'mpointName'
for item in data['items']:
mpoint_id = str(item['mpointId'])
mpoint_name = item['mpointName']
for item in data["items"]:
mpoint_id = str(item["mpointId"])
mpoint_name = item["mpointName"]
# print("mpointId:", item['mpointId'])
# print("mpointName:", item['mpointName'])
# 打印 'dataDate' 和 'dataValue'
for item_data in item['data']:
for item_data in item["data"]:
# 将时间戳转换为北京时间
beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate'])
data_value = item_data['dataValue']
beijing_time = convert_timestamp_to_beijing_time(
item_data["dataDate"]
)
data_value = item_data["dataValue"]
# 创建一个字典存储每条数据
data_dict = {
'time': beijing_time,
'device_ID': str(mpoint_id),
'description': mpoint_name,
"time": beijing_time,
"device_ID": str(mpoint_id),
"description": mpoint_name,
# 'dataDate (Beijing Time)': beijing_time.strftime('%Y-%m-%d %H:%M:%S'),
'monitored_value': data_value # 保留原有类型
"monitored_value": data_value, # 保留原有类型
}
history_data_list.append(data_dict)
@@ -164,4 +169,4 @@ def get_history_data(ids: str, begin_date: str, end_date: str, downsample: Optio
# for data in data_list1:
# writer.writerow([data['measurement'], data['mpointId'], data['date'], data['dataValue'], data['datetime']])
#
# print(f"筛选后的数据已保存到 {filtered_csv_file_path}")
# print(f"筛选后的数据已保存到 {filtered_csv_file_path}")

View File

@@ -8,35 +8,33 @@ def convert_timestamp_to_beijing_time(timestamp):
timestamp_seconds = timestamp / 1000
# 将时间戳转换为datetime对象
utc_time = datetime.utcfromtimestamp(timestamp_seconds)
utc_time = datetime.fromtimestamp(timestamp_seconds)
# 设定UTC时区
utc_timezone = pytz.timezone('UTC')
utc_timezone = pytz.timezone("UTC")
# 转换为北京时间
beijing_timezone = pytz.timezone('Asia/Shanghai')
beijing_timezone = pytz.timezone("Asia/Shanghai")
beijing_time = utc_time.replace(tzinfo=utc_timezone).astimezone(beijing_timezone)
return beijing_time
def conver_beingtime_to_ucttime(timestr:str):
beijing_time=datetime.strptime(timestr,'%Y-%m-%d %H:%M:%S')
utc_time=beijing_time.astimezone(pytz.utc)
str_utc=utc_time.strftime('%Y-%m-%dT%H:%M:%SZ')
#print(str_utc)
def conver_beingtime_to_ucttime(timestr: str):
beijing_time = datetime.strptime(timestr, "%Y-%m-%d %H:%M:%S")
utc_time = beijing_time.astimezone(pytz.utc)
str_utc = utc_time.strftime("%Y-%m-%dT%H:%M:%SZ")
# print(str_utc)
return str_utc
def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]:
def get_hist_data(ids, begin_date, end_date) -> dict[str, dict[datetime, float]]:
# 数据接口的地址
url = 'http://183.64.62.100:9057/loong/api/curves/data'
url = "http://183.64.62.100:9057/loong/api/curves/data"
# 设置 GET 请求的参数
params = {
'ids': ids,
'beginDate': begin_date,
'endDate': end_date
}
lst_data={}
params = {"ids": ids, "beginDate": begin_date, "endDate": end_date}
lst_data = {}
try:
# 发送 GET 请求获取数据
response = requests.get(url, params=params)
@@ -48,22 +46,27 @@ def get_hist_data(ids, begin_date,end_date)->dict[str,dict[datetime,float]]:
# 这里可以对获取到的数据进行进一步处理
# 打印 'mpointId' 和 'mpointName'
for item in data['items']:
#print("mpointId:", item['mpointId'])
#print("mpointName:", item['mpointName'])
for item in data["items"]:
# print("mpointId:", item['mpointId'])
# print("mpointName:", item['mpointName'])
# 打印 'dataDate' 和 'dataValue'
data_seriers={}
for item_data in item['data']:
data_seriers = {}
for item_data in item["data"]:
# print("dataDate:", item_data['dataDate'])
# 将时间戳转换为北京时间
beijing_time = convert_timestamp_to_beijing_time(item_data['dataDate'])
print("dataDate (Beijing Time):", beijing_time.strftime('%Y-%m-%d %H:%M:%S'))
print("dataValue:", item_data['dataValue'])
beijing_time = convert_timestamp_to_beijing_time(
item_data["dataDate"]
)
print(
"dataDate (Beijing Time):",
beijing_time.strftime("%Y-%m-%d %H:%M:%S"),
)
print("dataValue:", item_data["dataValue"])
print() # 打印空行分隔不同条目
r=float(item_data['dataValue'])
data_seriers[beijing_time]=r
lst_data[item['mpointId']]=data_seriers
r = float(item_data["dataValue"])
data_seriers[beijing_time] = r
lst_data[item["mpointId"]] = data_seriers
return lst_data
else:
# 如果请求不成功,打印错误信息

View File

@@ -1,152 +0,0 @@
#!/usr/bin/env python
"""
测试新增 API 集成
验证新的认证、用户管理和审计日志接口是否正确集成
"""
import sys
import subprocess
import time
def check_imports():
"""检查关键模块是否可以导入"""
print("=" * 60)
print("步骤 1: 检查模块导入")
print("=" * 60)
modules = [
("app.core.encryption", "加密模块"),
("app.core.security", "安全模块"),
("app.core.audit", "审计模块"),
("app.domain.models.role", "角色模型"),
("app.domain.schemas.user", "用户Schema"),
("app.domain.schemas.audit", "审计Schema"),
("app.auth.permissions", "权限控制"),
("app.api.v1.endpoints.auth", "认证接口"),
("app.api.v1.endpoints.user_management", "用户管理接口"),
("app.api.v1.endpoints.audit", "审计日志接口"),
("app.infra.repositories.user_repository", "用户仓储"),
("app.infra.repositories.audit_repository", "审计仓储"),
("app.infra.audit.middleware", "审计中间件"),
]
success = 0
failed = 0
for module_name, desc in modules:
try:
__import__(module_name)
print(f"{desc:20s} ({module_name})")
success += 1
except Exception as e:
print(f"{desc:20s} ({module_name})")
print(f" 错误: {e}")
failed += 1
print(f"\n结果: {success} 成功, {failed} 失败")
print()
return failed == 0
def check_router():
"""检查路由配置"""
print("=" * 60)
print("步骤 2: 检查路由配置")
print("=" * 60)
try:
from app.api.v1 import router
from app.api.v1.endpoints import auth, user_management, audit
print("✓ router 模块已导入")
print("✓ auth 端点已导入")
print("✓ user_management 端点已导入")
print("✓ audit 端点已导入")
# 检查 router 中是否包含新增的路由
api_router = router.api_router
print(f"\n已注册的路由数量: {len(api_router.routes)}")
# 查找新增的路由
auth_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/auth' in r.path]
user_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/users' in r.path]
audit_routes = [r for r in api_router.routes if hasattr(r, 'path') and '/audit' in r.path]
print(f"认证相关路由: {len(auth_routes)}")
print(f"用户管理路由: {len(user_routes)}")
print(f"审计日志路由: {len(audit_routes)}")
return True
except Exception as e:
print(f"✗ 路由配置检查失败: {e}")
import traceback
traceback.print_exc()
return False
def check_main_app():
"""检查 main.py 配置"""
print("\n" + "=" * 60)
print("步骤 3: 检查 main.py 配置")
print("=" * 60)
try:
from app.main import app
print("✓ FastAPI app 已创建")
print(f" 标题: {app.title}")
print(f" 版本: {app.version}")
# 检查中间件
middleware_names = [m.__class__.__name__ for m in app.user_middleware]
print(f"\n已注册的中间件: {len(middleware_names)}")
for name in middleware_names:
print(f" - {name}")
# 检查路由
print(f"\n已注册的路由: {len(app.routes)}")
return True
except Exception as e:
print(f"✗ main.py 配置检查失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
print("\n🔍 TJWater Server API 集成测试\n")
results = []
# 测试 1: 模块导入
results.append(("模块导入", check_imports()))
# 测试 2: 路由配置
results.append(("路由配置", check_router()))
# 测试 3: main.py
results.append(("main.py配置", check_main_app()))
# 总结
print("\n" + "=" * 60)
print("测试总结")
print("=" * 60)
for name, success in results:
status = "✓ 通过" if success else "✗ 失败"
print(f"{status:8s} - {name}")
all_passed = all(success for _, success in results)
if all_passed:
print("\n✅ 所有测试通过!")
print("\n下一步:")
print(" 1. 确保数据库迁移已执行")
print(" 2. 配置 .env 文件")
print(" 3. 启动服务: uvicorn app.main:app --reload")
print(" 4. 访问文档: http://localhost:8000/docs")
return 0
else:
print("\n❌ 部分测试失败,请检查错误信息")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python
"""
测试新增 API 集成
验证新的认证、用户管理和审计日志接口是否正确集成
"""
import sys
import os
import pytest
# 将项目根目录添加到 sys.path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
@pytest.mark.parametrize(
"module_name, desc",
[
("app.core.encryption", "加密模块"),
("app.core.security", "安全模块"),
("app.core.audit", "审计模块"),
("app.domain.models.role", "角色模型"),
("app.domain.schemas.user", "用户Schema"),
("app.domain.schemas.audit", "审计Schema"),
("app.auth.permissions", "权限控制"),
("app.api.v1.endpoints.auth", "认证接口"),
("app.api.v1.endpoints.user_management", "用户管理接口"),
("app.api.v1.endpoints.audit", "审计日志接口"),
("app.infra.repositories.user_repository", "用户仓储"),
("app.infra.repositories.audit_repository", "审计仓储"),
("app.infra.audit.middleware", "审计中间件"),
],
)
def test_module_imports(module_name, desc):
"""检查关键模块是否可以导入"""
try:
__import__(module_name)
except ImportError as e:
pytest.fail(f"无法导入 {desc} ({module_name}): {e}")
def test_router_configuration():
"""检查路由配置"""
try:
from app.api.v1 import router
# 检查 router 中是否包含新增的路由
api_router = router.api_router
routes = [r.path for r in api_router.routes if hasattr(r, "path")]
# 验证基础路径是否存在
assert any("/auth" in r for r in routes), "缺少认证相关路由 (/auth)"
assert any("/users" in r for r in routes), "缺少用户管理路由 (/users)"
assert any("/audit" in r for r in routes), "缺少审计日志路由 (/audit)"
except Exception as e:
pytest.fail(f"路由配置检查失败: {e}")
def test_main_app_initialization():
"""检查 main.py 配置"""
try:
from app.main import app
assert app is not None
assert app.title != ""
# 检查中间件 (简单检查是否存在)
middleware_names = [m.cls.__name__ for m in app.user_middleware]
# 检查是否包含审计中间件或其他关键中间件(根据实际类名修改)
assert "AuditMiddleware" in middleware_names, "缺少审计中间件"
# 检查路由总数
assert len(app.routes) > 0
except Exception as e:
pytest.fail(f"main.py 配置检查失败: {e}")

View File

@@ -1,37 +1,41 @@
"""
测试加密功能
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
def test_encryption():
"""测试加密和解密功能"""
from app.core.encryption import Encryptor
# 生成测试密钥
key = Encryptor.generate_key()
print(f"✓ 生成密钥: {key}")
# 创建加密器
encryptor = Encryptor(key=key.encode())
# 测试加密
test_data = "这是敏感数据 - 数据库密码: password123"
encrypted = encryptor.encrypt(test_data)
print(f"✓ 加密成功: {encrypted[:50]}...")
# 测试解密
decrypted = encryptor.decrypt(encrypted)
assert decrypted == test_data, "解密数据不匹配!"
print(f"✓ 解密成功: {decrypted}")
# 测试空数据
assert encryptor.encrypt("") == ""
assert encryptor.decrypt("") == ""
print("✓ 空数据处理正确")
print("\n✅ 所有加密测试通过!")
if __name__ == "__main__":
test_encryption()

View File

@@ -1,7 +1,11 @@
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
"""
tests.unit.test_pipeline_health_analyzer 的 Docstring
"""
def test_pipeline_health_analyzer():
from app.algorithms.api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
analyzer = PipelineHealthAnalyzer()
# 创建示例输入数据9个样本
@@ -51,7 +55,7 @@ def test_pipeline_health_analyzer():
), "每个生存函数应包含x和y属性"
# 可选:测试绘图功能(不显示图表)
analyzer.plot_survival(survival_functions, show_plot=True)
analyzer.plot_survival(survival_functions, show_plot=False)
if __name__ == "__main__":