225 lines
7.3 KiB
Python
225 lines
7.3 KiB
Python
"""
|
|
审计日志中间件
|
|
|
|
自动记录关键HTTP请求到审计日志
|
|
"""
|
|
|
|
import time
|
|
import json
|
|
from uuid import UUID
|
|
from typing import Callable
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from app.core.audit import log_audit_event, AuditAction
|
|
import logging
|
|
from jose import JWTError, jwt
|
|
|
|
from app.core.config import settings
|
|
from app.infra.db.metadata.database import SessionLocal
|
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AuditMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
审计中间件
|
|
|
|
自动记录以下操作:
|
|
- 所有 POST/PUT/DELETE 请求
|
|
- 登录/登出
|
|
- 关键资源访问
|
|
"""
|
|
|
|
# 需要审计的路径前缀
|
|
AUDIT_PATHS = [
|
|
# "/api/v1/auth/",
|
|
# "/api/v1/users/",
|
|
# "/api/v1/projects/",
|
|
# "/api/v1/networks/",
|
|
]
|
|
|
|
# [新增] 需要审计的 API Tags (在 Router 或 api 函数中定义 tags=["Audit"])
|
|
AUDIT_TAGS = [
|
|
"Audit",
|
|
"Users",
|
|
"Project",
|
|
"Network General",
|
|
"Junctions",
|
|
"Pipes",
|
|
"Reservoirs",
|
|
"Tanks",
|
|
"Pumps",
|
|
"Valves",
|
|
]
|
|
|
|
# 需要审计的HTTP方法
|
|
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# 提取开始时间
|
|
start_time = time.time()
|
|
|
|
# 1. 预判是否需要读取Body (针对写操作)
|
|
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
|
|
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
|
|
|
|
request_data = None
|
|
if should_capture_body:
|
|
try:
|
|
# 注意:读取 body 后需要重新设置,避免影响后续处理
|
|
body = await request.body()
|
|
if body:
|
|
request_data = json.loads(body.decode())
|
|
|
|
# 重新构造请求以供后续使用
|
|
async def receive():
|
|
return {"type": "http.request", "body": body}
|
|
|
|
request._receive = receive
|
|
except Exception as e:
|
|
logger.warning(f"Failed to read request body for audit: {e}")
|
|
|
|
# 2. 执行请求 (FastAPI在此过程中进行路由匹配)
|
|
response = await call_next(request)
|
|
|
|
# 3. 决定是否审计
|
|
# 检查方法
|
|
is_audit_method = request.method in self.AUDIT_METHODS
|
|
# 检查路径
|
|
is_audit_path = any(
|
|
request.url.path.startswith(path) for path in self.AUDIT_PATHS
|
|
)
|
|
# [新增] 检查 Tags (从 request.scope 中获取匹配的路由信息)
|
|
is_audit_tag = False
|
|
route = request.scope.get("route")
|
|
if route and hasattr(route, "tags"):
|
|
is_audit_tag = any(tag in self.AUDIT_TAGS for tag in route.tags)
|
|
|
|
should_audit = is_audit_method or is_audit_path or is_audit_tag
|
|
|
|
if not should_audit:
|
|
# 即便不审计,也要处理响应头中的时间(保持原有逻辑一致性)
|
|
process_time = time.time() - start_time
|
|
response.headers["X-Process-Time"] = str(process_time)
|
|
return response
|
|
|
|
# 4. 提取审计所需信息
|
|
user_id = await self._resolve_user_id(request)
|
|
project_id = self._resolve_project_id(request)
|
|
|
|
# 获取客户端信息
|
|
ip_address = request.client.host if request.client else None
|
|
|
|
# 确定操作类型
|
|
action = self._determine_action(request)
|
|
resource_type, resource_id = self._extract_resource_info(request)
|
|
|
|
# 记录审计日志
|
|
try:
|
|
await log_audit_event(
|
|
action=action,
|
|
user_id=user_id,
|
|
project_id=project_id,
|
|
resource_type=resource_type,
|
|
resource_id=resource_id,
|
|
ip_address=ip_address,
|
|
request_method=request.method,
|
|
request_path=str(request.url.path),
|
|
request_data=request_data,
|
|
response_status=response.status_code,
|
|
)
|
|
except Exception as e:
|
|
# 审计失败不应影响响应
|
|
logger.error(f"Failed to log audit event: {e}", exc_info=True)
|
|
|
|
# 添加处理时间到响应头
|
|
process_time = time.time() - start_time
|
|
response.headers["X-Process-Time"] = str(process_time)
|
|
|
|
return response
|
|
|
|
def _resolve_project_id(self, request: Request) -> UUID | None:
|
|
project_header = request.headers.get("X-Project-Id")
|
|
if not project_header:
|
|
return None
|
|
try:
|
|
return UUID(project_header)
|
|
except ValueError:
|
|
return None
|
|
|
|
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
|
auth_header = request.headers.get("authorization")
|
|
if not auth_header or not auth_header.lower().startswith("bearer "):
|
|
return None
|
|
token = auth_header.split(" ", 1)[1].strip()
|
|
if not token:
|
|
return None
|
|
try:
|
|
key = (
|
|
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
|
if settings.KEYCLOAK_PUBLIC_KEY
|
|
else settings.SECRET_KEY
|
|
)
|
|
algorithms = (
|
|
[settings.KEYCLOAK_ALGORITHM]
|
|
if settings.KEYCLOAK_PUBLIC_KEY
|
|
else [settings.ALGORITHM]
|
|
)
|
|
payload = jwt.decode(token, key, algorithms=algorithms)
|
|
sub = payload.get("sub")
|
|
if not sub:
|
|
return None
|
|
keycloak_id = UUID(sub)
|
|
except (JWTError, ValueError):
|
|
return None
|
|
|
|
async with SessionLocal() as session:
|
|
repo = MetadataRepository(session)
|
|
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
|
if user and user.is_active:
|
|
return user.id
|
|
return None
|
|
|
|
def _determine_action(self, request: Request) -> str:
|
|
"""根据请求路径和方法确定操作类型"""
|
|
path = request.url.path.lower()
|
|
method = request.method
|
|
|
|
# 认证相关
|
|
if "login" in path:
|
|
return AuditAction.LOGIN
|
|
elif "logout" in path:
|
|
return AuditAction.LOGOUT
|
|
elif "register" in path:
|
|
return AuditAction.REGISTER
|
|
|
|
# CRUD 操作
|
|
if method == "POST":
|
|
return AuditAction.CREATE
|
|
elif method == "PUT" or method == "PATCH":
|
|
return AuditAction.UPDATE
|
|
elif method == "DELETE":
|
|
return AuditAction.DELETE
|
|
elif method == "GET":
|
|
return AuditAction.READ
|
|
|
|
return f"{method}_REQUEST"
|
|
|
|
def _extract_resource_info(self, request: Request) -> tuple:
|
|
"""从请求路径提取资源类型和ID"""
|
|
path_parts = request.url.path.strip("/").split("/")
|
|
|
|
resource_type = None
|
|
resource_id = None
|
|
|
|
# 尝试从路径中提取资源信息
|
|
# 例如: /api/v1/users/123 -> resource_type=user, resource_id=123
|
|
if len(path_parts) >= 4:
|
|
resource_type = path_parts[3].rstrip("s") # 移除复数s
|
|
|
|
if len(path_parts) >= 5 and path_parts[4].isdigit():
|
|
resource_id = path_parts[4]
|
|
|
|
return resource_type, resource_id
|