Compare commits

...

15 Commits

32 changed files with 1590 additions and 353 deletions

View File

@@ -12,6 +12,7 @@ SECRET_KEY=your-secret-key-here-change-in-production-use-openssl-rand-hex-32
# 数据加密密钥 - 用于敏感数据加密
# 生成方式: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
ENCRYPTION_KEY=
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
# ============================================
# 数据库配置 (PostgreSQL)
@@ -31,6 +32,25 @@ TIMESCALEDB_DB_PORT="5433"
TIMESCALEDB_DB_USER="tjwater"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
# ============================================
# 元数据数据库配置 (Metadata DB)
# ============================================
METADATA_DB_NAME="system_hub"
METADATA_DB_HOST="localhost"
METADATA_DB_PORT="5432"
METADATA_DB_USER="tjwater"
METADATA_DB_PASSWORD="password"
# ============================================
# 项目连接缓存与连接池配置
# ============================================
PROJECT_PG_CACHE_SIZE=50
PROJECT_TS_CACHE_SIZE=50
PROJECT_PG_POOL_SIZE=5
PROJECT_PG_MAX_OVERFLOW=10
PROJECT_TS_POOL_MIN_SIZE=1
PROJECT_TS_POOL_MAX_SIZE=10
# ============================================
# InfluxDB 配置 (时序数据)
# ============================================
@@ -46,6 +66,12 @@ TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
# REFRESH_TOKEN_EXPIRE_DAYS=7
# ALGORITHM=HS256
# ============================================
# Keycloak JWT (可选)
# ============================================
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
# KEYCLOAK_ALGORITHM=RS256
# ============================================
# 其他配置
# ============================================

View File

@@ -1,13 +1,23 @@
NETWORK_NAME="szh"
NETWORK_NAME="tjwater"
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
KEYCLOAK_ALGORITHM="RS256"
KEYCLOAK_AUDIENCE="account"
DB_NAME="szh"
DB_NAME="tjwater"
DB_HOST="192.168.1.114"
DB_PORT="5432"
DB_USER="tjwater"
DB_PASSWORD="Tjwater@123456"
TIMESCALEDB_DB_NAME="szh"
TIMESCALEDB_DB_NAME="tjwater"
TIMESCALEDB_DB_HOST="192.168.1.114"
TIMESCALEDB_DB_PORT="5433"
TIMESCALEDB_DB_USER="tjwater"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
METADATA_DB_NAME="system_hub"
METADATA_DB_HOST="192.168.1.114"
METADATA_DB_PORT="5432"
METADATA_DB_USER="tjwater"
METADATA_DB_PASSWORD="Tjwater@123456"
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="

View File

@@ -87,7 +87,6 @@ Default admin accounts:
- `app/core/config.py`: Settings management using `pydantic-settings`
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
- `configs/project_info.yml`: Default project configuration (auto-loaded on startup)
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
## Important Conventions
@@ -148,7 +147,6 @@ async def delete_data(id: int, current_user = Depends(get_current_admin)):
- On startup, `main.py` automatically loads project from `project_info.name` if set
- Projects are opened via `open_project(name)` from `tjnetwork` service
- Initial project config comes from `configs/project_info.yml`
### Audit Logging

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ build/
.env
*.dump
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
.vscode/

View File

