重构数据库连接管理,添加元数据支持

This commit is contained in:
2026-02-11 18:57:47 +08:00
parent ff2011ae24
commit 780a48d927
21 changed files with 1195 additions and 305 deletions

View File

@@ -31,6 +31,25 @@ TIMESCALEDB_DB_PORT="5433"
TIMESCALEDB_DB_USER="tjwater"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
# ============================================
# 元数据数据库配置 (Metadata DB)
# ============================================
METADATA_DB_NAME="system_hub"
METADATA_DB_HOST="localhost"
METADATA_DB_PORT="5432"
METADATA_DB_USER="tjwater"
METADATA_DB_PASSWORD="password"
# ============================================
# 项目连接缓存与连接池配置
# ============================================
PROJECT_PG_CACHE_SIZE=50
PROJECT_TS_CACHE_SIZE=50
PROJECT_PG_POOL_SIZE=5
PROJECT_PG_MAX_OVERFLOW=10
PROJECT_TS_POOL_MIN_SIZE=1
PROJECT_TS_POOL_MAX_SIZE=10
# ============================================
# InfluxDB 配置 (时序数据)
# ============================================
@@ -46,6 +65,15 @@ TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
# REFRESH_TOKEN_EXPIRE_DAYS=7
# ALGORITHM=HS256
# ============================================
# Keycloak JWT (可选)
# ============================================
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
# KEYCLOAK_ALGORITHM=RS256
# 临时禁用鉴权(调试用)
# AUTH_DISABLED=false
# ============================================
# 其他配置
# ============================================

View File

