225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
"""
|
||
审计日志中间件
|
||
|
||
自动记录关键HTTP请求到审计日志
|
||
"""
|
||
|
||
import time
|
||
import json
|
||
from uuid import UUID
|
||
from typing import Callable
|
||
from fastapi import Request, Response
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from app.core.audit import log_audit_event, AuditAction
|
||
import logging
|
||
from jose import JWTError, jwt
|
||
|
||
from app.core.config import settings
|
||
from app.infra.db.metadata.database import SessionLocal
|
||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AuditMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
审计中间件
|
||
|
||
自动记录以下操作:
|
||
- 所有 POST/PUT/DELETE 请求
|
||
- 登录/登出
|
||
- 关键资源访问
|
||
"""
|
||
|
||
# 需要审计的路径前缀
|
||
AUDIT_PATHS = [
|
||
# "/api/v1/auth/",
|
||
# "/api/v1/users/",
|
||
# "/api/v1/projects/",
|
||
# "/api/v1/networks/",
|
||
]
|
||
|
||
# [新增] 需要审计的 API Tags (在 Router 或 api 函数中定义 tags=["Audit"])
|
||
AUDIT_TAGS = [
|
||
"Audit",
|
||
"Users",
|
||
"Project",
|
||
"Network General",
|
||
"Junctions",
|
||
"Pipes",
|
||
"Reservoirs",
|
||
"Tanks",
|
||
"Pumps",
|
||
"Valves",
|
||
]
|
||
|
||
# 需要审计的HTTP方法
|
||
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 提取开始时间
|
||
start_time = time.time()
|
||
|
||
# 1. 预判是否需要读取Body (针对写操作)
|
||
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
|
||
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
|
||
|
||
request_data = None
|
||
if should_capture_body:
|
||
try:
|
||
# 注意:读取 body 后需要重新设置,避免影响后续处理
|
||
body = await request.body()
|
||
if body:
|
||
request_data = json.loads(body.decode())
|
||
|
||
# 重新构造请求以供后续使用
|
||
async def receive():
|
||
return {"type": "http.request", "body": body}
|
||
|
||
request._receive = receive
|
||
except Exception as e:
|
||
logger.warning(f"Failed to read request body for audit: {e}")
|
||
|
||
# 2. 执行请求 (FastAPI在此过程中进行路由匹配)
|
||
response = await call_next(request)
|
||
|
||
# 3. 决定是否审计
|
||
# 检查方法
|
||
is_audit_method = request.method in self.AUDIT_METHODS
|
||
# 检查路径
|
||
is_audit_path = any(
|
||
request.url.path.startswith(path) for path in self.AUDIT_PATHS
|
||
)
|
||
# [新增] 检查 Tags (从 request.scope 中获取匹配的路由信息)
|
||
is_audit_tag = False
|
||
route = request.scope.get("route")
|
||
if route and hasattr(route, "tags"):
|
||
is_audit_tag = any(tag in self.AUDIT_TAGS for tag in route.tags)
|
||
|
||
should_audit = is_audit_method or is_audit_path or is_audit_tag
|
||
|
||
if not should_audit:
|
||
# 即便不审计,也要处理响应头中的时间(保持原有逻辑一致性)
|
||
process_time = time.time() - start_time
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
return response
|
||
|
||
# 4. 提取审计所需信息
|
||
user_id = await self._resolve_user_id(request)
|
||
project_id = self._resolve_project_id(request)
|
||
|
||
# 获取客户端信息
|
||
ip_address = request.client.host if request.client else None
|
||
|
||
# 确定操作类型
|
||
action = self._determine_action(request)
|
||
resource_type, resource_id = self._extract_resource_info(request)
|
||
|
||
# 记录审计日志
|
||
try:
|
||
await log_audit_event(
|
||
action=action,
|
||
user_id=user_id,
|
||
project_id=project_id,
|
||
resource_type=resource_type,
|
||
resource_id=resource_id,
|
||
ip_address=ip_address,
|
||
request_method=request.method,
|
||
request_path=str(request.url.path),
|
||
request_data=request_data,
|
||
response_status=response.status_code,
|
||
)
|
||
except Exception as e:
|
||
# 审计失败不应影响响应
|
||
logger.error(f"Failed to log audit event: {e}", exc_info=True)
|
||
|
||
# 添加处理时间到响应头
|
||
process_time = time.time() - start_time
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
|
||
return response
|
||
|
||
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||
project_header = request.headers.get("X-Project-Id")
|
||
if not project_header:
|
||
return None
|
||
try:
|
||
return UUID(project_header)
|
||
except ValueError:
|
||
return None
|
||
|
||
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
||
auth_header = request.headers.get("authorization")
|
||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||
return None
|
||
token = auth_header.split(" ", 1)[1].strip()
|
||
if not token:
|
||
return None
|
||
try:
|
||
key = (
|
||
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||
if settings.KEYCLOAK_PUBLIC_KEY
|
||
else settings.SECRET_KEY
|
||
)
|
||
algorithms = (
|
||
[settings.KEYCLOAK_ALGORITHM]
|
||
if settings.KEYCLOAK_PUBLIC_KEY
|
||
else [settings.ALGORITHM]
|
||
)
|
||
payload = jwt.decode(token, key, algorithms=algorithms)
|
||
sub = payload.get("sub")
|
||
if not sub:
|
||
return None
|
||
keycloak_id = UUID(sub)
|
||
except (JWTError, ValueError):
|
||
return None
|
||
|
||
async with SessionLocal() as session:
|
||
repo = MetadataRepository(session)
|
||
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||
if user and user.is_active:
|
||
return user.id
|
||
return None
|
||
|
||
def _determine_action(self, request: Request) -> str:
|
||
"""根据请求路径和方法确定操作类型"""
|
||
path = request.url.path.lower()
|
||
method = request.method
|
||
|
||
# 认证相关
|
||
if "login" in path:
|
||
return AuditAction.LOGIN
|
||
elif "logout" in path:
|
||
return AuditAction.LOGOUT
|
||
elif "register" in path:
|
||
return AuditAction.REGISTER
|
||
|
||
# CRUD 操作
|
||
if method == "POST":
|
||
return AuditAction.CREATE
|
||
elif method == "PUT" or method == "PATCH":
|
||
return AuditAction.UPDATE
|
||
elif method == "DELETE":
|
||
return AuditAction.DELETE
|
||
elif method == "GET":
|
||
return AuditAction.READ
|
||
|
||
return f"{method}_REQUEST"
|
||
|
||
def _extract_resource_info(self, request: Request) -> tuple:
|
||
"""从请求路径提取资源类型和ID"""
|
||
path_parts = request.url.path.strip("/").split("/")
|
||
|
||
resource_type = None
|
||
resource_id = None
|
||
|
||
# 尝试从路径中提取资源信息
|
||
# 例如: /api/v1/users/123 -> resource_type=user, resource_id=123
|
||
if len(path_parts) >= 4:
|
||
resource_type = path_parts[3].rstrip("s") # 移除复数s
|
||
|
||
if len(path_parts) >= 5 and path_parts[4].isdigit():
|
||
resource_id = path_parts[4]
|
||
|
||
return resource_type, resource_id
|