@@ -4,33 +4,38 @@
仅管理员可访问
"""
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from fastapi import APIRouter, Depends, Query, Request
from app.domain.schemas.audit import AuditLogResponse, AuditLogQuery
from app.domain.schemas.user import UserInDB
from fastapi import APIRouter, Depends, Query
from app.domain.schemas.audit import AuditLogResponse
from app.infra.repositories.audit_repository import AuditRepository
from app.auth.dependencies import get_user_repository, get_db
from app.auth.permissions import get_current_admin
from app.infra.db.postgresql.database import Database
from app.auth.metadata_dependencies import (
get_current_metadata_admin,
get_current_metadata_user,
)
from app.infra.db.metadata.database import get_metadata_session
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
async def get_audit_repository(db: Database = Depends(get_db)) -> AuditRepository:
async def get_audit_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> AuditRepository:
"""获取审计日志仓储"""
return AuditRepository(db)
return AuditRepository(session)
@router.get("/logs", response_model=List[AuditLogResponse])
async def get_audit_logs(
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
username: Optional[str] = Query(None, description="用户名过滤"),
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="项目ID过滤"),
action: Optional[str] = Query(None, description="按操作类型过滤"),
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_admin),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询审计日志(仅管理员)
@@ -39,7 +44,7 @@ async def get_audit_logs(
"""
logs = await audit_repo.get_logs(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
start_time=start_time,
@@ -51,21 +56,21 @@ async def get_audit_logs(
@router.get("/logs/count")
async def get_audit_logs_count(
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
username: Optional[str] = Query(None, description="用户名过滤"),
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="项目ID过滤"),
action: Optional[str] = Query(None, description="按操作类型过滤"),
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_admin),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> dict:
"""
获取审计日志总数(仅管理员)
"""
count = await audit_repo.get_log_count(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
start_time=start_time,
@@ -80,8 +85,8 @@ async def get_my_audit_logs(
end_time: Optional[datetime] = Query(None, description="结束时间"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_user),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询当前用户的审计日志

View File

@@ -0,0 +1,101 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg import AsyncConnection
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.project_dependencies import (
ProjectContext,
get_project_context,
get_project_pg_session,
get_project_timescale_connection,
get_metadata_repository,
)
from app.auth.metadata_dependencies import get_current_metadata_user
from app.core.config import settings
from app.domain.schemas.metadata import (
GeoServerConfigResponse,
ProjectMetaResponse,
ProjectSummaryResponse,
)
from app.infra.repositories.metadata_repository import MetadataRepository
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/meta/project", response_model=ProjectMetaResponse)
async def get_project_metadata(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
project = await metadata_repo.get_project_by_id(ctx.project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
geoserver_payload = (
GeoServerConfigResponse(
gs_base_url=geoserver.gs_base_url,
gs_admin_user=geoserver.gs_admin_user,
gs_datastore_name=geoserver.gs_datastore_name,
default_extent=geoserver.default_extent,
srid=geoserver.srid,
)
if geoserver
else None
)
return ProjectMetaResponse(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=ctx.project_role,
geoserver=geoserver_payload,
)
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
async def list_user_projects(
current_user=Depends(get_current_metadata_user),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
try:
projects = await metadata_repo.list_projects_for_user(current_user.id)
except SQLAlchemyError as exc:
logger.error(
"Metadata DB error while listing projects for user %s",
current_user.id,
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Metadata database error: {exc}",
) from exc
return [
ProjectSummaryResponse(
project_id=project.project_id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=project.project_role,
)
for project in projects
]
@router.get("/meta/db/health")
async def project_db_health(
pg_session: AsyncSession = Depends(get_project_pg_session),
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
):
await pg_session.execute(text("SELECT 1"))
async with ts_conn.cursor() as cur:
await cur.execute("SELECT 1")
return {"postgres": "ok", "timescale": "ok"}

View File

@@ -4,6 +4,8 @@ from fastapi.responses import PlainTextResponse
from typing import Any, Dict
import app.services.project_info as project_info
from app.native.api import ChangeSet
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
from app.services.tjnetwork import (
list_project,
have_project,
@@ -62,6 +64,28 @@ async def is_project_open_endpoint(network: str):
@router.post("/openproject/")
async def open_project_endpoint(network: str):
open_project(network)
# 尝试连接指定数据库
try:
# 初始化 PostgreSQL 连接池
pg_instance = await get_pg_db(network)
async with pg_instance.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT 1")
# 初始化 TimescaleDB 连接池
ts_instance = await get_ts_db(network)
async with ts_instance.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT 1")
except Exception as e:
# 记录错误但不阻断项目打开,或者根据需求决定是否阻断
# 这里选择打印错误,因为 open_project 原本只负责原生部分
print(f"Failed to connect to databases for {network}: {str(e)}")
# 如果数据库连接是必须的,可以抛出异常:
# raise HTTPException(status_code=500, detail=f"Database connection failed: {str(e)}")
return network
@router.post("/closeproject/")

View File

@@ -14,6 +14,7 @@ from app.api.v1.endpoints import (
cache,
user_management, # 新增:用户管理
audit, # 新增:审计日志
meta,
)
from app.api.v1.endpoints.network import (
general,
@@ -46,6 +47,7 @@ api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
api_router.include_router(meta.router, tags=["Metadata"])
api_router.include_router(project.router, tags=["Project"])
# Network Elements (Node/Link Types)

View File

@@ -0,0 +1,63 @@
# import logging
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from app.core.config import settings
oauth2_optional = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
)
# logger = logging.getLogger(__name__)
async def get_current_keycloak_sub(
token: str | None = Depends(oauth2_optional),
) -> UUID:
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
if settings.KEYCLOAK_PUBLIC_KEY:
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
algorithms = [settings.KEYCLOAK_ALGORITHM]
else:
key = settings.SECRET_KEY
algorithms = [settings.ALGORITHM]
try:
payload = jwt.decode(
token,
key,
algorithms=algorithms,
audience=settings.KEYCLOAK_AUDIENCE or None,
)
except JWTError as exc:
# logger.warning("Keycloak token validation failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
sub = payload.get("sub")
if not sub:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing subject claim",
headers={"WWW-Authenticate": "Bearer"},
)
try:
return UUID(sub)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid subject claim",
headers={"WWW-Authenticate": "Bearer"},
) from exc

View File

@@ -0,0 +1,60 @@
from dataclasses import dataclass
from uuid import UUID
import logging
from fastapi import Depends, HTTPException, status
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
async def get_metadata_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> MetadataRepository:
return MetadataRepository(session)
async def get_current_metadata_user(
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
try:
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
except SQLAlchemyError as exc:
logger.error(
"Metadata DB error while resolving current user",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Metadata database error: {exc}",
) from exc
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
async def get_current_metadata_admin(
user=Depends(get_current_metadata_user),
):
if user.is_superuser or user.role == "admin":
return user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
)
@dataclass(frozen=True)
class _AuthBypassUser:
id: UUID = UUID(int=0)
role: str = "admin"
is_superuser: bool = True
is_active: bool = True

View File

@@ -0,0 +1,213 @@
from dataclasses import dataclass
from typing import AsyncGenerator
from uuid import UUID
import logging
from fastapi import Depends, Header, HTTPException, status
from psycopg import AsyncConnection
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.dynamic_manager import project_connection_manager
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
DB_ROLE_BIZ_DATA = "biz_data"
DB_ROLE_IOT_DATA = "iot_data"
DB_TYPE_POSTGRES = "postgresql"
DB_TYPE_TIMESCALE = "timescaledb"
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ProjectContext:
project_id: UUID
user_id: UUID
project_role: str
async def get_metadata_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> MetadataRepository:
return MetadataRepository(session)
async def get_project_context(
x_project_id: str = Header(..., alias="X-Project-Id"),
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> ProjectContext:
try:
project_uuid = UUID(x_project_id)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
) from exc
try:
project = await metadata_repo.get_project_by_id(project_uuid)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
if project.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
)
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
if not user:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
if not membership_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
)
except SQLAlchemyError as exc:
logger.error(
"Metadata DB error while resolving project context",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Metadata database error: {exc}",
) from exc
return ProjectContext(
project_id=project.id,
user_id=user.id,
project_role=membership_role,
)
async def get_project_pg_session(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncSession, None]:
try:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_BIZ_DATA
)
except ValueError as exc:
logger.error(
"Invalid project PostgreSQL routing DSN configuration",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
) from exc
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL not configured",
)
if routing.db_type != DB_TYPE_POSTGRES:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
ctx.project_id,
DB_ROLE_BIZ_DATA,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with sessionmaker() as session:
yield session
async def get_project_pg_connection(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncConnection, None]:
try:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_BIZ_DATA
)
except ValueError as exc:
logger.error(
"Invalid project PostgreSQL routing DSN configuration",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
) from exc
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL not configured",
)
if routing.db_type != DB_TYPE_POSTGRES:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
pool = await project_connection_manager.get_pg_pool(
ctx.project_id,
DB_ROLE_BIZ_DATA,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with pool.connection() as conn:
yield conn
async def get_project_timescale_connection(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncConnection, None]:
try:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_IOT_DATA
)
except ValueError as exc:
logger.error(
"Invalid project TimescaleDB routing DSN configuration",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project TimescaleDB routing DSN is invalid: {exc}",
) from exc
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project TimescaleDB not configured",
)
if routing.db_type != DB_TYPE_TIMESCALE:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project TimescaleDB type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
pool = await project_connection_manager.get_timescale_pool(
ctx.project_id,
DB_ROLE_IOT_DATA,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with pool.connection() as conn:
yield conn

View File

@@ -7,6 +7,7 @@
from typing import Optional
from datetime import datetime
import logging
from uuid import UUID
logger = logging.getLogger(__name__)
@@ -38,18 +39,16 @@ class AuditAction:
async def log_audit_event(
action: str,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_method: Optional[str] = None,
request_path: Optional[str] = None,
request_data: Optional[dict] = None,
response_status: Optional[int] = None,
error_message: Optional[str] = None,
db=None, # 新增:可选的数据库实例
session=None,
):
"""
记录审计日志
@@ -57,67 +56,60 @@ async def log_audit_event(
Args:
action: 操作类型
user_id: 用户ID
username: 用户名
project_id: 项目ID
resource_type: 资源类型
resource_id: 资源ID
ip_address: IP地址
user_agent: User-Agent
request_method: 请求方法
request_path: 请求路径
request_data: 请求数据(敏感字段需脱敏)
response_status: 响应状态码
error_message: 错误消息
db: 数据库实例(可选,如果不提供则尝试获取)
session: 元数据库会话(可选)
"""
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.audit_repository import AuditRepository
try:
# 脱敏敏感数据
if request_data:
request_data = sanitize_sensitive_data(request_data)
# 如果没有提供数据库实例,尝试从全局获取
if db is None:
try:
from app.infra.db.postgresql.database import db as default_db
# 仅当连接池已初始化时使用
if default_db.pool:
db = default_db
except ImportError:
pass
# 如果仍然没有数据库实例
if db is None:
# 在某些上下文中可能无法获取,此时静默失败
logger.warning("No database instance provided for audit logging")
return
audit_repo = AuditRepository(db)
if request_data:
request_data = sanitize_sensitive_data(request_data)
if session is None:
async with SessionLocal() as session:
audit_repo = AuditRepository(session)
await audit_repo.create_log(
user_id=user_id,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
)
else:
audit_repo = AuditRepository(session)
await audit_repo.create_log(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
error_message=error_message,
)
logger.info(
f"Audit log created: action={action}, user={username or user_id}, "
f"resource={resource_type}:{resource_id}"
)
except Exception as e:
# 审计日志失败不应影响业务流程
logger.error(f"Failed to create audit log: {e}", exc_info=True)
logger.info(
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
action,
user_id,
project_id,
resource_type,
resource_id,
)
def sanitize_sensitive_data(data: dict) -> dict:

View File

@@ -1,4 +1,6 @@
from pydantic_settings import BaseSettings
from pathlib import Path
from urllib.parse import quote_plus
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
@@ -15,6 +17,7 @@ class Settings(BaseSettings):
# 数据加密密钥 (使用 Fernet)
ENCRYPTION_KEY: str = "" # 必须从环境变量设置
DATABASE_ENCRYPTION_KEY: str = "" # project_databases.dsn_encrypted 专用
# Database Config (PostgreSQL)
DB_NAME: str = "tjwater"
@@ -35,13 +38,45 @@ class Settings(BaseSettings):
INFLUXDB_ORG: str = "org"
INFLUXDB_BUCKET: str = "bucket"
# Metadata Database Config (PostgreSQL)
METADATA_DB_NAME: str = "system_hub"
METADATA_DB_HOST: str = "localhost"
METADATA_DB_PORT: str = "5432"
METADATA_DB_USER: str = "postgres"
METADATA_DB_PASSWORD: str = "password"
METADATA_DB_POOL_SIZE: int = 5
METADATA_DB_MAX_OVERFLOW: int = 10
PROJECT_PG_CACHE_SIZE: int = 50
PROJECT_TS_CACHE_SIZE: int = 50
PROJECT_PG_POOL_SIZE: int = 5
PROJECT_PG_MAX_OVERFLOW: int = 10
PROJECT_TS_POOL_MIN_SIZE: int = 1
PROJECT_TS_POOL_MAX_SIZE: int = 10
# Keycloak JWT (optional override)
KEYCLOAK_PUBLIC_KEY: str = ""
KEYCLOAK_ALGORITHM: str = "RS256"
KEYCLOAK_AUDIENCE: str = ""
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
db_password = quote_plus(self.DB_PASSWORD)
return f"postgresql://{self.DB_USER}:{db_password}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
class Config:
env_file = ".env"
extra = "ignore"
@property
def METADATA_DATABASE_URI(self) -> str:
metadata_password = quote_plus(self.METADATA_DB_PASSWORD)
return (
f"postgresql+psycopg://{self.METADATA_DB_USER}:{metadata_password}"
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
)
model_config = SettingsConfigDict(
env_file=Path(__file__).resolve().parents[2] / ".env",
extra="ignore",
)
settings = Settings()

View File

@@ -3,75 +3,94 @@ from typing import Optional
import base64
import os
from app.core.config import settings
class Encryptor:
"""
使用 Fernet (对称加密) 实现数据加密/解密
适用于加密敏感配置、用户数据等
"""
def __init__(self, key: Optional[bytes] = None):
"""
初始化加密器
Args:
key: 加密密钥,如果为 None 则从环境变量读取
"""
if key is None:
key_str = os.getenv("ENCRYPTION_KEY")
key_str = os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY
if not key_str:
raise ValueError(
"ENCRYPTION_KEY not found in environment variables. "
"ENCRYPTION_KEY not found in environment variables or .env. "
"Generate one using: Encryptor.generate_key()"
)
key = key_str.encode()
self.fernet = Fernet(key)
def encrypt(self, data: str) -> str:
"""
加密字符串
Args:
data: 待加密的明文字符串
Returns:
Base64 编码的加密字符串
"""
if not data:
return data
encrypted_bytes = self.fernet.encrypt(data.encode())
return encrypted_bytes.decode()
def decrypt(self, data: str) -> str:
"""
解密字符串
Args:
data: Base64 编码的加密字符串
Returns:
解密后的明文字符串
"""
if not data:
return data
decrypted_bytes = self.fernet.decrypt(data.encode())
return decrypted_bytes.decode()
@staticmethod
def generate_key() -> str:
"""
生成新的 Fernet 加密密钥
Returns:
Base64 编码的密钥字符串
"""
key = Fernet.generate_key()
return key.decode()
# 全局加密器实例(懒加载)
_encryptor: Optional[Encryptor] = None
_database_encryptor: Optional[Encryptor] = None
def is_encryption_configured() -> bool:
return bool(os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY)
def is_database_encryption_configured() -> bool:
return bool(
os.getenv("DATABASE_ENCRYPTION_KEY")
or settings.DATABASE_ENCRYPTION_KEY
or os.getenv("ENCRYPTION_KEY")
or settings.ENCRYPTION_KEY
)
def get_encryptor() -> Encryptor:
"""获取全局加密器实例"""
@@ -80,6 +99,26 @@ def get_encryptor() -> Encryptor:
_encryptor = Encryptor()
return _encryptor
def get_database_encryptor() -> Encryptor:
"""获取 project DB DSN 专用加密器实例"""
global _database_encryptor
if _database_encryptor is None:
key_str = (
os.getenv("DATABASE_ENCRYPTION_KEY")
or settings.DATABASE_ENCRYPTION_KEY
or os.getenv("ENCRYPTION_KEY")
or settings.ENCRYPTION_KEY
)
if not key_str:
raise ValueError(
"DATABASE_ENCRYPTION_KEY not found in environment variables or .env. "
"Generate one using: Encryptor.generate_key()"
)
_database_encryptor = Encryptor(key=key_str.encode())
return _database_encryptor
# 向后兼容(延迟加载)
def __getattr__(name):
if name == "encryptor":

View File

@@ -1,45 +1,42 @@
from datetime import datetime
from typing import Optional, Any
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
class AuditLogCreate(BaseModel):
"""创建审计日志"""
user_id: Optional[int] = None
username: Optional[str] = None
user_id: Optional[UUID] = None
project_id: Optional[UUID] = None
action: str
resource_type: Optional[str] = None
resource_id: Optional[str] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
request_method: Optional[str] = None
request_path: Optional[str] = None
request_data: Optional[dict] = None
response_status: Optional[int] = None
error_message: Optional[str] = None
class AuditLogResponse(BaseModel):
"""审计日志响应"""
id: int
user_id: Optional[int]
username: Optional[str]
id: UUID
user_id: Optional[UUID]
project_id: Optional[UUID]
action: str
resource_type: Optional[str]
resource_id: Optional[str]
ip_address: Optional[str]
user_agent: Optional[str]
request_method: Optional[str]
request_path: Optional[str]
request_data: Optional[dict]
response_status: Optional[int]
error_message: Optional[str]
timestamp: datetime
model_config = ConfigDict(from_attributes=True)
class AuditLogQuery(BaseModel):
"""审计日志查询参数"""
user_id: Optional[int] = None
username: Optional[str] = None
user_id: Optional[UUID] = None
project_id: Optional[UUID] = None
action: Optional[str] = None
resource_type: Optional[str] = None
start_time: Optional[datetime] = None

View File

@@ -0,0 +1,33 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class GeoServerConfigResponse(BaseModel):
gs_base_url: Optional[str]
gs_admin_user: Optional[str]
gs_datastore_name: str
default_extent: Optional[dict]
srid: int
class ProjectMetaResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str
geoserver: Optional[GeoServerConfigResponse]
class ProjectSummaryResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str

View File

@@ -6,12 +6,17 @@
import time
import json
from uuid import UUID
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.infra.db.postgresql.database import db as default_db
from app.core.audit import log_audit_event, AuditAction
import logging
from jose import JWTError, jwt
from app.core.config import settings
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
# 4. 提取审计所需信息
user_id = None
username = None
# 尝试从请求状态获取当前用户
if hasattr(request.state, "user"):
user = request.state.user
user_id = getattr(user, "id", None)
username = getattr(user, "username", None)
user_id = await self._resolve_user_id(request)
project_id = self._resolve_project_id(request)
# 获取客户端信息
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
# 确定操作类型
action = self._determine_action(request)
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
await log_audit_event(
action=action,
user_id=user_id,
username=username,
project_id=project_id,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_method=request.method,
request_path=str(request.url.path),
request_data=request_data,
response_status=response.status_code,
error_message=(
None
if response.status_code < 400
else f"HTTP {response.status_code}"
),
db=default_db,
)
except Exception as e:
# 审计失败不应影响响应
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id")
if not project_header:
return None
try:
return UUID(project_header)
except ValueError:
return None
async def _resolve_user_id(self, request: Request) -> UUID | None:
auth_header = request.headers.get("authorization")
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header.split(" ", 1)[1].strip()
if not token:
return None
try:
key = (
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
if settings.KEYCLOAK_PUBLIC_KEY
else settings.SECRET_KEY
)
algorithms = (
[settings.KEYCLOAK_ALGORITHM]
if settings.KEYCLOAK_PUBLIC_KEY
else [settings.ALGORITHM]
)
payload = jwt.decode(token, key, algorithms=algorithms)
sub = payload.get("sub")
if not sub:
return None
keycloak_id = UUID(sub)
except (JWTError, ValueError):
return None
async with SessionLocal() as session:
repo = MetadataRepository(session)
user = await repo.get_user_by_keycloak_id(keycloak_id)
if user and user.is_active:
return user.id
return None
def _determine_action(self, request: Request) -> str:
"""根据请求路径和方法确定操作类型"""
path = request.url.path.lower()

View File

@@ -0,0 +1,211 @@
import asyncio
import logging
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict
from uuid import UUID
from psycopg_pool import AsyncConnectionPool
from psycopg.rows import dict_row
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class PgEngineEntry:
engine: AsyncEngine
sessionmaker: async_sessionmaker[AsyncSession]
@dataclass(frozen=True)
class CacheKey:
project_id: UUID
db_role: str
class ProjectConnectionManager:
def __init__(self) -> None:
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_lock = asyncio.Lock()
self._ts_lock = asyncio.Lock()
self._pg_raw_lock = asyncio.Lock()
def _normalize_pg_url(self, url: str) -> str:
parsed = make_url(url)
if parsed.drivername == "postgresql":
parsed = parsed.set(drivername="postgresql+psycopg")
return str(parsed)
async def get_pg_sessionmaker(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> async_sessionmaker[AsyncSession]:
async with self._pg_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._pg_cache.get(key)
if entry:
self._pg_cache.move_to_end(key)
return entry.sessionmaker
normalized_url = self._normalize_pg_url(connection_url)
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
engine = create_async_engine(
normalized_url,
pool_size=pool_min_size,
max_overflow=max(0, pool_max_size - pool_min_size),
pool_pre_ping=True,
)
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
self._pg_cache[key] = PgEngineEntry(
engine=engine,
sessionmaker=sessionmaker,
)
await self._evict_pg_if_needed()
logger.info(
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
)
return sessionmaker
async def get_timescale_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._ts_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._ts_cache.get(key)
if pool:
self._ts_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
kwargs={"row_factory": dict_row},
)
await pool.open()
self._ts_cache[key] = pool
await self._evict_ts_if_needed()
logger.info(
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
)
return pool
async def get_pg_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._pg_raw_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._pg_raw_cache.get(key)
if pool:
self._pg_raw_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
kwargs={"row_factory": dict_row},
)
await pool.open()
self._pg_raw_cache[key] = pool
await self._evict_pg_raw_if_needed()
logger.info(
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
)
return pool
async def _evict_pg_if_needed(self) -> None:
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, entry = self._pg_cache.popitem(last=False)
await entry.engine.dispose()
logger.info(
"Evicted PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_ts_if_needed(self) -> None:
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
key, pool = self._ts_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_pg_raw_if_needed(self) -> None:
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, pool = self._pg_raw_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def close_all(self) -> None:
async with self._pg_lock:
for key, entry in list(self._pg_cache.items()):
await entry.engine.dispose()
logger.info(
"Closed PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_cache.clear()
async with self._ts_lock:
for key, pool in list(self._ts_cache.items()):
await pool.close()
logger.info(
"Closed TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._ts_cache.clear()
async with self._pg_raw_lock:
for key, pool in list(self._pg_raw_cache.items()):
await pool.close()
logger.info(
"Closed PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_raw_cache.clear()
project_connection_manager = ProjectConnectionManager()

View File

@@ -0,0 +1,3 @@
from .database import get_metadata_session, close_metadata_engine
__all__ = ["get_metadata_session", "close_metadata_engine"]

View File

@@ -0,0 +1,27 @@
import logging
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings
logger = logging.getLogger(__name__)
engine = create_async_engine(
settings.METADATA_DATABASE_URI,
pool_size=settings.METADATA_DB_POOL_SIZE,
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
pool_pre_ping=True,
)
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
async with SessionLocal() as session:
yield session
async def close_metadata_engine() -> None:
await engine.dispose()
logger.info("Metadata database engine disposed.")

View File

@@ -0,0 +1,115 @@
from datetime import datetime
from uuid import UUID
from sqlalchemy import Boolean, DateTime, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
keycloak_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), unique=True, index=True
)
username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True)
role: Mapped[str] = mapped_column(String(20), default="user")
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
last_login_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
class Project(Base):
__tablename__ = "projects"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
name: Mapped[str] = mapped_column(String(100))
code: Mapped[str] = mapped_column(String(50), unique=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
status: Mapped[str] = mapped_column(String(20), default="active")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
class ProjectDatabase(Base):
__tablename__ = "project_databases"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
db_role: Mapped[str] = mapped_column(String(20))
db_type: Mapped[str] = mapped_column(String(20))
dsn_encrypted: Mapped[str] = mapped_column(Text)
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
class ProjectGeoServerConfig(Base):
__tablename__ = "project_geoserver_configs"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
project_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), unique=True, index=True
)
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
Text, nullable=True
)
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
srid: Mapped[int] = mapped_column(Integer, default=4326)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
class UserProjectMembership(Base):
__tablename__ = "user_project_membership"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), nullable=True, index=True
)
project_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), nullable=True, index=True
)
action: Mapped[str] = mapped_column(String(50))
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)