@@ -4,33 +4,38 @@
仅管理员可访问
"""
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from fastapi import APIRouter, Depends, Query, Request
from app.domain.schemas.audit import AuditLogResponse, AuditLogQuery
from app.domain.schemas.user import UserInDB
from fastapi import APIRouter, Depends, Query
from app.domain.schemas.audit import AuditLogResponse
from app.infra.repositories.audit_repository import AuditRepository
from app.auth.dependencies import get_user_repository, get_db
from app.auth.permissions import get_current_admin
from app.infra.db.postgresql.database import Database
from app.auth.metadata_dependencies import (
get_current_metadata_admin,
get_current_metadata_user,
)
from app.infra.db.metadata.database import get_metadata_session
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
async def get_audit_repository(db: Database = Depends(get_db)) -> AuditRepository:
async def get_audit_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> AuditRepository:
"""获取审计日志仓储"""
return AuditRepository(db)
return AuditRepository(session)
@router.get("/logs", response_model=List[AuditLogResponse])
async def get_audit_logs(
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
username: Optional[str] = Query(None, description="用户名过滤"),
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="项目ID过滤"),
action: Optional[str] = Query(None, description="按操作类型过滤"),
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_admin),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询审计日志(仅管理员)
@@ -39,7 +44,7 @@ async def get_audit_logs(
"""
logs = await audit_repo.get_logs(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
start_time=start_time,
@@ -51,21 +56,21 @@ async def get_audit_logs(
@router.get("/logs/count")
async def get_audit_logs_count(
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
username: Optional[str] = Query(None, description="用户名过滤"),
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="项目ID过滤"),
action: Optional[str] = Query(None, description="按操作类型过滤"),
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_admin),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> dict:
"""
获取审计日志总数(仅管理员)
"""
count = await audit_repo.get_log_count(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
start_time=start_time,
@@ -80,8 +85,8 @@ async def get_my_audit_logs(
end_time: Optional[datetime] = Query(None, description="结束时间"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_user),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询当前用户的审计日志

View File

@@ -0,0 +1,90 @@
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg import AsyncConnection
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.project_dependencies import (
ProjectContext,
get_project_context,
get_project_pg_session,
get_project_timescale_connection,
get_metadata_repository,
)
from app.auth.metadata_dependencies import get_current_metadata_user
from app.core.config import settings
from app.domain.schemas.metadata import (
GeoServerConfigResponse,
ProjectMetaResponse,
ProjectSummaryResponse,
)
from app.infra.repositories.metadata_repository import MetadataRepository
router = APIRouter()
@router.get("/meta/project", response_model=ProjectMetaResponse)
async def get_project_metadata(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
project = await metadata_repo.get_project_by_id(ctx.project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
geoserver_payload = (
GeoServerConfigResponse(
gs_base_url=geoserver.gs_base_url,
gs_admin_user=geoserver.gs_admin_user,
gs_datastore_name=geoserver.gs_datastore_name,
default_extent=geoserver.default_extent,
srid=geoserver.srid,
)
if geoserver
else None
)
return ProjectMetaResponse(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=ctx.project_role,
geoserver=geoserver_payload,
)
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
async def list_user_projects(
current_user=Depends(get_current_metadata_user),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
if settings.AUTH_DISABLED:
projects = await metadata_repo.list_all_projects()
else:
projects = await metadata_repo.list_projects_for_user(current_user.id)
return [
ProjectSummaryResponse(
project_id=project.project_id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=project.project_role,
)
for project in projects
]
@router.get("/meta/db/health")
async def project_db_health(
pg_session: AsyncSession = Depends(get_project_pg_session),
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
):
await pg_session.execute(text("SELECT 1"))
async with ts_conn.cursor() as cur:
await cur.execute("SELECT 1")
return {"postgres": "ok", "timescale": "ok"}

View File

@@ -14,6 +14,7 @@ from app.api.v1.endpoints import (
cache,
user_management, # 新增:用户管理
audit, # 新增:审计日志
meta,
)
from app.api.v1.endpoints.network import (
general,
@@ -46,6 +47,7 @@ api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
api_router.include_router(meta.router, tags=["Metadata"])
api_router.include_router(project.router, tags=["Project"])
# Network Elements (Node/Link Types)

View File

@@ -0,0 +1,56 @@
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from app.core.config import settings
oauth2_optional = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
)
async def get_current_keycloak_sub(
token: str | None = Depends(oauth2_optional),
) -> UUID:
if settings.AUTH_DISABLED:
return UUID(int=0)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
if settings.KEYCLOAK_PUBLIC_KEY:
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
algorithms = [settings.KEYCLOAK_ALGORITHM]
else:
key = settings.SECRET_KEY
algorithms = [settings.ALGORITHM]
try:
payload = jwt.decode(token, key, algorithms=algorithms)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
sub = payload.get("sub")
if not sub:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing subject claim",
headers={"WWW-Authenticate": "Bearer"},
)
try:
return UUID(sub)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid subject claim",
headers={"WWW-Authenticate": "Bearer"},
) from exc

View File

@@ -0,0 +1,50 @@
from dataclasses import dataclass
from uuid import UUID
from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
async def get_metadata_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> MetadataRepository:
return MetadataRepository(session)
async def get_current_metadata_user(
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
if settings.AUTH_DISABLED:
return _AuthBypassUser()
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
async def get_current_metadata_admin(
user=Depends(get_current_metadata_user),
):
if settings.AUTH_DISABLED:
return user
if user.is_superuser or user.role == "admin":
return user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
)
@dataclass(frozen=True)
class _AuthBypassUser:
id: UUID = UUID(int=0)
role: str = "admin"
is_superuser: bool = True
is_active: bool = True

View File

@@ -0,0 +1,176 @@
from dataclasses import dataclass
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Depends, Header, HTTPException, status
from psycopg import AsyncConnection
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.dynamic_manager import project_connection_manager
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
DB_ROLE_BIZ_PG = "biz_pg"
DB_ROLE_IOT_TS = "iot_ts"
DB_TYPE_POSTGRES = "postgresql"
DB_TYPE_TIMESCALE = "timescaledb"
@dataclass(frozen=True)
class ProjectContext:
project_id: UUID
user_id: UUID
project_role: str
async def get_metadata_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> MetadataRepository:
return MetadataRepository(session)
async def get_project_context(
x_project_id: str = Header(..., alias="X-Project-Id"),
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> ProjectContext:
try:
project_uuid = UUID(x_project_id)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
) from exc
project = await metadata_repo.get_project_by_id(project_uuid)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
if project.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
)
if settings.AUTH_DISABLED:
return ProjectContext(
project_id=project.id,
user_id=UUID(int=0),
project_role="owner",
)
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
if not user:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
if not membership_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
)
return ProjectContext(
project_id=project.id,
user_id=user.id,
project_role=membership_role,
)
async def get_project_pg_session(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncSession, None]:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_BIZ_PG
)
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL not configured",
)
if routing.db_type != DB_TYPE_POSTGRES:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
ctx.project_id,
DB_ROLE_BIZ_PG,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with sessionmaker() as session:
yield session
async def get_project_pg_connection(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncConnection, None]:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_BIZ_PG
)
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL not configured",
)
if routing.db_type != DB_TYPE_POSTGRES:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
pool = await project_connection_manager.get_pg_pool(
ctx.project_id,
DB_ROLE_BIZ_PG,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with pool.connection() as conn:
yield conn
async def get_project_timescale_connection(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncConnection, None]:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_IOT_TS
)
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project TimescaleDB not configured",
)
if routing.db_type != DB_TYPE_TIMESCALE:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project TimescaleDB type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
pool = await project_connection_manager.get_timescale_pool(
ctx.project_id,
DB_ROLE_IOT_TS,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with pool.connection() as conn:
yield conn

View File

@@ -7,6 +7,7 @@
from typing import Optional
from datetime import datetime
import logging
from uuid import UUID
logger = logging.getLogger(__name__)
@@ -38,18 +39,16 @@ class AuditAction:
async def log_audit_event(
action: str,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_method: Optional[str] = None,
request_path: Optional[str] = None,
request_data: Optional[dict] = None,
response_status: Optional[int] = None,
error_message: Optional[str] = None,
db=None, # 新增:可选的数据库实例
session=None,
):
"""
记录审计日志
@@ -57,67 +56,60 @@ async def log_audit_event(
Args:
action: 操作类型
user_id: 用户ID
username: 用户名
project_id: 项目ID
resource_type: 资源类型
resource_id: 资源ID
ip_address: IP地址
user_agent: User-Agent
request_method: 请求方法
request_path: 请求路径
request_data: 请求数据(敏感字段需脱敏)
response_status: 响应状态码
error_message: 错误消息
db: 数据库实例(可选,如果不提供则尝试获取)
session: 元数据库会话(可选)
"""
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.audit_repository import AuditRepository
try:
# 脱敏敏感数据
if request_data:
request_data = sanitize_sensitive_data(request_data)
# 如果没有提供数据库实例,尝试从全局获取
if db is None:
try:
from app.infra.db.postgresql.database import db as default_db
# 仅当连接池已初始化时使用
if default_db.pool:
db = default_db
except ImportError:
pass
# 如果仍然没有数据库实例
if db is None:
# 在某些上下文中可能无法获取,此时静默失败
logger.warning("No database instance provided for audit logging")
return
audit_repo = AuditRepository(db)
if request_data:
request_data = sanitize_sensitive_data(request_data)
if session is None:
async with SessionLocal() as session:
audit_repo = AuditRepository(session)
await audit_repo.create_log(
user_id=user_id,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
)
else:
audit_repo = AuditRepository(session)
await audit_repo.create_log(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
error_message=error_message,
)
logger.info(
f"Audit log created: action={action}, user={username or user_id}, "
f"resource={resource_type}:{resource_id}"
)
except Exception as e:
# 审计日志失败不应影响业务流程
logger.error(f"Failed to create audit log: {e}", exc_info=True)
logger.info(
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
action,
user_id,
project_id,
resource_type,
resource_id,
)
def sanitize_sensitive_data(data: dict) -> dict:

View File

@@ -35,10 +35,41 @@ class Settings(BaseSettings):
INFLUXDB_ORG: str = "org"
INFLUXDB_BUCKET: str = "bucket"
# Metadata Database Config (PostgreSQL)
METADATA_DB_NAME: str = "system_hub"
METADATA_DB_HOST: str = "localhost"
METADATA_DB_PORT: str = "5432"
METADATA_DB_USER: str = "postgres"
METADATA_DB_PASSWORD: str = "password"
METADATA_DB_POOL_SIZE: int = 5
METADATA_DB_MAX_OVERFLOW: int = 10
PROJECT_PG_CACHE_SIZE: int = 50
PROJECT_TS_CACHE_SIZE: int = 50
PROJECT_PG_POOL_SIZE: int = 5
PROJECT_PG_MAX_OVERFLOW: int = 10
PROJECT_TS_POOL_MIN_SIZE: int = 1
PROJECT_TS_POOL_MAX_SIZE: int = 10
# Keycloak JWT (optional override)
KEYCLOAK_PUBLIC_KEY: str = ""
KEYCLOAK_ALGORITHM: str = "RS256"
# Auth bypass (temporary)
AUTH_DISABLED: bool = False
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
@property
def METADATA_DATABASE_URI(self) -> str:
return (
f"postgresql+psycopg://{self.METADATA_DB_USER}:{self.METADATA_DB_PASSWORD}"
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
)
class Config:
env_file = ".env"
extra = "ignore"

View File

@@ -1,45 +1,42 @@
from datetime import datetime
from typing import Optional, Any
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
class AuditLogCreate(BaseModel):
"""创建审计日志"""
user_id: Optional[int] = None
username: Optional[str] = None
user_id: Optional[UUID] = None
project_id: Optional[UUID] = None
action: str
resource_type: Optional[str] = None
resource_id: Optional[str] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
request_method: Optional[str] = None
request_path: Optional[str] = None
request_data: Optional[dict] = None
response_status: Optional[int] = None
error_message: Optional[str] = None
class AuditLogResponse(BaseModel):
"""审计日志响应"""
id: int
user_id: Optional[int]
username: Optional[str]
id: UUID
user_id: Optional[UUID]
project_id: Optional[UUID]
action: str
resource_type: Optional[str]
resource_id: Optional[str]
ip_address: Optional[str]
user_agent: Optional[str]
request_method: Optional[str]
request_path: Optional[str]
request_data: Optional[dict]
response_status: Optional[int]
error_message: Optional[str]
timestamp: datetime
model_config = ConfigDict(from_attributes=True)
class AuditLogQuery(BaseModel):
"""审计日志查询参数"""
user_id: Optional[int] = None
username: Optional[str] = None
user_id: Optional[UUID] = None
project_id: Optional[UUID] = None
action: Optional[str] = None
resource_type: Optional[str] = None
start_time: Optional[datetime] = None

View File

@@ -0,0 +1,33 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class GeoServerConfigResponse(BaseModel):
gs_base_url: Optional[str]
gs_admin_user: Optional[str]
gs_datastore_name: str
default_extent: Optional[dict]
srid: int
class ProjectMetaResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str
geoserver: Optional[GeoServerConfigResponse]
class ProjectSummaryResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str

View File

@@ -6,12 +6,17 @@
import time
import json
from uuid import UUID
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.infra.db.postgresql.database import db as default_db
from app.core.audit import log_audit_event, AuditAction
import logging
from jose import JWTError, jwt
from app.core.config import settings
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
# 4. 提取审计所需信息
user_id = None
username = None
# 尝试从请求状态获取当前用户
if hasattr(request.state, "user"):
user = request.state.user
user_id = getattr(user, "id", None)
username = getattr(user, "username", None)
user_id = await self._resolve_user_id(request)
project_id = self._resolve_project_id(request)
# 获取客户端信息
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
# 确定操作类型
action = self._determine_action(request)
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
await log_audit_event(
action=action,
user_id=user_id,
username=username,
project_id=project_id,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_method=request.method,
request_path=str(request.url.path),
request_data=request_data,
response_status=response.status_code,
error_message=(
None
if response.status_code < 400
else f"HTTP {response.status_code}"
),
db=default_db,
)
except Exception as e:
# 审计失败不应影响响应
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id")
if not project_header:
return None
try:
return UUID(project_header)
except ValueError:
return None
async def _resolve_user_id(self, request: Request) -> UUID | None:
auth_header = request.headers.get("authorization")
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header.split(" ", 1)[1].strip()
if not token:
return None
try:
key = (
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
if settings.KEYCLOAK_PUBLIC_KEY
else settings.SECRET_KEY
)
algorithms = (
[settings.KEYCLOAK_ALGORITHM]
if settings.KEYCLOAK_PUBLIC_KEY
else [settings.ALGORITHM]
)
payload = jwt.decode(token, key, algorithms=algorithms)
sub = payload.get("sub")
if not sub:
return None
keycloak_id = UUID(sub)
except (JWTError, ValueError):
return None
async with SessionLocal() as session:
repo = MetadataRepository(session)
user = await repo.get_user_by_keycloak_id(keycloak_id)
if user and user.is_active:
return user.id
return None
def _determine_action(self, request: Request) -> str:
"""根据请求路径和方法确定操作类型"""
path = request.url.path.lower()

View File

@@ -0,0 +1,208 @@
import asyncio
import logging
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict
from uuid import UUID
from psycopg_pool import AsyncConnectionPool
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class PgEngineEntry:
engine: AsyncEngine
sessionmaker: async_sessionmaker[AsyncSession]
@dataclass(frozen=True)
class CacheKey:
project_id: UUID
db_role: str
class ProjectConnectionManager:
def __init__(self) -> None:
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_lock = asyncio.Lock()
self._ts_lock = asyncio.Lock()
self._pg_raw_lock = asyncio.Lock()
def _normalize_pg_url(self, url: str) -> str:
parsed = make_url(url)
if parsed.drivername == "postgresql":
parsed = parsed.set(drivername="postgresql+psycopg")
return str(parsed)
async def get_pg_sessionmaker(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> async_sessionmaker[AsyncSession]:
async with self._pg_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._pg_cache.get(key)
if entry:
self._pg_cache.move_to_end(key)
return entry.sessionmaker
normalized_url = self._normalize_pg_url(connection_url)
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
engine = create_async_engine(
normalized_url,
pool_size=pool_min_size,
max_overflow=max(0, pool_max_size - pool_min_size),
pool_pre_ping=True,
)
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
self._pg_cache[key] = PgEngineEntry(
engine=engine,
sessionmaker=sessionmaker,
)
await self._evict_pg_if_needed()
logger.info(
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
)
return sessionmaker
async def get_timescale_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._ts_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._ts_cache.get(key)
if pool:
self._ts_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
)
await pool.open()
self._ts_cache[key] = pool
await self._evict_ts_if_needed()
logger.info(
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
)
return pool
async def get_pg_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._pg_raw_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._pg_raw_cache.get(key)
if pool:
self._pg_raw_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
)
await pool.open()
self._pg_raw_cache[key] = pool
await self._evict_pg_raw_if_needed()
logger.info(
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
)
return pool
async def _evict_pg_if_needed(self) -> None:
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, entry = self._pg_cache.popitem(last=False)
await entry.engine.dispose()
logger.info(
"Evicted PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_ts_if_needed(self) -> None:
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
key, pool = self._ts_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_pg_raw_if_needed(self) -> None:
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, pool = self._pg_raw_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def close_all(self) -> None:
async with self._pg_lock:
for key, entry in list(self._pg_cache.items()):
await entry.engine.dispose()
logger.info(
"Closed PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_cache.clear()
async with self._ts_lock:
for key, pool in list(self._ts_cache.items()):
await pool.close()
logger.info(
"Closed TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._ts_cache.clear()
async with self._pg_raw_lock:
for key, pool in list(self._pg_raw_cache.items()):
await pool.close()
logger.info(
"Closed PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_raw_cache.clear()
project_connection_manager = ProjectConnectionManager()

View File

@@ -0,0 +1,3 @@
from .database import get_metadata_session, close_metadata_engine
__all__ = ["get_metadata_session", "close_metadata_engine"]

View File

@@ -0,0 +1,27 @@
import logging
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings
logger = logging.getLogger(__name__)
engine = create_async_engine(
settings.METADATA_DATABASE_URI,
pool_size=settings.METADATA_DB_POOL_SIZE,
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
pool_pre_ping=True,
)
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
async with SessionLocal() as session:
yield session
async def close_metadata_engine() -> None:
await engine.dispose()
logger.info("Metadata database engine disposed.")

View File

@@ -0,0 +1,115 @@
from datetime import datetime
from uuid import UUID
from sqlalchemy import Boolean, DateTime, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
keycloak_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), unique=True, index=True
)
username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True)
role: Mapped[str] = mapped_column(String(20), default="user")
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
last_login_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
class Project(Base):
__tablename__ = "projects"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
name: Mapped[str] = mapped_column(String(100))
code: Mapped[str] = mapped_column(String(50), unique=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
status: Mapped[str] = mapped_column(String(20), default="active")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
class ProjectDatabase(Base):
__tablename__ = "project_databases"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
db_role: Mapped[str] = mapped_column(String(20))
db_type: Mapped[str] = mapped_column(String(20))
dsn_encrypted: Mapped[str] = mapped_column(Text)
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
class ProjectGeoServerConfig(Base):
__tablename__ = "project_geoserver_configs"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
project_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), unique=True, index=True
)
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
Text, nullable=True
)
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
srid: Mapped[int] = mapped_column(Integer, default=4326)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
class UserProjectMembership(Base):
__tablename__ = "user_project_membership"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), nullable=True, index=True
)
project_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), nullable=True, index=True
)
action: Mapped[str] = mapped_column(String(50))
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)

