初步实现数据加密、权限管理、日志审计等功能
This commit is contained in:
189
app/infra/audit/middleware.py
Normal file
189
app/infra/audit/middleware.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
审计日志中间件
|
||||
|
||||
自动记录关键HTTP请求到审计日志
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
审计中间件
|
||||
|
||||
自动记录以下操作:
|
||||
- 所有 POST/PUT/DELETE 请求
|
||||
- 登录/登出
|
||||
- 关键资源访问
|
||||
"""
|
||||
|
||||
# 需要审计的路径前缀
|
||||
AUDIT_PATHS = [
|
||||
# "/api/v1/auth/",
|
||||
# "/api/v1/users/",
|
||||
# "/api/v1/projects/",
|
||||
# "/api/v1/networks/",
|
||||
]
|
||||
|
||||
# [新增] 需要审计的 API Tags (在 Router 或 api 函数中定义 tags=["Audit"])
|
||||
AUDIT_TAGS = [
|
||||
"Audit",
|
||||
"Users",
|
||||
"Project",
|
||||
"Network General",
|
||||
"Junctions",
|
||||
"Pipes",
|
||||
"Reservoirs",
|
||||
"Tanks",
|
||||
"Pumps",
|
||||
"Valves",
|
||||
]
|
||||
|
||||
# 需要审计的HTTP方法
|
||||
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 提取开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 预判是否需要读取Body (针对写操作)
|
||||
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
|
||||
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
|
||||
|
||||
request_data = None
|
||||
if should_capture_body:
|
||||
try:
|
||||
# 注意:读取 body 后需要重新设置,避免影响后续处理
|
||||
body = await request.body()
|
||||
if body:
|
||||
request_data = json.loads(body.decode())
|
||||
|
||||
# 重新构造请求以供后续使用
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body}
|
||||
|
||||
request._receive = receive
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read request body for audit: {e}")
|
||||
|
||||
# 2. 执行请求 (FastAPI在此过程中进行路由匹配)
|
||||
response = await call_next(request)
|
||||
|
||||
# 3. 决定是否审计
|
||||
# 检查方法
|
||||
is_audit_method = request.method in self.AUDIT_METHODS
|
||||
# 检查路径
|
||||
is_audit_path = any(
|
||||
request.url.path.startswith(path) for path in self.AUDIT_PATHS
|
||||
)
|
||||
# [新增] 检查 Tags (从 request.scope 中获取匹配的路由信息)
|
||||
is_audit_tag = False
|
||||
route = request.scope.get("route")
|
||||
if route and hasattr(route, "tags"):
|
||||
is_audit_tag = any(tag in self.AUDIT_TAGS for tag in route.tags)
|
||||
|
||||
should_audit = is_audit_method or is_audit_path or is_audit_tag
|
||||
|
||||
if not should_audit:
|
||||
# 即便不审计,也要处理响应头中的时间(保持原有逻辑一致性)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
|
||||
# 4. 提取审计所需信息
|
||||
user_id = None
|
||||
username = None
|
||||
|
||||
# 尝试从请求状态获取当前用户
|
||||
if hasattr(request.state, "user"):
|
||||
user = request.state.user
|
||||
user_id = getattr(user, "id", None)
|
||||
username = getattr(user, "username", None)
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
# 确定操作类型
|
||||
action = self._determine_action(request)
|
||||
resource_type, resource_id = self._extract_resource_info(request)
|
||||
|
||||
# 记录审计日志
|
||||
try:
|
||||
await log_audit_event(
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_method=request.method,
|
||||
request_path=str(request.url.path),
|
||||
request_data=request_data,
|
||||
response_status=response.status_code,
|
||||
error_message=(
|
||||
None
|
||||
if response.status_code < 400
|
||||
else f"HTTP {response.status_code}"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# 审计失败不应影响响应
|
||||
logger.error(f"Failed to log audit event: {e}", exc_info=True)
|
||||
|
||||
# 添加处理时间到响应头
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
||||
|
||||
def _determine_action(self, request: Request) -> str:
|
||||
"""根据请求路径和方法确定操作类型"""
|
||||
path = request.url.path.lower()
|
||||
method = request.method
|
||||
|
||||
# 认证相关
|
||||
if "login" in path:
|
||||
return AuditAction.LOGIN
|
||||
elif "logout" in path:
|
||||
return AuditAction.LOGOUT
|
||||
elif "register" in path:
|
||||
return AuditAction.REGISTER
|
||||
|
||||
# CRUD 操作
|
||||
if method == "POST":
|
||||
return AuditAction.CREATE
|
||||
elif method == "PUT" or method == "PATCH":
|
||||
return AuditAction.UPDATE
|
||||
elif method == "DELETE":
|
||||
return AuditAction.DELETE
|
||||
elif method == "GET":
|
||||
return AuditAction.READ
|
||||
|
||||
return f"{method}_REQUEST"
|
||||
|
||||
def _extract_resource_info(self, request: Request) -> tuple:
|
||||
"""从请求路径提取资源类型和ID"""
|
||||
path_parts = request.url.path.strip("/").split("/")
|
||||
|
||||
resource_type = None
|
||||
resource_id = None
|
||||
|
||||
# 尝试从路径中提取资源信息
|
||||
# 例如: /api/v1/users/123 -> resource_type=user, resource_id=123
|
||||
if len(path_parts) >= 4:
|
||||
resource_type = path_parts[3].rstrip("s") # 移除复数s
|
||||
|
||||
if len(path_parts) >= 5 and path_parts[4].isdigit():
|
||||
resource_id = path_parts[4]
|
||||
|
||||
return resource_type, resource_id
|
||||
220
app/infra/repositories/audit_repository.py
Normal file
220
app/infra/repositories/audit_repository.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuditRepository:
|
||||
"""审计日志数据访问层"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
async def create_log(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
action: str = "",
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
error_message: Optional[str] = None
|
||||
) -> Optional[AuditLogResponse]:
|
||||
"""
|
||||
创建审计日志
|
||||
|
||||
Args:
|
||||
参数说明见 AuditLogCreate
|
||||
|
||||
Returns:
|
||||
创建的审计日志对象
|
||||
"""
|
||||
query = """
|
||||
INSERT INTO audit_logs (
|
||||
user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message
|
||||
)
|
||||
VALUES (
|
||||
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
|
||||
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
|
||||
%(request_data)s, %(response_status)s, %(error_message)s
|
||||
)
|
||||
RETURNING id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {
|
||||
'user_id': user_id,
|
||||
'username': username,
|
||||
'action': action,
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'ip_address': ip_address,
|
||||
'user_agent': user_agent,
|
||||
'request_method': request_method,
|
||||
'request_path': request_path,
|
||||
'request_data': json.dumps(request_data) if request_data else None,
|
||||
'response_status': response_status,
|
||||
'error_message': error_message
|
||||
})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return AuditLogResponse(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating audit log: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志
|
||||
|
||||
Args:
|
||||
user_id: 用户ID过滤
|
||||
username: 用户名过滤
|
||||
action: 操作类型过滤
|
||||
resource_type: 资源类型过滤
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
skip: 跳过记录数
|
||||
limit: 限制记录数
|
||||
|
||||
Returns:
|
||||
审计日志列表
|
||||
"""
|
||||
# 构建动态查询
|
||||
conditions = []
|
||||
params = {'skip': skip, 'limit': limit}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT %(limit)s OFFSET %(skip)s
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
rows = await cur.fetchall()
|
||||
return [AuditLogResponse(**row) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying audit logs: {e}")
|
||||
raise
|
||||
|
||||
async def get_log_count(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None
|
||||
) -> int:
|
||||
"""
|
||||
获取审计日志数量
|
||||
|
||||
Args:
|
||||
参数同 get_logs
|
||||
|
||||
Returns:
|
||||
日志总数
|
||||
"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
result = await cur.fetchone()
|
||||
return result['count'] if result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting audit logs: {e}")
|
||||
return 0
|
||||
235
app/infra/repositories/user_repository.py
Normal file
235
app/infra/repositories/user_repository.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.domain.schemas.user import UserCreate, UserUpdate, UserInDB
|
||||
from app.domain.models.role import UserRole
|
||||
from app.core.security import get_password_hash
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserRepository:
|
||||
"""用户数据访问层"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
async def create_user(self, user: UserCreate) -> Optional[UserInDB]:
|
||||
"""
|
||||
创建新用户
|
||||
|
||||
Args:
|
||||
user: 用户创建数据
|
||||
|
||||
Returns:
|
||||
创建的用户对象
|
||||
"""
|
||||
hashed_password = get_password_hash(user.password)
|
||||
|
||||
query = """
|
||||
INSERT INTO users (username, email, hashed_password, role, is_active, is_superuser)
|
||||
VALUES (%(username)s, %(email)s, %(hashed_password)s, %(role)s, TRUE, FALSE)
|
||||
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {
|
||||
'username': user.username,
|
||||
'email': user.email,
|
||||
'hashed_password': hashed_password,
|
||||
'role': user.role.value
|
||||
})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> Optional[UserInDB]:
|
||||
"""根据ID获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE id = %(user_id)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'user_id': user_id})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[UserInDB]:
|
||||
"""根据用户名获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE username = %(username)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'username': username})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[UserInDB]:
|
||||
"""根据邮箱获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE email = %(email)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'email': email})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_all_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
||||
"""获取所有用户(分页)"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %(limit)s OFFSET %(skip)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'skip': skip, 'limit': limit})
|
||||
rows = await cur.fetchall()
|
||||
return [UserInDB(**row) for row in rows]
|
||||
|
||||
async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[UserInDB]:
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
user_update: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的用户对象
|
||||
"""
|
||||
# 构建动态更新语句
|
||||
update_fields = []
|
||||
params = {'user_id': user_id}
|
||||
|
||||
if user_update.email is not None:
|
||||
update_fields.append("email = %(email)s")
|
||||
params['email'] = user_update.email
|
||||
|
||||
if user_update.password is not None:
|
||||
update_fields.append("hashed_password = %(hashed_password)s")
|
||||
params['hashed_password'] = get_password_hash(user_update.password)
|
||||
|
||||
if user_update.role is not None:
|
||||
update_fields.append("role = %(role)s")
|
||||
params['role'] = user_update.role.value
|
||||
|
||||
if user_update.is_active is not None:
|
||||
update_fields.append("is_active = %(is_active)s")
|
||||
params['is_active'] = user_update.is_active
|
||||
|
||||
if not update_fields:
|
||||
return await self.get_user_by_id(user_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE users
|
||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %(user_id)s
|
||||
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
async def delete_user(self, user_id: int) -> bool:
|
||||
"""
|
||||
删除用户
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
query = "DELETE FROM users WHERE id = %(user_id)s"
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'user_id': user_id})
|
||||
return cur.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def user_exists(self, username: str = None, email: str = None) -> bool:
|
||||
"""
|
||||
检查用户是否存在
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
email: 邮箱
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
if email:
|
||||
conditions.append("email = %(email)s")
|
||||
params['email'] = email
|
||||
|
||||
if not conditions:
|
||||
return False
|
||||
|
||||
query = f"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM users WHERE {' OR '.join(conditions)}
|
||||
)
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
result = await cur.fetchone()
|
||||
return result['exists'] if result else False
|
||||
Reference in New Issue
Block a user