重构数据库连接管理,添加元数据支持
This commit is contained in:
@@ -6,12 +6,17 @@
|
||||
|
||||
import time
|
||||
import json
|
||||
from uuid import UUID
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.infra.db.postgresql.database import db as default_db
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
import logging
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
return response
|
||||
|
||||
# 4. 提取审计所需信息
|
||||
user_id = None
|
||||
username = None
|
||||
|
||||
# 尝试从请求状态获取当前用户
|
||||
if hasattr(request.state, "user"):
|
||||
user = request.state.user
|
||||
user_id = getattr(user, "id", None)
|
||||
username = getattr(user, "username", None)
|
||||
user_id = await self._resolve_user_id(request)
|
||||
project_id = self._resolve_project_id(request)
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
# 确定操作类型
|
||||
action = self._determine_action(request)
|
||||
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
await log_audit_event(
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_method=request.method,
|
||||
request_path=str(request.url.path),
|
||||
request_data=request_data,
|
||||
response_status=response.status_code,
|
||||
error_message=(
|
||||
None
|
||||
if response.status_code < 400
|
||||
else f"HTTP {response.status_code}"
|
||||
),
|
||||
db=default_db,
|
||||
)
|
||||
except Exception as e:
|
||||
# 审计失败不应影响响应
|
||||
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return response
|
||||
|
||||
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||
project_header = request.headers.get("X-Project-Id")
|
||||
if not project_header:
|
||||
return None
|
||||
try:
|
||||
return UUID(project_header)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
||||
auth_header = request.headers.get("authorization")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
key = (
|
||||
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else settings.SECRET_KEY
|
||||
)
|
||||
algorithms = (
|
||||
[settings.KEYCLOAK_ALGORITHM]
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else [settings.ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
return None
|
||||
keycloak_id = UUID(sub)
|
||||
except (JWTError, ValueError):
|
||||
return None
|
||||
|
||||
async with SessionLocal() as session:
|
||||
repo = MetadataRepository(session)
|
||||
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||
if user and user.is_active:
|
||||
return user.id
|
||||
return None
|
||||
|
||||
def _determine_action(self, request: Request) -> str:
|
||||
"""根据请求路径和方法确定操作类型"""
|
||||
path = request.url.path.lower()
|
||||
|
||||
Reference in New Issue
Block a user