Files
TJWaterServerBinary/app/infra/audit/middleware.py
T
2026-03-27 13:52:12 +08:00

288 lines
9.4 KiB
Python

"""
审计日志中间件
自动记录关键HTTP请求到审计日志
"""
import time
import json
from uuid import UUID
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.audit import log_audit_event, AuditAction
import logging
from jose import JWTError, jwt
from app.core.config import settings
from app.infra.db.metadb.database import SessionLocal
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
class AuditMiddleware(BaseHTTPMiddleware):
"""
审计中间件
自动记录以下操作:
- 所有 POST/PUT/DELETE 请求
- 登录/登出
- 关键资源访问
"""
# 需要审计的路径前缀
AUDIT_PATHS = [
# "/api/v1/auth/",
# "/api/v1/users/",
# "/api/v1/projects/",
# "/api/v1/networks/",
]
# [新增] 需要审计的 API Tags (在 Router 或 api 函数中定义 tags=["Audit"])
AUDIT_TAGS = [
"Audit",
"Users",
"Project",
"Network General",
"Junctions",
"Pipes",
"Reservoirs",
"Tanks",
"Pumps",
"Valves",
]
# 需要审计的HTTP方法
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
EXCLUDED_PATHS = {
"/api/v1/meta/projects",
"/meta/projects",
"/api/v1/openproject/",
"/openproject/",
}
EXCLUDED_PATH_PREFIXES = (
)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 提取开始时间
start_time = time.time()
# 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期
if self._is_excluded_path(request.url.path):
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 1. 预判是否需要读取Body (针对写操作)
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
request_data = None
if should_capture_body:
try:
# 注意:读取 body 后需要重新设置,避免影响后续处理
original_receive = request._receive
body = await request.body()
if body:
request_data = json.loads(body.decode())
# 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive
body_sent = False
async def receive():
nonlocal body_sent
if not body_sent:
body_sent = True
return {
"type": "http.request",
"body": body,
"more_body": False,
}
return await original_receive()
request._receive = receive
except Exception as e:
logger.warning(f"Failed to read request body for audit: {e}")
# 2. 执行请求 (FastAPI在此过程中进行路由匹配)
response = await call_next(request)
# 3. 决定是否审计
if self._is_excluded_path(request.url.path):
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 检查方法
is_audit_method = request.method in self.AUDIT_METHODS
# 检查路径
is_audit_path = any(
request.url.path.startswith(path) for path in self.AUDIT_PATHS
)
# [新增] 检查 Tags (从 request.scope 中获取匹配的路由信息)
is_audit_tag = False
route = request.scope.get("route")
if route and hasattr(route, "tags"):
is_audit_tag = any(tag in self.AUDIT_TAGS for tag in route.tags)
should_audit = is_audit_method or is_audit_path or is_audit_tag
if not should_audit:
# 即便不审计,也要处理响应头中的时间(保持原有逻辑一致性)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 4. 提取审计所需信息
user_id = await self._resolve_user_id(request)
project_id = self._resolve_project_id(request)
# 获取客户端信息
ip_address = request.client.host if request.client else None
# 确定操作类型
action = self._determine_action(request)
resource_type, resource_id = self._extract_resource_info(request)
# 记录审计日志
try:
await log_audit_event(
action=action,
user_id=user_id,
project_id=project_id,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
request_method=request.method,
request_path=str(request.url.path),
request_data=request_data,
response_status=response.status_code,
)
except Exception as e:
# 审计失败不应影响响应
logger.error(f"Failed to log audit event: {e}", exc_info=True)
# 添加处理时间到响应头
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
def _is_excluded_path(self, path: str) -> bool:
if path in self.EXCLUDED_PATHS:
return True
return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES)
def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id")
if not project_header:
return None
try:
return UUID(project_header)
except ValueError:
return None
async def _resolve_user_id(self, request: Request) -> UUID | None:
auth_header = request.headers.get("authorization")
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header.split(" ", 1)[1].strip()
if not token:
return None
sub = None
try:
key = (
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
if settings.KEYCLOAK_PUBLIC_KEY
else settings.SECRET_KEY
)
algorithms = (
[settings.KEYCLOAK_ALGORITHM]
if settings.KEYCLOAK_PUBLIC_KEY
else [settings.ALGORITHM]
)
payload = jwt.decode(
token,
key,
algorithms=algorithms,
audience=settings.KEYCLOAK_AUDIENCE or None,
)
sub = payload.get("sub")
if not sub:
return None
except JWTError:
return None
async with SessionLocal() as session:
repo = MetadataRepository(session)
try:
keycloak_id = UUID(sub)
user = await repo.get_user_by_keycloak_id(keycloak_id)
except ValueError:
user = await repo.get_user_by_username(sub)
if user and user.is_active:
return user.id
return None
def _determine_action(self, request: Request) -> str:
"""根据请求路径和方法确定操作类型"""
path = request.url.path.lower()
method = request.method
# 认证相关
if "login" in path:
return AuditAction.LOGIN
elif "logout" in path:
return AuditAction.LOGOUT
elif "register" in path:
return AuditAction.REGISTER
# CRUD 操作
if method == "POST":
return AuditAction.CREATE
elif method == "PUT" or method == "PATCH":
return AuditAction.UPDATE
elif method == "DELETE":
return AuditAction.DELETE
elif method == "GET":
return AuditAction.READ
return f"{method}_REQUEST"
def _extract_resource_info(self, request: Request) -> tuple:
"""从请求路径提取资源类型和ID"""
path_parts = request.url.path.strip("/").split("/")
resource_type = None
resource_id = None
# 尝试从路径中提取资源信息
# 例如: /api/v1/users/123 -> resource_type=user, resource_id=123
if len(path_parts) >= 4:
resource_type = path_parts[3].rstrip("s") # 移除复数s
if len(path_parts) >= 5 and path_parts[4]:
resource_id = path_parts[4]
# 无路径ID时,尝试从查询参数提取业务ID
if not resource_id:
for key in (
"id",
"resource_id",
"device_id",
"device_ids",
"element_id",
"user_id",
"project_id",
"network",
"name",
):
value = request.query_params.get(key)
if value:
resource_id = value
break
return resource_type, resource_id