初步实现数据加密、权限管理、日志审计等功能
This commit is contained in:
189
app/infra/audit/middleware.py
Normal file
189
app/infra/audit/middleware.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
审计日志中间件
|
||||
|
||||
自动记录关键HTTP请求到审计日志
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
审计中间件
|
||||
|
||||
自动记录以下操作:
|
||||
- 所有 POST/PUT/DELETE 请求
|
||||
- 登录/登出
|
||||
- 关键资源访问
|
||||
"""
|
||||
|
||||
# 需要审计的路径前缀
|
||||
AUDIT_PATHS = [
|
||||
# "/api/v1/auth/",
|
||||
# "/api/v1/users/",
|
||||
# "/api/v1/projects/",
|
||||
# "/api/v1/networks/",
|
||||
]
|
||||
|
||||
# [新增] 需要审计的 API Tags (在 Router 或 api 函数中定义 tags=["Audit"])
|
||||
AUDIT_TAGS = [
|
||||
"Audit",
|
||||
"Users",
|
||||
"Project",
|
||||
"Network General",
|
||||
"Junctions",
|
||||
"Pipes",
|
||||
"Reservoirs",
|
||||
"Tanks",
|
||||
"Pumps",
|
||||
"Valves",
|
||||
]
|
||||
|
||||
# 需要审计的HTTP方法
|
||||
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 提取开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 预判是否需要读取Body (针对写操作)
|
||||
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
|
||||
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
|
||||
|
||||
request_data = None
|
||||
if should_capture_body:
|
||||
try:
|
||||
# 注意:读取 body 后需要重新设置,避免影响后续处理
|
||||
body = await request.body()
|
||||
if body:
|
||||
request_data = json.loads(body.decode())
|
||||
|
||||
# 重新构造请求以供后续使用
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body}
|
||||
|
||||
request._receive = receive
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read request body for audit: {e}")
|
||||
|
||||
# 2. 执行请求 (FastAPI在此过程中进行路由匹配)
|
||||
response = await call_next(request)
|
||||
|
||||
# 3. 决定是否审计
|
||||
# 检查方法
|
||||
is_audit_method = request.method in self.AUDIT_METHODS
|
||||
# 检查路径
|
||||
is_audit_path = any(
|
||||
request.url.path.startswith(path) for path in self.AUDIT_PATHS
|
||||
)
|
||||
# [新增] 检查 Tags (从 request.scope 中获取匹配的路由信息)
|
||||
is_audit_tag = False
|
||||
route = request.scope.get("route")
|
||||
if route and hasattr(route, "tags"):
|
||||
is_audit_tag = any(tag in self.AUDIT_TAGS for tag in route.tags)
|
||||
|
||||
should_audit = is_audit_method or is_audit_path or is_audit_tag
|
||||
|
||||
if not should_audit:
|
||||
# 即便不审计,也要处理响应头中的时间(保持原有逻辑一致性)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
|
||||
# 4. 提取审计所需信息
|
||||
user_id = None
|
||||
username = None
|
||||
|
||||
# 尝试从请求状态获取当前用户
|
||||
if hasattr(request.state, "user"):
|
||||
user = request.state.user
|
||||
user_id = getattr(user, "id", None)
|
||||
username = getattr(user, "username", None)
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
# 确定操作类型
|
||||
action = self._determine_action(request)
|
||||
resource_type, resource_id = self._extract_resource_info(request)
|
||||
|
||||
# 记录审计日志
|
||||
try:
|
||||
await log_audit_event(
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_method=request.method,
|
||||
request_path=str(request.url.path),
|
||||
request_data=request_data,
|
||||
response_status=response.status_code,
|
||||
error_message=(
|
||||
None
|
||||
if response.status_code < 400
|
||||
else f"HTTP {response.status_code}"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# 审计失败不应影响响应
|
||||
logger.error(f"Failed to log audit event: {e}", exc_info=True)
|
||||
|
||||
# 添加处理时间到响应头
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
||||
|
||||
def _determine_action(self, request: Request) -> str:
|
||||
"""根据请求路径和方法确定操作类型"""
|
||||
path = request.url.path.lower()
|
||||
method = request.method
|
||||
|
||||
# 认证相关
|
||||
if "login" in path:
|
||||
return AuditAction.LOGIN
|
||||
elif "logout" in path:
|
||||
return AuditAction.LOGOUT
|
||||
elif "register" in path:
|
||||
return AuditAction.REGISTER
|
||||
|
||||
# CRUD 操作
|
||||
if method == "POST":
|
||||
return AuditAction.CREATE
|
||||
elif method == "PUT" or method == "PATCH":
|
||||
return AuditAction.UPDATE
|
||||
elif method == "DELETE":
|
||||
return AuditAction.DELETE
|
||||
elif method == "GET":
|
||||
return AuditAction.READ
|
||||
|
||||
return f"{method}_REQUEST"
|
||||
|
||||
def _extract_resource_info(self, request: Request) -> tuple:
|
||||
"""从请求路径提取资源类型和ID"""
|
||||
path_parts = request.url.path.strip("/").split("/")
|
||||
|
||||
resource_type = None
|
||||
resource_id = None
|
||||
|
||||
# 尝试从路径中提取资源信息
|
||||
# 例如: /api/v1/users/123 -> resource_type=user, resource_id=123
|
||||
if len(path_parts) >= 4:
|
||||
resource_type = path_parts[3].rstrip("s") # 移除复数s
|
||||
|
||||
if len(path_parts) >= 5 and path_parts[4].isdigit():
|
||||
resource_id = path_parts[4]
|
||||
|
||||
return resource_type, resource_id
|
||||
Reference in New Issue
Block a user