View File

@@ -1,24 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from psycopg import AsyncConnection
from .database import get_database_instance
from .scada_info import ScadaRepository
from .scheme import SchemeRepository
from app.auth.project_dependencies import get_project_pg_connection
router = APIRouter()
# 创建支持数据库选择的连接依赖函数
# 动态项目 PostgreSQL 连接依赖
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_pg_connection),
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
@router.get("/scada-info")

View File

@@ -1,42 +1,32 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from typing import List
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
from .schemas.scheme import SchemeRepository
from .schemas.scada import ScadaRepository
from .composite_queries import CompositeQueries
from app.infra.db.postgresql.database import (
get_database_instance as get_postgres_database_instance,
from app.auth.project_dependencies import (
get_project_pg_connection,
get_project_timescale_connection,
)
router = APIRouter()
# 创建支持数据库选择的连接依赖函数
# 动态项目 TimescaleDB 连接依赖
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_timescale_connection),
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
# PostgreSQL 数据库连接依赖函数
# 动态项目 PostgreSQL 连接依赖
async def get_postgres_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_pg_connection),
):
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_postgres_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
# --- Realtime Endpoints ---

View File

@@ -1,220 +1,112 @@
from typing import Optional, List
from datetime import datetime
import json
from app.infra.db.postgresql.database import Database
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
import logging
from typing import Optional, List
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.domain.schemas.audit import AuditLogResponse
from app.infra.db.metadata import models
logger = logging.getLogger(__name__)
class AuditRepository:
"""审计日志数据访问层"""
"""审计日志数据访问层system_hub"""
def __init__(self, db: Database):
self.db = db
def __init__(self, session: AsyncSession):
self.session = session
async def create_log(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
action: str = "",
action: str,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_method: Optional[str] = None,
request_path: Optional[str] = None,
request_data: Optional[dict] = None,
response_status: Optional[int] = None,
error_message: Optional[str] = None
) -> Optional[AuditLogResponse]:
"""
创建审计日志
Args:
参数说明见 AuditLogCreate
Returns:
创建的审计日志对象
"""
query = """
INSERT INTO audit_logs (
user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message
)
VALUES (
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
%(request_data)s, %(response_status)s, %(error_message)s
)
RETURNING id, user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message, timestamp
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, {
'user_id': user_id,
'username': username,
'action': action,
'resource_type': resource_type,
'resource_id': resource_id,
'ip_address': ip_address,
'user_agent': user_agent,
'request_method': request_method,
'request_path': request_path,
'request_data': json.dumps(request_data) if request_data else None,
'response_status': response_status,
'error_message': error_message
})
row = await cur.fetchone()
if row:
return AuditLogResponse(**row)
except Exception as e:
logger.error(f"Error creating audit log: {e}")
raise
return None
) -> AuditLogResponse:
log = models.AuditLog(
user_id=user_id,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
timestamp=datetime.utcnow(),
)
self.session.add(log)
await self.session.commit()
await self.session.refresh(log)
return AuditLogResponse.model_validate(log)
async def get_logs(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
action: Optional[str] = None,
resource_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
skip: int = 0,
limit: int = 100
limit: int = 100,
) -> List[AuditLogResponse]:
"""
查询审计日志
Args:
user_id: 用户ID过滤
username: 用户名过滤
action: 操作类型过滤
resource_type: 资源类型过滤
start_time: 开始时间
end_time: 结束时间
skip: 跳过记录数
limit: 限制记录数
Returns:
审计日志列表
"""
# 构建动态查询
conditions = []
params = {'skip': skip, 'limit': limit}
if user_id is not None:
conditions.append("user_id = %(user_id)s")
params['user_id'] = user_id
if username:
conditions.append("username = %(username)s")
params['username'] = username
conditions.append(models.AuditLog.user_id == user_id)
if project_id is not None:
conditions.append(models.AuditLog.project_id == project_id)
if action:
conditions.append("action = %(action)s")
params['action'] = action
conditions.append(models.AuditLog.action == action)
if resource_type:
conditions.append("resource_type = %(resource_type)s")
params['resource_type'] = resource_type
conditions.append(models.AuditLog.resource_type == resource_type)
if start_time:
conditions.append("timestamp >= %(start_time)s")
params['start_time'] = start_time
conditions.append(models.AuditLog.timestamp >= start_time)
if end_time:
conditions.append("timestamp <= %(end_time)s")
params['end_time'] = end_time
conditions.append(models.AuditLog.timestamp <= end_time)
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"""
SELECT id, user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message, timestamp
FROM audit_logs
{where_clause}
ORDER BY timestamp DESC
LIMIT %(limit)s OFFSET %(skip)s
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
rows = await cur.fetchall()
return [AuditLogResponse(**row) for row in rows]
except Exception as e:
logger.error(f"Error querying audit logs: {e}")
raise
stmt = (
select(models.AuditLog)
.where(*conditions)
.order_by(models.AuditLog.timestamp.desc())
.offset(skip)
.limit(limit)
)
result = await self.session.execute(stmt)
return [
AuditLogResponse.model_validate(log)
for log in result.scalars().all()
]
async def get_log_count(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
action: Optional[str] = None,
resource_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
end_time: Optional[datetime] = None,
) -> int:
"""
获取审计日志数量
Args:
参数同 get_logs
Returns:
日志总数
"""
conditions = []
params = {}
if user_id is not None:
conditions.append("user_id = %(user_id)s")
params['user_id'] = user_id
if username:
conditions.append("username = %(username)s")
params['username'] = username
conditions.append(models.AuditLog.user_id == user_id)
if project_id is not None:
conditions.append(models.AuditLog.project_id == project_id)
if action:
conditions.append("action = %(action)s")
params['action'] = action
conditions.append(models.AuditLog.action == action)
if resource_type:
conditions.append("resource_type = %(resource_type)s")
params['resource_type'] = resource_type
conditions.append(models.AuditLog.resource_type == resource_type)
if start_time:
conditions.append("timestamp >= %(start_time)s")
params['start_time'] = start_time
conditions.append(models.AuditLog.timestamp >= start_time)
if end_time:
conditions.append("timestamp <= %(end_time)s")
params['end_time'] = end_time
conditions.append(models.AuditLog.timestamp <= end_time)
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"""
SELECT COUNT(*) as count
FROM audit_logs
{where_clause}
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
result = await cur.fetchone()
return result['count'] if result else 0
except Exception as e:
logger.error(f"Error counting audit logs: {e}")
return 0
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
result = await self.session.execute(stmt)
return int(result.scalar() or 0)

View File

@@ -0,0 +1,164 @@
from dataclasses import dataclass
from typing import Optional, List
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.encryption import get_encryptor
from app.infra.db.metadata import models
@dataclass(frozen=True)
class ProjectDbRouting:
project_id: UUID
db_role: str
db_type: str
dsn: str
pool_min_size: int
pool_max_size: int
@dataclass(frozen=True)
class ProjectGeoServerInfo:
project_id: UUID
gs_base_url: Optional[str]
gs_admin_user: Optional[str]
gs_admin_password: Optional[str]
gs_datastore_name: str
default_extent: Optional[dict]
srid: int
@dataclass(frozen=True)
class ProjectSummary:
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str
class MetadataRepository:
"""元数据访问层system_hub"""
def __init__(self, session: AsyncSession):
self.session = session
async def get_user_by_keycloak_id(
self, keycloak_id: UUID
) -> Optional[models.User]:
result = await self.session.execute(
select(models.User).where(models.User.keycloak_id == keycloak_id)
)
return result.scalar_one_or_none()
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
result = await self.session.execute(
select(models.Project).where(models.Project.id == project_id)
)
return result.scalar_one_or_none()
async def get_membership_role(
self, project_id: UUID, user_id: UUID
) -> Optional[str]:
result = await self.session.execute(
select(models.UserProjectMembership.project_role).where(
models.UserProjectMembership.project_id == project_id,
models.UserProjectMembership.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def get_project_db_routing(
self, project_id: UUID, db_role: str
) -> Optional[ProjectDbRouting]:
result = await self.session.execute(
select(models.ProjectDatabase).where(
models.ProjectDatabase.project_id == project_id,
models.ProjectDatabase.db_role == db_role,
)
)
record = result.scalar_one_or_none()
if not record:
return None
encryptor = get_encryptor()
dsn = encryptor.decrypt(record.dsn_encrypted)
return ProjectDbRouting(
project_id=record.project_id,
db_role=record.db_role,
db_type=record.db_type,
dsn=dsn,
pool_min_size=record.pool_min_size,
pool_max_size=record.pool_max_size,
)
async def get_geoserver_config(
self, project_id: UUID
) -> Optional[ProjectGeoServerInfo]:
result = await self.session.execute(
select(models.ProjectGeoServerConfig).where(
models.ProjectGeoServerConfig.project_id == project_id
)
)
record = result.scalar_one_or_none()
if not record:
return None
encryptor = get_encryptor()
password = (
encryptor.decrypt(record.gs_admin_password_encrypted)
if record.gs_admin_password_encrypted
else None
)
return ProjectGeoServerInfo(
project_id=record.project_id,
gs_base_url=record.gs_base_url,
gs_admin_user=record.gs_admin_user,
gs_admin_password=password,
gs_datastore_name=record.gs_datastore_name,
default_extent=record.default_extent,
srid=record.srid,
)
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
stmt = (
select(models.Project, models.UserProjectMembership.project_role)
.join(
models.UserProjectMembership,
models.UserProjectMembership.project_id == models.Project.id,
)
.where(models.UserProjectMembership.user_id == user_id)
.order_by(models.Project.name)
)
result = await self.session.execute(stmt)
return [
ProjectSummary(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=role,
)
for project, role in result.all()
]
async def list_all_projects(self) -> List[ProjectSummary]:
result = await self.session.execute(
select(models.Project).order_by(models.Project.name)
)
return [
ProjectSummary(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role="owner",
)
for project in result.scalars().all()
]

View File

@@ -9,6 +9,8 @@ import app.services.project_info as project_info
from app.api.v1.router import api_router
from app.infra.db.timescaledb.database import db as tsdb
from app.infra.db.postgresql.database import db as pgdb
from app.infra.db.dynamic_manager import project_connection_manager
from app.infra.db.metadata.database import close_metadata_engine
from app.services.tjnetwork import open_project
from app.core.config import settings
@@ -46,6 +48,8 @@ async def lifespan(app: FastAPI):
# 清理资源
await tsdb.close()
await pgdb.close()
await project_connection_manager.close_all()
await close_metadata_engine()
logger.info("Database connections closed")