""" 审计日志中间件 自动记录关键HTTP请求到审计日志 """ import time import json from uuid import UUID from typing import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from app.core.audit import log_audit_event, AuditAction import logging from jose import JWTError, jwt from app.core.config import settings from app.infra.db.metadata.database import SessionLocal from app.infra.repositories.metadata_repository import MetadataRepository logger = logging.getLogger(__name__) class AuditMiddleware(BaseHTTPMiddleware): """ 审计中间件 自动记录以下操作: - 所有 POST/PUT/DELETE 请求 - 登录/登出 - 关键资源访问 """ # 需要审计的路径前缀 AUDIT_PATHS = [ # "/api/v1/auth/", # "/api/v1/users/", # "/api/v1/projects/", # "/api/v1/networks/", ] # [新增] 需要审计的 API Tags (在 Router 或 api 函数中定义 tags=["Audit"]) AUDIT_TAGS = [ "Audit", "Users", "Project", "Network General", "Junctions", "Pipes", "Reservoirs", "Tanks", "Pumps", "Valves", ] # 需要审计的HTTP方法 AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"] async def dispatch(self, request: Request, call_next: Callable) -> Response: # 提取开始时间 start_time = time.time() # 1. 预判是否需要读取Body (针对写操作) # 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag should_capture_body = request.method in ["POST", "PUT", "PATCH"] request_data = None if should_capture_body: try: # 注意:读取 body 后需要重新设置,避免影响后续处理 body = await request.body() if body: request_data = json.loads(body.decode()) # 重新构造请求以供后续使用 async def receive(): return {"type": "http.request", "body": body} request._receive = receive except Exception as e: logger.warning(f"Failed to read request body for audit: {e}") # 2. 执行请求 (FastAPI在此过程中进行路由匹配) response = await call_next(request) # 3. 决定是否审计 # 检查方法 is_audit_method = request.method in self.AUDIT_METHODS # 检查路径 is_audit_path = any( request.url.path.startswith(path) for path in self.AUDIT_PATHS ) # [新增] 检查 Tags (从 request.scope 中获取匹配的路由信息) is_audit_tag = False route = request.scope.get("route") if route and hasattr(route, "tags"): is_audit_tag = any(tag in self.AUDIT_TAGS for tag in route.tags) should_audit = is_audit_method or is_audit_path or is_audit_tag if not should_audit: # 即便不审计,也要处理响应头中的时间(保持原有逻辑一致性) process_time = time.time() - start_time response.headers["X-Process-Time"] = str(process_time) return response # 4. 提取审计所需信息 user_id = await self._resolve_user_id(request) project_id = self._resolve_project_id(request) # 获取客户端信息 ip_address = request.client.host if request.client else None # 确定操作类型 action = self._determine_action(request) resource_type, resource_id = self._extract_resource_info(request) # 记录审计日志 try: await log_audit_event( action=action, user_id=user_id, project_id=project_id, resource_type=resource_type, resource_id=resource_id, ip_address=ip_address, request_method=request.method, request_path=str(request.url.path), request_data=request_data, response_status=response.status_code, ) except Exception as e: # 审计失败不应影响响应 logger.error(f"Failed to log audit event: {e}", exc_info=True) # 添加处理时间到响应头 process_time = time.time() - start_time response.headers["X-Process-Time"] = str(process_time) return response def _resolve_project_id(self, request: Request) -> UUID | None: project_header = request.headers.get("X-Project-Id") if not project_header: return None try: return UUID(project_header) except ValueError: return None async def _resolve_user_id(self, request: Request) -> UUID | None: auth_header = request.headers.get("authorization") if not auth_header or not auth_header.lower().startswith("bearer "): return None token = auth_header.split(" ", 1)[1].strip() if not token: return None try: key = ( settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n") if settings.KEYCLOAK_PUBLIC_KEY else settings.SECRET_KEY ) algorithms = ( [settings.KEYCLOAK_ALGORITHM] if settings.KEYCLOAK_PUBLIC_KEY else [settings.ALGORITHM] ) payload = jwt.decode(token, key, algorithms=algorithms) sub = payload.get("sub") if not sub: return None keycloak_id = UUID(sub) except (JWTError, ValueError): return None async with SessionLocal() as session: repo = MetadataRepository(session) user = await repo.get_user_by_keycloak_id(keycloak_id) if user and user.is_active: return user.id return None def _determine_action(self, request: Request) -> str: """根据请求路径和方法确定操作类型""" path = request.url.path.lower() method = request.method # 认证相关 if "login" in path: return AuditAction.LOGIN elif "logout" in path: return AuditAction.LOGOUT elif "register" in path: return AuditAction.REGISTER # CRUD 操作 if method == "POST": return AuditAction.CREATE elif method == "PUT" or method == "PATCH": return AuditAction.UPDATE elif method == "DELETE": return AuditAction.DELETE elif method == "GET": return AuditAction.READ return f"{method}_REQUEST" def _extract_resource_info(self, request: Request) -> tuple: """从请求路径提取资源类型和ID""" path_parts = request.url.path.strip("/").split("/") resource_type = None resource_id = None # 尝试从路径中提取资源信息 # 例如: /api/v1/users/123 -> resource_type=user, resource_id=123 if len(path_parts) >= 4: resource_type = path_parts[3].rstrip("s") # 移除复数s if len(path_parts) >= 5 and path_parts[4].isdigit(): resource_id = path_parts[4] return resource_type, resource_id