View File

@@ -17,7 +17,14 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = postgresql_info.get_pgconn_string()
target_db_name = db_name or self.db_name
# Get connection string, handling default case where target_db_name might be None
if target_db_name:
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
else:
conn_string = postgresql_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -26,7 +33,7 @@ class Database:
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(f"PostgreSQL connection pool initialized for database: default")
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")
raise

View File

@@ -1,24 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from psycopg import AsyncConnection
from .database import get_database_instance
from .scada_info import ScadaRepository
from .scheme import SchemeRepository
from app.auth.project_dependencies import get_project_pg_connection
router = APIRouter()
# 创建支持数据库选择的连接依赖函数
# 动态项目 PostgreSQL 连接依赖
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_pg_connection),
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
@router.get("/scada-info")

View File

@@ -17,7 +17,14 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = timescaledb_info.get_pgconn_string()
target_db_name = db_name or self.db_name
# Get connection string, handling default case where target_db_name might be None
if target_db_name:
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
else:
conn_string = timescaledb_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -27,7 +34,7 @@ class Database:
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}"
)
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
@@ -46,7 +53,9 @@ class Database:
def get_pgconn_string(self, db_name=None):
"""Get the TimescaleDB connection string."""
target_db_name = db_name or self.db_name
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
if target_db_name:
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
return timescaledb_info.get_pgconn_string()
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:

