重构数据库连接管理,添加元数据支持
This commit is contained in:
28
.env.example
28
.env.example
@@ -31,6 +31,25 @@ TIMESCALEDB_DB_PORT="5433"
|
|||||||
TIMESCALEDB_DB_USER="tjwater"
|
TIMESCALEDB_DB_USER="tjwater"
|
||||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 元数据数据库配置 (Metadata DB)
|
||||||
|
# ============================================
|
||||||
|
METADATA_DB_NAME="system_hub"
|
||||||
|
METADATA_DB_HOST="localhost"
|
||||||
|
METADATA_DB_PORT="5432"
|
||||||
|
METADATA_DB_USER="tjwater"
|
||||||
|
METADATA_DB_PASSWORD="password"
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 项目连接缓存与连接池配置
|
||||||
|
# ============================================
|
||||||
|
PROJECT_PG_CACHE_SIZE=50
|
||||||
|
PROJECT_TS_CACHE_SIZE=50
|
||||||
|
PROJECT_PG_POOL_SIZE=5
|
||||||
|
PROJECT_PG_MAX_OVERFLOW=10
|
||||||
|
PROJECT_TS_POOL_MIN_SIZE=1
|
||||||
|
PROJECT_TS_POOL_MAX_SIZE=10
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# InfluxDB 配置 (时序数据)
|
# InfluxDB 配置 (时序数据)
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -46,6 +65,15 @@ TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
|||||||
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||||
# ALGORITHM=HS256
|
# ALGORITHM=HS256
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# Keycloak JWT (可选)
|
||||||
|
# ============================================
|
||||||
|
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||||
|
# KEYCLOAK_ALGORITHM=RS256
|
||||||
|
|
||||||
|
# 临时禁用鉴权(调试用)
|
||||||
|
# AUTH_DISABLED=false
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 其他配置
|
# 其他配置
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|||||||
@@ -4,33 +4,38 @@
|
|||||||
仅管理员可访问
|
仅管理员可访问
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from uuid import UUID
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query
|
||||||
from app.domain.schemas.audit import AuditLogResponse, AuditLogQuery
|
from app.domain.schemas.audit import AuditLogResponse
|
||||||
from app.domain.schemas.user import UserInDB
|
|
||||||
from app.infra.repositories.audit_repository import AuditRepository
|
from app.infra.repositories.audit_repository import AuditRepository
|
||||||
from app.auth.dependencies import get_user_repository, get_db
|
from app.auth.metadata_dependencies import (
|
||||||
from app.auth.permissions import get_current_admin
|
get_current_metadata_admin,
|
||||||
from app.infra.db.postgresql.database import Database
|
get_current_metadata_user,
|
||||||
|
)
|
||||||
|
from app.infra.db.metadata.database import get_metadata_session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
async def get_audit_repository(db: Database = Depends(get_db)) -> AuditRepository:
|
async def get_audit_repository(
|
||||||
|
session: AsyncSession = Depends(get_metadata_session),
|
||||||
|
) -> AuditRepository:
|
||||||
"""获取审计日志仓储"""
|
"""获取审计日志仓储"""
|
||||||
return AuditRepository(db)
|
return AuditRepository(session)
|
||||||
|
|
||||||
@router.get("/logs", response_model=List[AuditLogResponse])
|
@router.get("/logs", response_model=List[AuditLogResponse])
|
||||||
async def get_audit_logs(
|
async def get_audit_logs(
|
||||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||||
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
||||||
current_user: UserInDB = Depends(get_current_admin),
|
current_user=Depends(get_current_metadata_admin),
|
||||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||||
) -> List[AuditLogResponse]:
|
) -> List[AuditLogResponse]:
|
||||||
"""
|
"""
|
||||||
查询审计日志(仅管理员)
|
查询审计日志(仅管理员)
|
||||||
@@ -39,7 +44,7 @@ async def get_audit_logs(
|
|||||||
"""
|
"""
|
||||||
logs = await audit_repo.get_logs(
|
logs = await audit_repo.get_logs(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
action=action,
|
action=action,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
@@ -51,21 +56,21 @@ async def get_audit_logs(
|
|||||||
|
|
||||||
@router.get("/logs/count")
|
@router.get("/logs/count")
|
||||||
async def get_audit_logs_count(
|
async def get_audit_logs_count(
|
||||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||||
current_user: UserInDB = Depends(get_current_admin),
|
current_user=Depends(get_current_metadata_admin),
|
||||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取审计日志总数(仅管理员)
|
获取审计日志总数(仅管理员)
|
||||||
"""
|
"""
|
||||||
count = await audit_repo.get_log_count(
|
count = await audit_repo.get_log_count(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
action=action,
|
action=action,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
@@ -80,8 +85,8 @@ async def get_my_audit_logs(
|
|||||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
current_user: UserInDB = Depends(get_current_admin),
|
current_user=Depends(get_current_metadata_user),
|
||||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||||
) -> List[AuditLogResponse]:
|
) -> List[AuditLogResponse]:
|
||||||
"""
|
"""
|
||||||
查询当前用户的审计日志
|
查询当前用户的审计日志
|
||||||
|
|||||||
90
app/api/v1/endpoints/meta.py
Normal file
90
app/api/v1/endpoints/meta.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from psycopg import AsyncConnection
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.project_dependencies import (
|
||||||
|
ProjectContext,
|
||||||
|
get_project_context,
|
||||||
|
get_project_pg_session,
|
||||||
|
get_project_timescale_connection,
|
||||||
|
get_metadata_repository,
|
||||||
|
)
|
||||||
|
from app.auth.metadata_dependencies import get_current_metadata_user
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.domain.schemas.metadata import (
|
||||||
|
GeoServerConfigResponse,
|
||||||
|
ProjectMetaResponse,
|
||||||
|
ProjectSummaryResponse,
|
||||||
|
)
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meta/project", response_model=ProjectMetaResponse)
|
||||||
|
async def get_project_metadata(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
):
|
||||||
|
project = await metadata_repo.get_project_by_id(ctx.project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||||
|
)
|
||||||
|
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
|
||||||
|
geoserver_payload = (
|
||||||
|
GeoServerConfigResponse(
|
||||||
|
gs_base_url=geoserver.gs_base_url,
|
||||||
|
gs_admin_user=geoserver.gs_admin_user,
|
||||||
|
gs_datastore_name=geoserver.gs_datastore_name,
|
||||||
|
default_extent=geoserver.default_extent,
|
||||||
|
srid=geoserver.srid,
|
||||||
|
)
|
||||||
|
if geoserver
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return ProjectMetaResponse(
|
||||||
|
project_id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role=ctx.project_role,
|
||||||
|
geoserver=geoserver_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
|
||||||
|
async def list_user_projects(
|
||||||
|
current_user=Depends(get_current_metadata_user),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
):
|
||||||
|
if settings.AUTH_DISABLED:
|
||||||
|
projects = await metadata_repo.list_all_projects()
|
||||||
|
else:
|
||||||
|
projects = await metadata_repo.list_projects_for_user(current_user.id)
|
||||||
|
return [
|
||||||
|
ProjectSummaryResponse(
|
||||||
|
project_id=project.project_id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role=project.project_role,
|
||||||
|
)
|
||||||
|
for project in projects
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meta/db/health")
|
||||||
|
async def project_db_health(
|
||||||
|
pg_session: AsyncSession = Depends(get_project_pg_session),
|
||||||
|
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||||
|
):
|
||||||
|
await pg_session.execute(text("SELECT 1"))
|
||||||
|
async with ts_conn.cursor() as cur:
|
||||||
|
await cur.execute("SELECT 1")
|
||||||
|
return {"postgres": "ok", "timescale": "ok"}
|
||||||
@@ -14,6 +14,7 @@ from app.api.v1.endpoints import (
|
|||||||
cache,
|
cache,
|
||||||
user_management, # 新增:用户管理
|
user_management, # 新增:用户管理
|
||||||
audit, # 新增:审计日志
|
audit, # 新增:审计日志
|
||||||
|
meta,
|
||||||
)
|
)
|
||||||
from app.api.v1.endpoints.network import (
|
from app.api.v1.endpoints.network import (
|
||||||
general,
|
general,
|
||||||
@@ -46,6 +47,7 @@ api_router = APIRouter()
|
|||||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
||||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
||||||
|
api_router.include_router(meta.router, tags=["Metadata"])
|
||||||
api_router.include_router(project.router, tags=["Project"])
|
api_router.include_router(project.router, tags=["Project"])
|
||||||
|
|
||||||
# Network Elements (Node/Link Types)
|
# Network Elements (Node/Link Types)
|
||||||
|
|||||||
56
app/auth/keycloak_dependencies.py
Normal file
56
app/auth/keycloak_dependencies.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
oauth2_optional = OAuth2PasswordBearer(
|
||||||
|
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_keycloak_sub(
|
||||||
|
token: str | None = Depends(oauth2_optional),
|
||||||
|
) -> UUID:
|
||||||
|
if settings.AUTH_DISABLED:
|
||||||
|
return UUID(int=0)
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Not authenticated",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
if settings.KEYCLOAK_PUBLIC_KEY:
|
||||||
|
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||||
|
algorithms = [settings.KEYCLOAK_ALGORITHM]
|
||||||
|
else:
|
||||||
|
key = settings.SECRET_KEY
|
||||||
|
algorithms = [settings.ALGORITHM]
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||||
|
except JWTError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
sub = payload.get("sub")
|
||||||
|
if not sub:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing subject claim",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return UUID(sub)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid subject claim",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
) from exc
|
||||||
50
app/auth/metadata_dependencies.py
Normal file
50
app/auth/metadata_dependencies.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.infra.db.metadata.database import get_metadata_session
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
|
|
||||||
|
async def get_metadata_repository(
|
||||||
|
session: AsyncSession = Depends(get_metadata_session),
|
||||||
|
) -> MetadataRepository:
|
||||||
|
return MetadataRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_metadata_user(
|
||||||
|
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
):
|
||||||
|
if settings.AUTH_DISABLED:
|
||||||
|
return _AuthBypassUser()
|
||||||
|
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_metadata_admin(
|
||||||
|
user=Depends(get_current_metadata_user),
|
||||||
|
):
|
||||||
|
if settings.AUTH_DISABLED:
|
||||||
|
return user
|
||||||
|
if user.is_superuser or user.role == "admin":
|
||||||
|
return user
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _AuthBypassUser:
|
||||||
|
id: UUID = UUID(int=0)
|
||||||
|
role: str = "admin"
|
||||||
|
is_superuser: bool = True
|
||||||
|
is_active: bool = True
|
||||||
176
app/auth/project_dependencies.py
Normal file
176
app/auth/project_dependencies.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
|
from psycopg import AsyncConnection
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.infra.db.dynamic_manager import project_connection_manager
|
||||||
|
from app.infra.db.metadata.database import get_metadata_session
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
|
DB_ROLE_BIZ_PG = "biz_pg"
|
||||||
|
DB_ROLE_IOT_TS = "iot_ts"
|
||||||
|
DB_TYPE_POSTGRES = "postgresql"
|
||||||
|
DB_TYPE_TIMESCALE = "timescaledb"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectContext:
|
||||||
|
project_id: UUID
|
||||||
|
user_id: UUID
|
||||||
|
project_role: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_metadata_repository(
|
||||||
|
session: AsyncSession = Depends(get_metadata_session),
|
||||||
|
) -> MetadataRepository:
|
||||||
|
return MetadataRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_context(
|
||||||
|
x_project_id: str = Header(..., alias="X-Project-Id"),
|
||||||
|
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> ProjectContext:
|
||||||
|
try:
|
||||||
|
project_uuid = UUID(x_project_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
project = await metadata_repo.get_project_by_id(project_uuid)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||||
|
)
|
||||||
|
if project.status != "active":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
|
||||||
|
)
|
||||||
|
|
||||||
|
if settings.AUTH_DISABLED:
|
||||||
|
return ProjectContext(
|
||||||
|
project_id=project.id,
|
||||||
|
user_id=UUID(int=0),
|
||||||
|
project_role="owner",
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
|
||||||
|
)
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||||
|
)
|
||||||
|
|
||||||
|
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
|
||||||
|
if not membership_role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ProjectContext(
|
||||||
|
project_id=project.id,
|
||||||
|
user_id=user.id,
|
||||||
|
project_role=membership_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_pg_session(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
routing = await metadata_repo.get_project_db_routing(
|
||||||
|
ctx.project_id, DB_ROLE_BIZ_PG
|
||||||
|
)
|
||||||
|
if not routing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL not configured",
|
||||||
|
)
|
||||||
|
if routing.db_type != DB_TYPE_POSTGRES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL type mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
|
||||||
|
ctx.project_id,
|
||||||
|
DB_ROLE_BIZ_PG,
|
||||||
|
routing.dsn,
|
||||||
|
pool_min_size,
|
||||||
|
pool_max_size,
|
||||||
|
)
|
||||||
|
async with sessionmaker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_pg_connection(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> AsyncGenerator[AsyncConnection, None]:
|
||||||
|
routing = await metadata_repo.get_project_db_routing(
|
||||||
|
ctx.project_id, DB_ROLE_BIZ_PG
|
||||||
|
)
|
||||||
|
if not routing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL not configured",
|
||||||
|
)
|
||||||
|
if routing.db_type != DB_TYPE_POSTGRES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL type mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
pool = await project_connection_manager.get_pg_pool(
|
||||||
|
ctx.project_id,
|
||||||
|
DB_ROLE_BIZ_PG,
|
||||||
|
routing.dsn,
|
||||||
|
pool_min_size,
|
||||||
|
pool_max_size,
|
||||||
|
)
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_timescale_connection(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> AsyncGenerator[AsyncConnection, None]:
|
||||||
|
routing = await metadata_repo.get_project_db_routing(
|
||||||
|
ctx.project_id, DB_ROLE_IOT_TS
|
||||||
|
)
|
||||||
|
if not routing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project TimescaleDB not configured",
|
||||||
|
)
|
||||||
|
if routing.db_type != DB_TYPE_TIMESCALE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project TimescaleDB type mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
|
||||||
|
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
|
||||||
|
pool = await project_connection_manager.get_timescale_pool(
|
||||||
|
ctx.project_id,
|
||||||
|
DB_ROLE_IOT_TS,
|
||||||
|
routing.dsn,
|
||||||
|
pool_min_size,
|
||||||
|
pool_max_size,
|
||||||
|
)
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
yield conn
|
||||||
@@ -7,6 +7,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,18 +39,16 @@ class AuditAction:
|
|||||||
|
|
||||||
async def log_audit_event(
|
async def log_audit_event(
|
||||||
action: str,
|
action: str,
|
||||||
user_id: Optional[int] = None,
|
user_id: Optional[UUID] = None,
|
||||||
username: Optional[str] = None,
|
project_id: Optional[UUID] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
resource_id: Optional[str] = None,
|
resource_id: Optional[str] = None,
|
||||||
ip_address: Optional[str] = None,
|
ip_address: Optional[str] = None,
|
||||||
user_agent: Optional[str] = None,
|
|
||||||
request_method: Optional[str] = None,
|
request_method: Optional[str] = None,
|
||||||
request_path: Optional[str] = None,
|
request_path: Optional[str] = None,
|
||||||
request_data: Optional[dict] = None,
|
request_data: Optional[dict] = None,
|
||||||
response_status: Optional[int] = None,
|
response_status: Optional[int] = None,
|
||||||
error_message: Optional[str] = None,
|
session=None,
|
||||||
db=None, # 新增:可选的数据库实例
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
记录审计日志
|
记录审计日志
|
||||||
@@ -57,68 +56,61 @@ async def log_audit_event(
|
|||||||
Args:
|
Args:
|
||||||
action: 操作类型
|
action: 操作类型
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
username: 用户名
|
project_id: 项目ID
|
||||||
resource_type: 资源类型
|
resource_type: 资源类型
|
||||||
resource_id: 资源ID
|
resource_id: 资源ID
|
||||||
ip_address: IP地址
|
ip_address: IP地址
|
||||||
user_agent: User-Agent
|
|
||||||
request_method: 请求方法
|
request_method: 请求方法
|
||||||
request_path: 请求路径
|
request_path: 请求路径
|
||||||
request_data: 请求数据(敏感字段需脱敏)
|
request_data: 请求数据(敏感字段需脱敏)
|
||||||
response_status: 响应状态码
|
response_status: 响应状态码
|
||||||
error_message: 错误消息
|
session: 元数据库会话(可选)
|
||||||
db: 数据库实例(可选,如果不提供则尝试获取)
|
|
||||||
"""
|
"""
|
||||||
|
from app.infra.db.metadata.database import SessionLocal
|
||||||
from app.infra.repositories.audit_repository import AuditRepository
|
from app.infra.repositories.audit_repository import AuditRepository
|
||||||
|
|
||||||
try:
|
|
||||||
# 脱敏敏感数据
|
|
||||||
if request_data:
|
if request_data:
|
||||||
request_data = sanitize_sensitive_data(request_data)
|
request_data = sanitize_sensitive_data(request_data)
|
||||||
|
|
||||||
# 如果没有提供数据库实例,尝试从全局获取
|
if session is None:
|
||||||
if db is None:
|
async with SessionLocal() as session:
|
||||||
try:
|
audit_repo = AuditRepository(session)
|
||||||
from app.infra.db.postgresql.database import db as default_db
|
|
||||||
|
|
||||||
# 仅当连接池已初始化时使用
|
|
||||||
if default_db.pool:
|
|
||||||
db = default_db
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 如果仍然没有数据库实例
|
|
||||||
if db is None:
|
|
||||||
# 在某些上下文中可能无法获取,此时静默失败
|
|
||||||
logger.warning("No database instance provided for audit logging")
|
|
||||||
return
|
|
||||||
|
|
||||||
audit_repo = AuditRepository(db)
|
|
||||||
|
|
||||||
await audit_repo.create_log(
|
await audit_repo.create_log(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
|
action=action,
|
||||||
|
resource_type=resource_type,
|
||||||
|
resource_id=resource_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
request_method=request_method,
|
||||||
|
request_path=request_path,
|
||||||
|
request_data=request_data,
|
||||||
|
response_status=response_status,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audit_repo = AuditRepository(session)
|
||||||
|
await audit_repo.create_log(
|
||||||
|
user_id=user_id,
|
||||||
|
project_id=project_id,
|
||||||
action=action,
|
action=action,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
resource_id=resource_id,
|
resource_id=resource_id,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent,
|
|
||||||
request_method=request_method,
|
request_method=request_method,
|
||||||
request_path=request_path,
|
request_path=request_path,
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
response_status=response_status,
|
response_status=response_status,
|
||||||
error_message=error_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Audit log created: action={action}, user={username or user_id}, "
|
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
|
||||||
f"resource={resource_type}:{resource_id}"
|
action,
|
||||||
|
user_id,
|
||||||
|
project_id,
|
||||||
|
resource_type,
|
||||||
|
resource_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 审计日志失败不应影响业务流程
|
|
||||||
logger.error(f"Failed to create audit log: {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_sensitive_data(data: dict) -> dict:
|
def sanitize_sensitive_data(data: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -35,10 +35,41 @@ class Settings(BaseSettings):
|
|||||||
INFLUXDB_ORG: str = "org"
|
INFLUXDB_ORG: str = "org"
|
||||||
INFLUXDB_BUCKET: str = "bucket"
|
INFLUXDB_BUCKET: str = "bucket"
|
||||||
|
|
||||||
|
# Metadata Database Config (PostgreSQL)
|
||||||
|
METADATA_DB_NAME: str = "system_hub"
|
||||||
|
METADATA_DB_HOST: str = "localhost"
|
||||||
|
METADATA_DB_PORT: str = "5432"
|
||||||
|
METADATA_DB_USER: str = "postgres"
|
||||||
|
METADATA_DB_PASSWORD: str = "password"
|
||||||
|
|
||||||
|
METADATA_DB_POOL_SIZE: int = 5
|
||||||
|
METADATA_DB_MAX_OVERFLOW: int = 10
|
||||||
|
|
||||||
|
PROJECT_PG_CACHE_SIZE: int = 50
|
||||||
|
PROJECT_TS_CACHE_SIZE: int = 50
|
||||||
|
PROJECT_PG_POOL_SIZE: int = 5
|
||||||
|
PROJECT_PG_MAX_OVERFLOW: int = 10
|
||||||
|
PROJECT_TS_POOL_MIN_SIZE: int = 1
|
||||||
|
PROJECT_TS_POOL_MAX_SIZE: int = 10
|
||||||
|
|
||||||
|
# Keycloak JWT (optional override)
|
||||||
|
KEYCLOAK_PUBLIC_KEY: str = ""
|
||||||
|
KEYCLOAK_ALGORITHM: str = "RS256"
|
||||||
|
|
||||||
|
# Auth bypass (temporary)
|
||||||
|
AUTH_DISABLED: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||||
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def METADATA_DATABASE_URI(self) -> str:
|
||||||
|
return (
|
||||||
|
f"postgresql+psycopg://{self.METADATA_DB_USER}:{self.METADATA_DB_PASSWORD}"
|
||||||
|
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
extra = "ignore"
|
extra = "ignore"
|
||||||
|
|||||||
@@ -1,45 +1,42 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Any
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
class AuditLogCreate(BaseModel):
|
class AuditLogCreate(BaseModel):
|
||||||
"""创建审计日志"""
|
"""创建审计日志"""
|
||||||
user_id: Optional[int] = None
|
user_id: Optional[UUID] = None
|
||||||
username: Optional[str] = None
|
project_id: Optional[UUID] = None
|
||||||
action: str
|
action: str
|
||||||
resource_type: Optional[str] = None
|
resource_type: Optional[str] = None
|
||||||
resource_id: Optional[str] = None
|
resource_id: Optional[str] = None
|
||||||
ip_address: Optional[str] = None
|
ip_address: Optional[str] = None
|
||||||
user_agent: Optional[str] = None
|
|
||||||
request_method: Optional[str] = None
|
request_method: Optional[str] = None
|
||||||
request_path: Optional[str] = None
|
request_path: Optional[str] = None
|
||||||
request_data: Optional[dict] = None
|
request_data: Optional[dict] = None
|
||||||
response_status: Optional[int] = None
|
response_status: Optional[int] = None
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
class AuditLogResponse(BaseModel):
|
class AuditLogResponse(BaseModel):
|
||||||
"""审计日志响应"""
|
"""审计日志响应"""
|
||||||
id: int
|
id: UUID
|
||||||
user_id: Optional[int]
|
user_id: Optional[UUID]
|
||||||
username: Optional[str]
|
project_id: Optional[UUID]
|
||||||
action: str
|
action: str
|
||||||
resource_type: Optional[str]
|
resource_type: Optional[str]
|
||||||
resource_id: Optional[str]
|
resource_id: Optional[str]
|
||||||
ip_address: Optional[str]
|
ip_address: Optional[str]
|
||||||
user_agent: Optional[str]
|
|
||||||
request_method: Optional[str]
|
request_method: Optional[str]
|
||||||
request_path: Optional[str]
|
request_path: Optional[str]
|
||||||
request_data: Optional[dict]
|
request_data: Optional[dict]
|
||||||
response_status: Optional[int]
|
response_status: Optional[int]
|
||||||
error_message: Optional[str]
|
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
class AuditLogQuery(BaseModel):
|
class AuditLogQuery(BaseModel):
|
||||||
"""审计日志查询参数"""
|
"""审计日志查询参数"""
|
||||||
user_id: Optional[int] = None
|
user_id: Optional[UUID] = None
|
||||||
username: Optional[str] = None
|
project_id: Optional[UUID] = None
|
||||||
action: Optional[str] = None
|
action: Optional[str] = None
|
||||||
resource_type: Optional[str] = None
|
resource_type: Optional[str] = None
|
||||||
start_time: Optional[datetime] = None
|
start_time: Optional[datetime] = None
|
||||||
|
|||||||
33
app/domain/schemas/metadata.py
Normal file
33
app/domain/schemas/metadata.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GeoServerConfigResponse(BaseModel):
|
||||||
|
gs_base_url: Optional[str]
|
||||||
|
gs_admin_user: Optional[str]
|
||||||
|
gs_datastore_name: str
|
||||||
|
default_extent: Optional[dict]
|
||||||
|
srid: int
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectMetaResponse(BaseModel):
|
||||||
|
project_id: UUID
|
||||||
|
name: str
|
||||||
|
code: str
|
||||||
|
description: Optional[str]
|
||||||
|
gs_workspace: str
|
||||||
|
status: str
|
||||||
|
project_role: str
|
||||||
|
geoserver: Optional[GeoServerConfigResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectSummaryResponse(BaseModel):
|
||||||
|
project_id: UUID
|
||||||
|
name: str
|
||||||
|
code: str
|
||||||
|
description: Optional[str]
|
||||||
|
gs_workspace: str
|
||||||
|
status: str
|
||||||
|
project_role: str
|
||||||
@@ -6,12 +6,17 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
from uuid import UUID
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from app.infra.db.postgresql.database import db as default_db
|
|
||||||
from app.core.audit import log_audit_event, AuditAction
|
from app.core.audit import log_audit_event, AuditAction
|
||||||
import logging
|
import logging
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.infra.db.metadata.database import SessionLocal
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
# 4. 提取审计所需信息
|
# 4. 提取审计所需信息
|
||||||
user_id = None
|
user_id = await self._resolve_user_id(request)
|
||||||
username = None
|
project_id = self._resolve_project_id(request)
|
||||||
|
|
||||||
# 尝试从请求状态获取当前用户
|
|
||||||
if hasattr(request.state, "user"):
|
|
||||||
user = request.state.user
|
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
username = getattr(user, "username", None)
|
|
||||||
|
|
||||||
# 获取客户端信息
|
# 获取客户端信息
|
||||||
ip_address = request.client.host if request.client else None
|
ip_address = request.client.host if request.client else None
|
||||||
user_agent = request.headers.get("user-agent")
|
|
||||||
|
|
||||||
# 确定操作类型
|
# 确定操作类型
|
||||||
action = self._determine_action(request)
|
action = self._determine_action(request)
|
||||||
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
|||||||
await log_audit_event(
|
await log_audit_event(
|
||||||
action=action,
|
action=action,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
resource_id=resource_id,
|
resource_id=resource_id,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent,
|
|
||||||
request_method=request.method,
|
request_method=request.method,
|
||||||
request_path=str(request.url.path),
|
request_path=str(request.url.path),
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
response_status=response.status_code,
|
response_status=response.status_code,
|
||||||
error_message=(
|
|
||||||
None
|
|
||||||
if response.status_code < 400
|
|
||||||
else f"HTTP {response.status_code}"
|
|
||||||
),
|
|
||||||
db=default_db,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 审计失败不应影响响应
|
# 审计失败不应影响响应
|
||||||
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||||
|
project_header = request.headers.get("X-Project-Id")
|
||||||
|
if not project_header:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return UUID(project_header)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
||||||
|
auth_header = request.headers.get("authorization")
|
||||||
|
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||||
|
return None
|
||||||
|
token = auth_header.split(" ", 1)[1].strip()
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
key = (
|
||||||
|
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||||
|
if settings.KEYCLOAK_PUBLIC_KEY
|
||||||
|
else settings.SECRET_KEY
|
||||||
|
)
|
||||||
|
algorithms = (
|
||||||
|
[settings.KEYCLOAK_ALGORITHM]
|
||||||
|
if settings.KEYCLOAK_PUBLIC_KEY
|
||||||
|
else [settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||||
|
sub = payload.get("sub")
|
||||||
|
if not sub:
|
||||||
|
return None
|
||||||
|
keycloak_id = UUID(sub)
|
||||||
|
except (JWTError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async with SessionLocal() as session:
|
||||||
|
repo = MetadataRepository(session)
|
||||||
|
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||||
|
if user and user.is_active:
|
||||||
|
return user.id
|
||||||
|
return None
|
||||||
|
|
||||||
def _determine_action(self, request: Request) -> str:
|
def _determine_action(self, request: Request) -> str:
|
||||||
"""根据请求路径和方法确定操作类型"""
|
"""根据请求路径和方法确定操作类型"""
|
||||||
path = request.url.path.lower()
|
path = request.url.path.lower()
|
||||||
|
|||||||
208
app/infra/db/dynamic_manager.py
Normal file
208
app/infra/db/dynamic_manager.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from sqlalchemy.engine.url import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncEngine,
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PgEngineEntry:
|
||||||
|
engine: AsyncEngine
|
||||||
|
sessionmaker: async_sessionmaker[AsyncSession]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CacheKey:
|
||||||
|
project_id: UUID
|
||||||
|
db_role: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectConnectionManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
|
||||||
|
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||||
|
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||||
|
self._pg_lock = asyncio.Lock()
|
||||||
|
self._ts_lock = asyncio.Lock()
|
||||||
|
self._pg_raw_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _normalize_pg_url(self, url: str) -> str:
|
||||||
|
parsed = make_url(url)
|
||||||
|
if parsed.drivername == "postgresql":
|
||||||
|
parsed = parsed.set(drivername="postgresql+psycopg")
|
||||||
|
return str(parsed)
|
||||||
|
|
||||||
|
async def get_pg_sessionmaker(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
db_role: str,
|
||||||
|
connection_url: str,
|
||||||
|
pool_min_size: int,
|
||||||
|
pool_max_size: int,
|
||||||
|
) -> async_sessionmaker[AsyncSession]:
|
||||||
|
async with self._pg_lock:
|
||||||
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||||
|
entry = self._pg_cache.get(key)
|
||||||
|
if entry:
|
||||||
|
self._pg_cache.move_to_end(key)
|
||||||
|
return entry.sessionmaker
|
||||||
|
|
||||||
|
normalized_url = self._normalize_pg_url(connection_url)
|
||||||
|
pool_min_size = max(1, pool_min_size)
|
||||||
|
pool_max_size = max(pool_min_size, pool_max_size)
|
||||||
|
engine = create_async_engine(
|
||||||
|
normalized_url,
|
||||||
|
pool_size=pool_min_size,
|
||||||
|
max_overflow=max(0, pool_max_size - pool_min_size),
|
||||||
|
pool_pre_ping=True,
|
||||||
|
)
|
||||||
|
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
self._pg_cache[key] = PgEngineEntry(
|
||||||
|
engine=engine,
|
||||||
|
sessionmaker=sessionmaker,
|
||||||
|
)
|
||||||
|
await self._evict_pg_if_needed()
|
||||||
|
logger.info(
|
||||||
|
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
|
||||||
|
)
|
||||||
|
return sessionmaker
|
||||||
|
|
||||||
|
async def get_timescale_pool(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
db_role: str,
|
||||||
|
connection_url: str,
|
||||||
|
pool_min_size: int,
|
||||||
|
pool_max_size: int,
|
||||||
|
) -> AsyncConnectionPool:
|
||||||
|
async with self._ts_lock:
|
||||||
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||||
|
pool = self._ts_cache.get(key)
|
||||||
|
if pool:
|
||||||
|
self._ts_cache.move_to_end(key)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
pool_min_size = max(1, pool_min_size)
|
||||||
|
pool_max_size = max(pool_min_size, pool_max_size)
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=connection_url,
|
||||||
|
min_size=pool_min_size,
|
||||||
|
max_size=pool_max_size,
|
||||||
|
open=False,
|
||||||
|
)
|
||||||
|
await pool.open()
|
||||||
|
self._ts_cache[key] = pool
|
||||||
|
await self._evict_ts_if_needed()
|
||||||
|
logger.info(
|
||||||
|
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
|
||||||
|
)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
async def get_pg_pool(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
db_role: str,
|
||||||
|
connection_url: str,
|
||||||
|
pool_min_size: int,
|
||||||
|
pool_max_size: int,
|
||||||
|
) -> AsyncConnectionPool:
|
||||||
|
async with self._pg_raw_lock:
|
||||||
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||||
|
pool = self._pg_raw_cache.get(key)
|
||||||
|
if pool:
|
||||||
|
self._pg_raw_cache.move_to_end(key)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
pool_min_size = max(1, pool_min_size)
|
||||||
|
pool_max_size = max(pool_min_size, pool_max_size)
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=connection_url,
|
||||||
|
min_size=pool_min_size,
|
||||||
|
max_size=pool_max_size,
|
||||||
|
open=False,
|
||||||
|
)
|
||||||
|
await pool.open()
|
||||||
|
self._pg_raw_cache[key] = pool
|
||||||
|
await self._evict_pg_raw_if_needed()
|
||||||
|
logger.info(
|
||||||
|
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
|
||||||
|
)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
async def _evict_pg_if_needed(self) -> None:
|
||||||
|
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||||
|
key, entry = self._pg_cache.popitem(last=False)
|
||||||
|
await entry.engine.dispose()
|
||||||
|
logger.info(
|
||||||
|
"Evicted PostgreSQL engine for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _evict_ts_if_needed(self) -> None:
|
||||||
|
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
|
||||||
|
key, pool = self._ts_cache.popitem(last=False)
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Evicted TimescaleDB pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _evict_pg_raw_if_needed(self) -> None:
|
||||||
|
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||||
|
key, pool = self._pg_raw_cache.popitem(last=False)
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Evicted PostgreSQL pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close_all(self) -> None:
|
||||||
|
async with self._pg_lock:
|
||||||
|
for key, entry in list(self._pg_cache.items()):
|
||||||
|
await entry.engine.dispose()
|
||||||
|
logger.info(
|
||||||
|
"Closed PostgreSQL engine for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
self._pg_cache.clear()
|
||||||
|
|
||||||
|
async with self._ts_lock:
|
||||||
|
for key, pool in list(self._ts_cache.items()):
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Closed TimescaleDB pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
self._ts_cache.clear()
|
||||||
|
|
||||||
|
async with self._pg_raw_lock:
|
||||||
|
for key, pool in list(self._pg_raw_cache.items()):
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Closed PostgreSQL pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
self._pg_raw_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
project_connection_manager = ProjectConnectionManager()
|
||||||
3
app/infra/db/metadata/__init__.py
Normal file
3
app/infra/db/metadata/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .database import get_metadata_session, close_metadata_engine
|
||||||
|
|
||||||
|
__all__ = ["get_metadata_session", "close_metadata_engine"]
|
||||||
27
app/infra/db/metadata/database.py
Normal file
27
app/infra/db/metadata/database.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.METADATA_DATABASE_URI,
|
||||||
|
pool_size=settings.METADATA_DB_POOL_SIZE,
|
||||||
|
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with SessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def close_metadata_engine() -> None:
|
||||||
|
await engine.dispose()
|
||||||
|
logger.info("Metadata database engine disposed.")
|
||||||
115
app/infra/db/metadata/models.py
Normal file
115
app/infra/db/metadata/models.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
keycloak_id: Mapped[UUID] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), unique=True, index=True
|
||||||
|
)
|
||||||
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
|
role: Mapped[str] = mapped_column(String(20), default="user")
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Project(Base):
|
||||||
|
__tablename__ = "projects"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
code: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDatabase(Base):
|
||||||
|
__tablename__ = "project_databases"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||||
|
db_role: Mapped[str] = mapped_column(String(20))
|
||||||
|
db_type: Mapped[str] = mapped_column(String(20))
|
||||||
|
dsn_encrypted: Mapped[str] = mapped_column(Text)
|
||||||
|
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
|
||||||
|
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectGeoServerConfig(Base):
|
||||||
|
__tablename__ = "project_geoserver_configs"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
project_id: Mapped[UUID] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), unique=True, index=True
|
||||||
|
)
|
||||||
|
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||||
|
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
|
||||||
|
Text, nullable=True
|
||||||
|
)
|
||||||
|
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
|
||||||
|
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
srid: Mapped[int] = mapped_column(Integer, default=4326)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserProjectMembership(Base):
|
||||||
|
__tablename__ = "user_project_membership"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||||
|
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||||
|
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLog(Base):
|
||||||
|
__tablename__ = "audit_logs"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
user_id: Mapped[UUID | None] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), nullable=True, index=True
|
||||||
|
)
|
||||||
|
project_id: Mapped[UUID | None] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), nullable=True, index=True
|
||||||
|
)
|
||||||
|
action: Mapped[str] = mapped_column(String(50))
|
||||||
|
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||||
|
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||||
|
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||||
|
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
timestamp: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
@@ -1,23 +1,17 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from typing import Optional
|
|
||||||
from psycopg import AsyncConnection
|
from psycopg import AsyncConnection
|
||||||
|
|
||||||
from .database import get_database_instance
|
|
||||||
from .scada_info import ScadaRepository
|
from .scada_info import ScadaRepository
|
||||||
from .scheme import SchemeRepository
|
from .scheme import SchemeRepository
|
||||||
|
from app.auth.project_dependencies import get_project_pg_connection
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
# 创建支持数据库选择的连接依赖函数
|
# 动态项目 PostgreSQL 连接依赖
|
||||||
async def get_database_connection(
|
async def get_database_connection(
|
||||||
db_name: Optional[str] = Query(
|
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
|
||||||
instance = await get_database_instance(db_name)
|
|
||||||
async with instance.get_connection() as conn:
|
|
||||||
yield conn
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,41 +1,31 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from psycopg import AsyncConnection
|
from psycopg import AsyncConnection
|
||||||
|
|
||||||
from .database import get_database_instance
|
|
||||||
from .schemas.realtime import RealtimeRepository
|
from .schemas.realtime import RealtimeRepository
|
||||||
from .schemas.scheme import SchemeRepository
|
from .schemas.scheme import SchemeRepository
|
||||||
from .schemas.scada import ScadaRepository
|
from .schemas.scada import ScadaRepository
|
||||||
from .composite_queries import CompositeQueries
|
from .composite_queries import CompositeQueries
|
||||||
from app.infra.db.postgresql.database import (
|
from app.auth.project_dependencies import (
|
||||||
get_database_instance as get_postgres_database_instance,
|
get_project_pg_connection,
|
||||||
|
get_project_timescale_connection,
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
# 创建支持数据库选择的连接依赖函数
|
# 动态项目 TimescaleDB 连接依赖
|
||||||
async def get_database_connection(
|
async def get_database_connection(
|
||||||
db_name: Optional[str] = Query(
|
conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
|
||||||
instance = await get_database_instance(db_name)
|
|
||||||
async with instance.get_connection() as conn:
|
|
||||||
yield conn
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
# PostgreSQL 数据库连接依赖函数
|
# 动态项目 PostgreSQL 连接依赖
|
||||||
async def get_postgres_connection(
|
async def get_postgres_connection(
|
||||||
db_name: Optional[str] = Query(
|
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||||
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
|
|
||||||
instance = await get_postgres_database_instance(db_name)
|
|
||||||
async with instance.get_connection() as conn:
|
|
||||||
yield conn
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,220 +1,112 @@
|
|||||||
from typing import Optional, List
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
from typing import Optional, List
|
||||||
from app.infra.db.postgresql.database import Database
|
from uuid import UUID
|
||||||
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
|
|
||||||
import logging
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.domain.schemas.audit import AuditLogResponse
|
||||||
|
from app.infra.db.metadata import models
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class AuditRepository:
|
class AuditRepository:
|
||||||
"""审计日志数据访问层"""
|
"""审计日志数据访问层(system_hub)"""
|
||||||
|
|
||||||
def __init__(self, db: Database):
|
def __init__(self, session: AsyncSession):
|
||||||
self.db = db
|
self.session = session
|
||||||
|
|
||||||
async def create_log(
|
async def create_log(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[int] = None,
|
action: str,
|
||||||
username: Optional[str] = None,
|
user_id: Optional[UUID] = None,
|
||||||
action: str = "",
|
project_id: Optional[UUID] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
resource_id: Optional[str] = None,
|
resource_id: Optional[str] = None,
|
||||||
ip_address: Optional[str] = None,
|
ip_address: Optional[str] = None,
|
||||||
user_agent: Optional[str] = None,
|
|
||||||
request_method: Optional[str] = None,
|
request_method: Optional[str] = None,
|
||||||
request_path: Optional[str] = None,
|
request_path: Optional[str] = None,
|
||||||
request_data: Optional[dict] = None,
|
request_data: Optional[dict] = None,
|
||||||
response_status: Optional[int] = None,
|
response_status: Optional[int] = None,
|
||||||
error_message: Optional[str] = None
|
) -> AuditLogResponse:
|
||||||
) -> Optional[AuditLogResponse]:
|
log = models.AuditLog(
|
||||||
"""
|
user_id=user_id,
|
||||||
创建审计日志
|
project_id=project_id,
|
||||||
|
action=action,
|
||||||
Args:
|
resource_type=resource_type,
|
||||||
参数说明见 AuditLogCreate
|
resource_id=resource_id,
|
||||||
|
ip_address=ip_address,
|
||||||
Returns:
|
request_method=request_method,
|
||||||
创建的审计日志对象
|
request_path=request_path,
|
||||||
"""
|
request_data=request_data,
|
||||||
query = """
|
response_status=response_status,
|
||||||
INSERT INTO audit_logs (
|
timestamp=datetime.utcnow(),
|
||||||
user_id, username, action, resource_type, resource_id,
|
|
||||||
ip_address, user_agent, request_method, request_path,
|
|
||||||
request_data, response_status, error_message
|
|
||||||
)
|
)
|
||||||
VALUES (
|
self.session.add(log)
|
||||||
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
|
await self.session.commit()
|
||||||
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
|
await self.session.refresh(log)
|
||||||
%(request_data)s, %(response_status)s, %(error_message)s
|
return AuditLogResponse.model_validate(log)
|
||||||
)
|
|
||||||
RETURNING id, user_id, username, action, resource_type, resource_id,
|
|
||||||
ip_address, user_agent, request_method, request_path,
|
|
||||||
request_data, response_status, error_message, timestamp
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.db.get_connection() as conn:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute(query, {
|
|
||||||
'user_id': user_id,
|
|
||||||
'username': username,
|
|
||||||
'action': action,
|
|
||||||
'resource_type': resource_type,
|
|
||||||
'resource_id': resource_id,
|
|
||||||
'ip_address': ip_address,
|
|
||||||
'user_agent': user_agent,
|
|
||||||
'request_method': request_method,
|
|
||||||
'request_path': request_path,
|
|
||||||
'request_data': json.dumps(request_data) if request_data else None,
|
|
||||||
'response_status': response_status,
|
|
||||||
'error_message': error_message
|
|
||||||
})
|
|
||||||
row = await cur.fetchone()
|
|
||||||
if row:
|
|
||||||
return AuditLogResponse(**row)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating audit log: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_logs(
|
async def get_logs(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[int] = None,
|
user_id: Optional[UUID] = None,
|
||||||
username: Optional[str] = None,
|
project_id: Optional[UUID] = None,
|
||||||
action: Optional[str] = None,
|
action: Optional[str] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[datetime] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[datetime] = None,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100
|
limit: int = 100,
|
||||||
) -> List[AuditLogResponse]:
|
) -> List[AuditLogResponse]:
|
||||||
"""
|
|
||||||
查询审计日志
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID过滤
|
|
||||||
username: 用户名过滤
|
|
||||||
action: 操作类型过滤
|
|
||||||
resource_type: 资源类型过滤
|
|
||||||
start_time: 开始时间
|
|
||||||
end_time: 结束时间
|
|
||||||
skip: 跳过记录数
|
|
||||||
limit: 限制记录数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
审计日志列表
|
|
||||||
"""
|
|
||||||
# 构建动态查询
|
|
||||||
conditions = []
|
conditions = []
|
||||||
params = {'skip': skip, 'limit': limit}
|
|
||||||
|
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
conditions.append("user_id = %(user_id)s")
|
conditions.append(models.AuditLog.user_id == user_id)
|
||||||
params['user_id'] = user_id
|
if project_id is not None:
|
||||||
|
conditions.append(models.AuditLog.project_id == project_id)
|
||||||
if username:
|
|
||||||
conditions.append("username = %(username)s")
|
|
||||||
params['username'] = username
|
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
conditions.append("action = %(action)s")
|
conditions.append(models.AuditLog.action == action)
|
||||||
params['action'] = action
|
|
||||||
|
|
||||||
if resource_type:
|
if resource_type:
|
||||||
conditions.append("resource_type = %(resource_type)s")
|
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||||
params['resource_type'] = resource_type
|
|
||||||
|
|
||||||
if start_time:
|
if start_time:
|
||||||
conditions.append("timestamp >= %(start_time)s")
|
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||||
params['start_time'] = start_time
|
|
||||||
|
|
||||||
if end_time:
|
if end_time:
|
||||||
conditions.append("timestamp <= %(end_time)s")
|
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||||
params['end_time'] = end_time
|
|
||||||
|
|
||||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
stmt = (
|
||||||
|
select(models.AuditLog)
|
||||||
query = f"""
|
.where(*conditions)
|
||||||
SELECT id, user_id, username, action, resource_type, resource_id,
|
.order_by(models.AuditLog.timestamp.desc())
|
||||||
ip_address, user_agent, request_method, request_path,
|
.offset(skip)
|
||||||
request_data, response_status, error_message, timestamp
|
.limit(limit)
|
||||||
FROM audit_logs
|
)
|
||||||
{where_clause}
|
result = await self.session.execute(stmt)
|
||||||
ORDER BY timestamp DESC
|
return [
|
||||||
LIMIT %(limit)s OFFSET %(skip)s
|
AuditLogResponse.model_validate(log)
|
||||||
"""
|
for log in result.scalars().all()
|
||||||
|
]
|
||||||
try:
|
|
||||||
async with self.db.get_connection() as conn:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute(query, params)
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
return [AuditLogResponse(**row) for row in rows]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error querying audit logs: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_log_count(
|
async def get_log_count(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[int] = None,
|
user_id: Optional[UUID] = None,
|
||||||
username: Optional[str] = None,
|
project_id: Optional[UUID] = None,
|
||||||
action: Optional[str] = None,
|
action: Optional[str] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[datetime] = None,
|
||||||
end_time: Optional[datetime] = None
|
end_time: Optional[datetime] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
|
||||||
获取审计日志数量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
参数同 get_logs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
日志总数
|
|
||||||
"""
|
|
||||||
conditions = []
|
conditions = []
|
||||||
params = {}
|
|
||||||
|
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
conditions.append("user_id = %(user_id)s")
|
conditions.append(models.AuditLog.user_id == user_id)
|
||||||
params['user_id'] = user_id
|
if project_id is not None:
|
||||||
|
conditions.append(models.AuditLog.project_id == project_id)
|
||||||
if username:
|
|
||||||
conditions.append("username = %(username)s")
|
|
||||||
params['username'] = username
|
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
conditions.append("action = %(action)s")
|
conditions.append(models.AuditLog.action == action)
|
||||||
params['action'] = action
|
|
||||||
|
|
||||||
if resource_type:
|
if resource_type:
|
||||||
conditions.append("resource_type = %(resource_type)s")
|
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||||
params['resource_type'] = resource_type
|
|
||||||
|
|
||||||
if start_time:
|
if start_time:
|
||||||
conditions.append("timestamp >= %(start_time)s")
|
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||||
params['start_time'] = start_time
|
|
||||||
|
|
||||||
if end_time:
|
if end_time:
|
||||||
conditions.append("timestamp <= %(end_time)s")
|
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||||
params['end_time'] = end_time
|
|
||||||
|
|
||||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
query = f"""
|
return int(result.scalar() or 0)
|
||||||
SELECT COUNT(*) as count
|
|
||||||
FROM audit_logs
|
|
||||||
{where_clause}
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.db.get_connection() as conn:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute(query, params)
|
|
||||||
result = await cur.fetchone()
|
|
||||||
return result['count'] if result else 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error counting audit logs: {e}")
|
|
||||||
return 0
|
|
||||||
|
|||||||
164
app/infra/repositories/metadata_repository.py
Normal file
164
app/infra/repositories/metadata_repository.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.encryption import get_encryptor
|
||||||
|
from app.infra.db.metadata import models
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectDbRouting:
|
||||||
|
project_id: UUID
|
||||||
|
db_role: str
|
||||||
|
db_type: str
|
||||||
|
dsn: str
|
||||||
|
pool_min_size: int
|
||||||
|
pool_max_size: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectGeoServerInfo:
|
||||||
|
project_id: UUID
|
||||||
|
gs_base_url: Optional[str]
|
||||||
|
gs_admin_user: Optional[str]
|
||||||
|
gs_admin_password: Optional[str]
|
||||||
|
gs_datastore_name: str
|
||||||
|
default_extent: Optional[dict]
|
||||||
|
srid: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectSummary:
|
||||||
|
project_id: UUID
|
||||||
|
name: str
|
||||||
|
code: str
|
||||||
|
description: Optional[str]
|
||||||
|
gs_workspace: str
|
||||||
|
status: str
|
||||||
|
project_role: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataRepository:
|
||||||
|
"""元数据访问层(system_hub)"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def get_user_by_keycloak_id(
|
||||||
|
self, keycloak_id: UUID
|
||||||
|
) -> Optional[models.User]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.User).where(models.User.keycloak_id == keycloak_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.Project).where(models.Project.id == project_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_membership_role(
|
||||||
|
self, project_id: UUID, user_id: UUID
|
||||||
|
) -> Optional[str]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.UserProjectMembership.project_role).where(
|
||||||
|
models.UserProjectMembership.project_id == project_id,
|
||||||
|
models.UserProjectMembership.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_project_db_routing(
|
||||||
|
self, project_id: UUID, db_role: str
|
||||||
|
) -> Optional[ProjectDbRouting]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.ProjectDatabase).where(
|
||||||
|
models.ProjectDatabase.project_id == project_id,
|
||||||
|
models.ProjectDatabase.db_role == db_role,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if not record:
|
||||||
|
return None
|
||||||
|
encryptor = get_encryptor()
|
||||||
|
dsn = encryptor.decrypt(record.dsn_encrypted)
|
||||||
|
return ProjectDbRouting(
|
||||||
|
project_id=record.project_id,
|
||||||
|
db_role=record.db_role,
|
||||||
|
db_type=record.db_type,
|
||||||
|
dsn=dsn,
|
||||||
|
pool_min_size=record.pool_min_size,
|
||||||
|
pool_max_size=record.pool_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_geoserver_config(
|
||||||
|
self, project_id: UUID
|
||||||
|
) -> Optional[ProjectGeoServerInfo]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.ProjectGeoServerConfig).where(
|
||||||
|
models.ProjectGeoServerConfig.project_id == project_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if not record:
|
||||||
|
return None
|
||||||
|
encryptor = get_encryptor()
|
||||||
|
password = (
|
||||||
|
encryptor.decrypt(record.gs_admin_password_encrypted)
|
||||||
|
if record.gs_admin_password_encrypted
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return ProjectGeoServerInfo(
|
||||||
|
project_id=record.project_id,
|
||||||
|
gs_base_url=record.gs_base_url,
|
||||||
|
gs_admin_user=record.gs_admin_user,
|
||||||
|
gs_admin_password=password,
|
||||||
|
gs_datastore_name=record.gs_datastore_name,
|
||||||
|
default_extent=record.default_extent,
|
||||||
|
srid=record.srid,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
|
||||||
|
stmt = (
|
||||||
|
select(models.Project, models.UserProjectMembership.project_role)
|
||||||
|
.join(
|
||||||
|
models.UserProjectMembership,
|
||||||
|
models.UserProjectMembership.project_id == models.Project.id,
|
||||||
|
)
|
||||||
|
.where(models.UserProjectMembership.user_id == user_id)
|
||||||
|
.order_by(models.Project.name)
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return [
|
||||||
|
ProjectSummary(
|
||||||
|
project_id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role=role,
|
||||||
|
)
|
||||||
|
for project, role in result.all()
|
||||||
|
]
|
||||||
|
|
||||||
|
async def list_all_projects(self) -> List[ProjectSummary]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.Project).order_by(models.Project.name)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
ProjectSummary(
|
||||||
|
project_id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role="owner",
|
||||||
|
)
|
||||||
|
for project in result.scalars().all()
|
||||||
|
]
|
||||||
@@ -9,6 +9,8 @@ import app.services.project_info as project_info
|
|||||||
from app.api.v1.router import api_router
|
from app.api.v1.router import api_router
|
||||||
from app.infra.db.timescaledb.database import db as tsdb
|
from app.infra.db.timescaledb.database import db as tsdb
|
||||||
from app.infra.db.postgresql.database import db as pgdb
|
from app.infra.db.postgresql.database import db as pgdb
|
||||||
|
from app.infra.db.dynamic_manager import project_connection_manager
|
||||||
|
from app.infra.db.metadata.database import close_metadata_engine
|
||||||
from app.services.tjnetwork import open_project
|
from app.services.tjnetwork import open_project
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
@@ -46,6 +48,8 @@ async def lifespan(app: FastAPI):
|
|||||||
# 清理资源
|
# 清理资源
|
||||||
await tsdb.close()
|
await tsdb.close()
|
||||||
await pgdb.close()
|
await pgdb.close()
|
||||||
|
await project_connection_manager.close_all()
|
||||||
|
await close_metadata_engine()
|
||||||
logger.info("Database connections closed")
|
logger.info("Database connections closed")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user