View File

@@ -1,42 +1,32 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from typing import List
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
from .schemas.scheme import SchemeRepository
from .schemas.scada import ScadaRepository
from .composite_queries import CompositeQueries
from app.infra.db.postgresql.database import (
get_database_instance as get_postgres_database_instance,
from app.auth.project_dependencies import (
get_project_pg_connection,
get_project_timescale_connection,
)
router = APIRouter()
# 创建支持数据库选择的连接依赖函数
# 动态项目 TimescaleDB 连接依赖
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_timescale_connection),
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
# PostgreSQL 数据库连接依赖函数
# 动态项目 PostgreSQL 连接依赖
async def get_postgres_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_pg_connection),
):
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_postgres_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
# --- Realtime Endpoints ---

View File

@@ -1,220 +1,112 @@
from typing import Optional, List
from datetime import datetime
import json
from app.infra.db.postgresql.database import Database
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
import logging
from typing import Optional, List
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.domain.schemas.audit import AuditLogResponse
from app.infra.db.metadata import models
logger = logging.getLogger(__name__)
class AuditRepository:
"""审计日志数据访问层"""
def __init__(self, db: Database):
self.db = db
"""审计日志数据访问层system_hub"""
def __init__(self, session: AsyncSession):
self.session = session
async def create_log(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
action: str = "",
action: str,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_method: Optional[str] = None,
request_path: Optional[str] = None,
request_data: Optional[dict] = None,
response_status: Optional[int] = None,
error_message: Optional[str] = None
) -> Optional[AuditLogResponse]:
"""
创建审计日志
Args:
参数说明见 AuditLogCreate
Returns:
创建的审计日志对象
"""
query = """
INSERT INTO audit_logs (
user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message
)
VALUES (
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
%(request_data)s, %(response_status)s, %(error_message)s
)
RETURNING id, user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message, timestamp
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, {
'user_id': user_id,
'username': username,
'action': action,
'resource_type': resource_type,
'resource_id': resource_id,
'ip_address': ip_address,
'user_agent': user_agent,
'request_method': request_method,
'request_path': request_path,
'request_data': json.dumps(request_data) if request_data else None,
'response_status': response_status,
'error_message': error_message
})
row = await cur.fetchone()
if row:
return AuditLogResponse(**row)
except Exception as e:
logger.error(f"Error creating audit log: {e}")
raise
return None
) -> AuditLogResponse:
log = models.AuditLog(
user_id=user_id,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
timestamp=datetime.utcnow(),
)
self.session.add(log)
await self.session.commit()
await self.session.refresh(log)
return AuditLogResponse.model_validate(log)
async def get_logs(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
action: Optional[str] = None,
resource_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
skip: int = 0,
limit: int = 100
limit: int = 100,
) -> List[AuditLogResponse]:
"""
查询审计日志
Args:
user_id: 用户ID过滤
username: 用户名过滤
action: 操作类型过滤
resource_type: 资源类型过滤
start_time: 开始时间
end_time: 结束时间
skip: 跳过记录数
limit: 限制记录数
Returns:
审计日志列表
"""
# 构建动态查询
conditions = []
params = {'skip': skip, 'limit': limit}
if user_id is not None:
conditions.append("user_id = %(user_id)s")
params['user_id'] = user_id
if username:
conditions.append("username = %(username)s")
params['username'] = username
conditions.append(models.AuditLog.user_id == user_id)
if project_id is not None:
conditions.append(models.AuditLog.project_id == project_id)
if action:
conditions.append("action = %(action)s")
params['action'] = action
conditions.append(models.AuditLog.action == action)
if resource_type:
conditions.append("resource_type = %(resource_type)s")
params['resource_type'] = resource_type
conditions.append(models.AuditLog.resource_type == resource_type)
if start_time:
conditions.append("timestamp >= %(start_time)s")
params['start_time'] = start_time
conditions.append(models.AuditLog.timestamp >= start_time)
if end_time:
conditions.append("timestamp <= %(end_time)s")
params['end_time'] = end_time
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"""
SELECT id, user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message, timestamp
FROM audit_logs
{where_clause}
ORDER BY timestamp DESC
LIMIT %(limit)s OFFSET %(skip)s
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
rows = await cur.fetchall()
return [AuditLogResponse(**row) for row in rows]
except Exception as e:
logger.error(f"Error querying audit logs: {e}")
raise
conditions.append(models.AuditLog.timestamp <= end_time)
stmt = (
select(models.AuditLog)
.where(*conditions)
.order_by(models.AuditLog.timestamp.desc())
.offset(skip)
.limit(limit)
)
result = await self.session.execute(stmt)
return [
AuditLogResponse.model_validate(log)
for log in result.scalars().all()
]
async def get_log_count(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
action: Optional[str] = None,
resource_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
end_time: Optional[datetime] = None,
) -> int:
"""
获取审计日志数量
Args:
参数同 get_logs
Returns:
日志总数
"""
conditions = []
params = {}
if user_id is not None:
conditions.append("user_id = %(user_id)s")
params['user_id'] = user_id
if username:
conditions.append("username = %(username)s")
params['username'] = username
conditions.append(models.AuditLog.user_id == user_id)
if project_id is not None:
conditions.append(models.AuditLog.project_id == project_id)
if action:
conditions.append("action = %(action)s")
params['action'] = action
conditions.append(models.AuditLog.action == action)
if resource_type:
conditions.append("resource_type = %(resource_type)s")
params['resource_type'] = resource_type
conditions.append(models.AuditLog.resource_type == resource_type)
if start_time:
conditions.append("timestamp >= %(start_time)s")
params['start_time'] = start_time
conditions.append(models.AuditLog.timestamp >= start_time)
if end_time:
conditions.append("timestamp <= %(end_time)s")
params['end_time'] = end_time
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"""
SELECT COUNT(*) as count
FROM audit_logs
{where_clause}
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
result = await cur.fetchone()
return result['count'] if result else 0
except Exception as e:
logger.error(f"Error counting audit logs: {e}")
return 0
conditions.append(models.AuditLog.timestamp <= end_time)
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
result = await self.session.execute(stmt)
return int(result.scalar() or 0)

View File

@@ -0,0 +1,197 @@
from dataclasses import dataclass
from typing import Optional, List
from uuid import UUID
from cryptography.fernet import InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.encryption import (
get_database_encryptor,
get_encryptor,
is_database_encryption_configured,
is_encryption_configured,
)
from app.infra.db.metadata import models
def _normalize_postgres_dsn(dsn: str) -> str:
if not dsn or "://" not in dsn:
return dsn
scheme, rest = dsn.split("://", 1)
if scheme not in ("postgresql", "postgres", "postgresql+psycopg"):
return dsn
if "@" not in rest:
return dsn
userinfo, hostinfo = rest.rsplit("@", 1)
if ":" not in userinfo:
return dsn
username, password = userinfo.split(":", 1)
if "@" not in password:
return dsn
password = password.replace("@", "%40")
return f"{scheme}://{username}:{password}@{hostinfo}"
@dataclass(frozen=True)
class ProjectDbRouting:
project_id: UUID
db_role: str
db_type: str
dsn: str
pool_min_size: int
pool_max_size: int
@dataclass(frozen=True)
class ProjectGeoServerInfo:
project_id: UUID
gs_base_url: Optional[str]
gs_admin_user: Optional[str]
gs_admin_password: Optional[str]
gs_datastore_name: str
default_extent: Optional[dict]
srid: int
@dataclass(frozen=True)
class ProjectSummary:
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str
class MetadataRepository:
"""元数据访问层system_hub"""
def __init__(self, session: AsyncSession):
self.session = session
async def get_user_by_keycloak_id(self, keycloak_id: UUID) -> Optional[models.User]:
result = await self.session.execute(
select(models.User).where(models.User.keycloak_id == keycloak_id)
)
return result.scalar_one_or_none()
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
result = await self.session.execute(
select(models.Project).where(models.Project.id == project_id)
)
return result.scalar_one_or_none()
async def get_membership_role(
self, project_id: UUID, user_id: UUID
) -> Optional[str]:
result = await self.session.execute(
select(models.UserProjectMembership.project_role).where(
models.UserProjectMembership.project_id == project_id,
models.UserProjectMembership.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def get_project_db_routing(
self, project_id: UUID, db_role: str
) -> Optional[ProjectDbRouting]:
result = await self.session.execute(
select(models.ProjectDatabase).where(
models.ProjectDatabase.project_id == project_id,
models.ProjectDatabase.db_role == db_role,
)
)
record = result.scalar_one_or_none()
if not record:
return None
if not is_database_encryption_configured():
raise ValueError("DATABASE_ENCRYPTION_KEY is not configured")
encryptor = get_database_encryptor()
try:
dsn = encryptor.decrypt(record.dsn_encrypted)
except InvalidToken:
raise ValueError(
"Failed to decrypt project DB DSN: DATABASE_ENCRYPTION_KEY mismatch "
"or invalid dsn_encrypted value"
)
dsn = _normalize_postgres_dsn(dsn)
return ProjectDbRouting(
project_id=record.project_id,
db_role=record.db_role,
db_type=record.db_type,
dsn=dsn,
pool_min_size=record.pool_min_size,
pool_max_size=record.pool_max_size,
)
async def get_geoserver_config(
self, project_id: UUID
) -> Optional[ProjectGeoServerInfo]:
result = await self.session.execute(
select(models.ProjectGeoServerConfig).where(
models.ProjectGeoServerConfig.project_id == project_id
)
)
record = result.scalar_one_or_none()
if not record:
return None
if record.gs_admin_password_encrypted:
if is_encryption_configured():
encryptor = get_encryptor()
password = encryptor.decrypt(record.gs_admin_password_encrypted)
else:
password = record.gs_admin_password_encrypted
else:
password = None
return ProjectGeoServerInfo(
project_id=record.project_id,
gs_base_url=record.gs_base_url,
gs_admin_user=record.gs_admin_user,
gs_admin_password=password,
gs_datastore_name=record.gs_datastore_name,
default_extent=record.default_extent,
srid=record.srid,
)
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
stmt = (
select(models.Project, models.UserProjectMembership.project_role)
.join(
models.UserProjectMembership,
models.UserProjectMembership.project_id == models.Project.id,
)
.where(models.UserProjectMembership.user_id == user_id)
.order_by(models.Project.name)
)
result = await self.session.execute(stmt)
return [
ProjectSummary(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=role,
)
for project, role in result.all()
]
async def list_all_projects(self) -> List[ProjectSummary]:
result = await self.session.execute(
select(models.Project).order_by(models.Project.name)
)
return [
ProjectSummary(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role="owner",
)
for project in result.scalars().all()
]

View File

@@ -9,6 +9,8 @@ import app.services.project_info as project_info
from app.api.v1.router import api_router
from app.infra.db.timescaledb.database import db as tsdb
from app.infra.db.postgresql.database import db as pgdb
from app.infra.db.dynamic_manager import project_connection_manager
from app.infra.db.metadata.database import close_metadata_engine
from app.services.tjnetwork import open_project
from app.core.config import settings
@@ -33,7 +35,7 @@ async def lifespan(app: FastAPI):
await tsdb.open()
await pgdb.open()
# 将数据库实例存储到 app.state供依赖项使用
app.state.db = pgdb
logger.info("Database connection pool initialized and stored in app.state")
@@ -46,6 +48,8 @@ async def lifespan(app: FastAPI):
# 清理资源
await tsdb.close()
await pgdb.close()
await project_connection_manager.close_all()
await close_metadata_engine()
logger.info("Database connections closed")
@@ -58,22 +62,25 @@ app = FastAPI(
redoc_url="/redoc",
)
# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True, # 允许传递凭证Cookie、HTTP 头等)
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有 HTTP 头
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# 添加审计中间件(可选,记录关键操作)
# 如果需要启用审计日志,取消下面的注释
app.add_middleware(AuditMiddleware)
# Include Routers
app.include_router(api_router, prefix="/api/v1")
# Legcy Routers without version prefix
app.include_router(api_router)
# 配置中间件
app.add_middleware(GZipMiddleware, minimum_size=1000)
# 添加审计中间件(可选,记录关键操作)
app.add_middleware(AuditMiddleware)
# 配置 CORS 中间件
# 确保这是你最后一个添加的 app.add_middleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000", # 必须明确指定
"http://127.0.0.1:3000", # 建议同时加上这个
],
allow_credentials=True, # 既然这里是 True
allow_methods=["*"],
allow_headers=["*"],
)

View File

@@ -1235,7 +1235,7 @@ def run_simulation(
starttime = time.time()
if simulation_type.upper() == "REALTIME":
TimescaleInternalStorage.store_realtime_simulation(
node_result, link_result, modify_pattern_start_time
node_result, link_result, modify_pattern_start_time, db_name=name
)
elif simulation_type.upper() == "EXTENDED":
TimescaleInternalStorage.store_scheme_simulation(
@@ -1245,6 +1245,7 @@ def run_simulation(
link_result,
modify_pattern_start_time,
num_periods_result,
db_name=name,
)
endtime = time.time()
logging.info("store time: %f", endtime - starttime)

33
scripts/encrypt_string.py Normal file
View File

@@ -0,0 +1,33 @@
import os
import sys
# 将项目根目录添加到 python 路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from app.core.encryption import get_database_encryptor
def main() -> int:
plaintext = None
if not sys.stdin.isatty():
stdin_text = sys.stdin.read()
if stdin_text != "":
plaintext = stdin_text.rstrip("\r\n")
if plaintext is None and len(sys.argv) >= 2:
plaintext = sys.argv[1]
if plaintext is None:
try:
plaintext = input("请输入要加密的文本: ")
except EOFError:
plaintext = ""
if not plaintext.strip():
print("Error: plaintext string cannot be empty.", file=sys.stderr)
return 1
token = get_database_encryptor().encrypt(plaintext)
print(token)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -16,6 +16,6 @@ if __name__ == "__main__":
"app.main:app",
host="0.0.0.0",
port=8000,
workers=2, # 这里可以设置多进程
workers=4, # 这里可以设置多进程
loop="asyncio",
)

View File

@@ -0,0 +1,119 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from cryptography.fernet import InvalidToken
from app.infra.repositories.metadata_repository import MetadataRepository
class _DummyResult:
def __init__(self, record):
self._record = record
def scalar_one_or_none(self):
return self._record
class _DummyEncryptor:
def __init__(self, decrypted=None, raise_invalid_token=False):
self._decrypted = decrypted
self._raise_invalid_token = raise_invalid_token
self.encrypted_values = []
def decrypt(self, _value):
if self._raise_invalid_token:
raise InvalidToken()
return self._decrypted
def _build_record(dsn_encrypted: str):
return SimpleNamespace(
project_id=uuid4(),
db_role="biz_data",
db_type="postgresql",
dsn_encrypted=dsn_encrypted,
pool_min_size=1,
pool_max_size=5,
)
def test_invalid_token_with_plaintext_dsn_value_raises_clear_error(monkeypatch):
record = _build_record("postgresql://user:p@ss@localhost:5432/db")
session = SimpleNamespace(
execute=None,
commit=None,
)
session.execute = AsyncMock(return_value=_DummyResult(record))
session.commit = AsyncMock()
encryptor = _DummyEncryptor(raise_invalid_token=True)
repo = MetadataRepository(session)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
lambda: True,
)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.get_database_encryptor",
lambda: encryptor,
)
with pytest.raises(
ValueError,
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
):
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
session.commit.assert_not_awaited()
def test_invalid_token_with_non_dsn_value_raises_clear_error(monkeypatch):
record = _build_record("gAAAAABinvalidciphertext")
session = SimpleNamespace(
execute=None,
commit=None,
)
session.execute = AsyncMock(return_value=_DummyResult(record))
session.commit = AsyncMock()
repo = MetadataRepository(session)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
lambda: True,
)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.get_database_encryptor",
lambda: _DummyEncryptor(raise_invalid_token=True),
)
with pytest.raises(
ValueError,
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
):
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
session.commit.assert_not_awaited()
def test_encrypted_dsn_decrypts_without_migration(monkeypatch):
record = _build_record("encrypted-value")
session = SimpleNamespace(
execute=None,
commit=None,
)
session.execute = AsyncMock(return_value=_DummyResult(record))
session.commit = AsyncMock()
repo = MetadataRepository(session)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
lambda: True,
)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.get_database_encryptor",
lambda: _DummyEncryptor(decrypted="postgresql://u:p@ss@host/db"),
)
routing = asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
assert routing.dsn == "postgresql://u:p%40ss@host/db"
session.commit.assert_not_awaited()