Compare commits
15 Commits
a472639b8a
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
| 80b6970970 | |||
| 364a8c8ec2 | |||
| 52ccb8abf1 | |||
| 0bc4058f23 | |||
| 0d3e6ca4fa | |||
| 6fc3aa5209 | |||
| 1b1b0a3697 | |||
| 2826999ddc | |||
| efc05f7278 | |||
| 29209f5c63 | |||
| 020432ad0e | |||
| 780a48d927 | |||
| ff2011ae24 | |||
| f5069a5606 | |||
| eb45e4aaa5 |
26
.env.example
26
.env.example
@@ -12,6 +12,7 @@ SECRET_KEY=your-secret-key-here-change-in-production-use-openssl-rand-hex-32
|
||||
# 数据加密密钥 - 用于敏感数据加密
|
||||
# 生成方式: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
ENCRYPTION_KEY=
|
||||
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
|
||||
# ============================================
|
||||
# 数据库配置 (PostgreSQL)
|
||||
@@ -31,6 +32,25 @@ TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
# ============================================
|
||||
# 元数据数据库配置 (Metadata DB)
|
||||
# ============================================
|
||||
METADATA_DB_NAME="system_hub"
|
||||
METADATA_DB_HOST="localhost"
|
||||
METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 项目连接缓存与连接池配置
|
||||
# ============================================
|
||||
PROJECT_PG_CACHE_SIZE=50
|
||||
PROJECT_TS_CACHE_SIZE=50
|
||||
PROJECT_PG_POOL_SIZE=5
|
||||
PROJECT_PG_MAX_OVERFLOW=10
|
||||
PROJECT_TS_POOL_MIN_SIZE=1
|
||||
PROJECT_TS_POOL_MAX_SIZE=10
|
||||
|
||||
# ============================================
|
||||
# InfluxDB 配置 (时序数据)
|
||||
# ============================================
|
||||
@@ -46,6 +66,12 @@ TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
# ALGORITHM=HS256
|
||||
|
||||
# ============================================
|
||||
# Keycloak JWT (可选)
|
||||
# ============================================
|
||||
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||
# KEYCLOAK_ALGORITHM=RS256
|
||||
|
||||
# ============================================
|
||||
# 其他配置
|
||||
# ============================================
|
||||
|
||||
18
.env.local
18
.env.local
@@ -1,13 +1,23 @@
|
||||
NETWORK_NAME="szh"
|
||||
NETWORK_NAME="tjwater"
|
||||
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
|
||||
KEYCLOAK_ALGORITHM="RS256"
|
||||
KEYCLOAK_AUDIENCE="account"
|
||||
|
||||
DB_NAME="szh"
|
||||
DB_NAME="tjwater"
|
||||
DB_HOST="192.168.1.114"
|
||||
DB_PORT="5432"
|
||||
DB_USER="tjwater"
|
||||
DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
TIMESCALEDB_DB_NAME="szh"
|
||||
TIMESCALEDB_DB_NAME="tjwater"
|
||||
TIMESCALEDB_DB_HOST="192.168.1.114"
|
||||
TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
METADATA_DB_NAME="system_hub"
|
||||
METADATA_DB_HOST="192.168.1.114"
|
||||
METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="Tjwater@123456"
|
||||
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
|
||||
2
.github/copilot-instructions.md
vendored
2
.github/copilot-instructions.md
vendored
@@ -87,7 +87,6 @@ Default admin accounts:
|
||||
- `app/core/config.py`: Settings management using `pydantic-settings`
|
||||
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
|
||||
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
|
||||
- `configs/project_info.yml`: Default project configuration (auto-loaded on startup)
|
||||
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
|
||||
|
||||
## Important Conventions
|
||||
@@ -148,7 +147,6 @@ async def delete_data(id: int, current_user = Depends(get_current_admin)):
|
||||
|
||||
- On startup, `main.py` automatically loads project from `project_info.name` if set
|
||||
- Projects are opened via `open_project(name)` from `tjnetwork` service
|
||||
- Initial project config comes from `configs/project_info.yml`
|
||||
|
||||
### Audit Logging
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ build/
|
||||
.env
|
||||
*.dump
|
||||
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
|
||||
.vscode/
|
||||
|
||||
@@ -4,33 +4,38 @@
|
||||
仅管理员可访问
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from app.domain.schemas.audit import AuditLogResponse, AuditLogQuery
|
||||
from app.domain.schemas.user import UserInDB
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
from app.auth.dependencies import get_user_repository, get_db
|
||||
from app.auth.permissions import get_current_admin
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.auth.metadata_dependencies import (
|
||||
get_current_metadata_admin,
|
||||
get_current_metadata_user,
|
||||
)
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
async def get_audit_repository(db: Database = Depends(get_db)) -> AuditRepository:
|
||||
async def get_audit_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> AuditRepository:
|
||||
"""获取审计日志仓储"""
|
||||
return AuditRepository(db)
|
||||
return AuditRepository(session)
|
||||
|
||||
@router.get("/logs", response_model=List[AuditLogResponse])
|
||||
async def get_audit_logs(
|
||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
||||
current_user=Depends(get_current_metadata_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志(仅管理员)
|
||||
@@ -39,7 +44,7 @@ async def get_audit_logs(
|
||||
"""
|
||||
logs = await audit_repo.get_logs(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
start_time=start_time,
|
||||
@@ -51,21 +56,21 @@ async def get_audit_logs(
|
||||
|
||||
@router.get("/logs/count")
|
||||
async def get_audit_logs_count(
|
||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
||||
current_user=Depends(get_current_metadata_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> dict:
|
||||
"""
|
||||
获取审计日志总数(仅管理员)
|
||||
"""
|
||||
count = await audit_repo.get_log_count(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
start_time=start_time,
|
||||
@@ -80,8 +85,8 @@ async def get_my_audit_logs(
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询当前用户的审计日志
|
||||
|
||||
101
app/api/v1/endpoints/meta.py
Normal file
101
app/api/v1/endpoints/meta.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from psycopg import AsyncConnection
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.project_dependencies import (
|
||||
ProjectContext,
|
||||
get_project_context,
|
||||
get_project_pg_session,
|
||||
get_project_timescale_connection,
|
||||
get_metadata_repository,
|
||||
)
|
||||
from app.auth.metadata_dependencies import get_current_metadata_user
|
||||
from app.core.config import settings
|
||||
from app.domain.schemas.metadata import (
|
||||
GeoServerConfigResponse,
|
||||
ProjectMetaResponse,
|
||||
ProjectSummaryResponse,
|
||||
)
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/meta/project", response_model=ProjectMetaResponse)
|
||||
async def get_project_metadata(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
project = await metadata_repo.get_project_by_id(ctx.project_id)
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||
)
|
||||
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
|
||||
geoserver_payload = (
|
||||
GeoServerConfigResponse(
|
||||
gs_base_url=geoserver.gs_base_url,
|
||||
gs_admin_user=geoserver.gs_admin_user,
|
||||
gs_datastore_name=geoserver.gs_datastore_name,
|
||||
default_extent=geoserver.default_extent,
|
||||
srid=geoserver.srid,
|
||||
)
|
||||
if geoserver
|
||||
else None
|
||||
)
|
||||
return ProjectMetaResponse(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=ctx.project_role,
|
||||
geoserver=geoserver_payload,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
|
||||
async def list_user_projects(
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
try:
|
||||
projects = await metadata_repo.list_projects_for_user(current_user.id)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while listing projects for user %s",
|
||||
current_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
return [
|
||||
ProjectSummaryResponse(
|
||||
project_id=project.project_id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=project.project_role,
|
||||
)
|
||||
for project in projects
|
||||
]
|
||||
|
||||
|
||||
@router.get("/meta/db/health")
|
||||
async def project_db_health(
|
||||
pg_session: AsyncSession = Depends(get_project_pg_session),
|
||||
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
await pg_session.execute(text("SELECT 1"))
|
||||
async with ts_conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
return {"postgres": "ok", "timescale": "ok"}
|
||||
@@ -4,6 +4,8 @@ from fastapi.responses import PlainTextResponse
|
||||
from typing import Any, Dict
|
||||
import app.services.project_info as project_info
|
||||
from app.native.api import ChangeSet
|
||||
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
|
||||
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
|
||||
from app.services.tjnetwork import (
|
||||
list_project,
|
||||
have_project,
|
||||
@@ -62,6 +64,28 @@ async def is_project_open_endpoint(network: str):
|
||||
@router.post("/openproject/")
|
||||
async def open_project_endpoint(network: str):
|
||||
open_project(network)
|
||||
|
||||
# 尝试连接指定数据库
|
||||
try:
|
||||
# 初始化 PostgreSQL 连接池
|
||||
pg_instance = await get_pg_db(network)
|
||||
async with pg_instance.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
|
||||
# 初始化 TimescaleDB 连接池
|
||||
ts_instance = await get_ts_db(network)
|
||||
async with ts_instance.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误但不阻断项目打开,或者根据需求决定是否阻断
|
||||
# 这里选择打印错误,因为 open_project 原本只负责原生部分
|
||||
print(f"Failed to connect to databases for {network}: {str(e)}")
|
||||
# 如果数据库连接是必须的,可以抛出异常:
|
||||
# raise HTTPException(status_code=500, detail=f"Database connection failed: {str(e)}")
|
||||
|
||||
return network
|
||||
|
||||
@router.post("/closeproject/")
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.api.v1.endpoints import (
|
||||
cache,
|
||||
user_management, # 新增:用户管理
|
||||
audit, # 新增:审计日志
|
||||
meta,
|
||||
)
|
||||
from app.api.v1.endpoints.network import (
|
||||
general,
|
||||
@@ -46,6 +47,7 @@ api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
||||
api_router.include_router(meta.router, tags=["Metadata"])
|
||||
api_router.include_router(project.router, tags=["Project"])
|
||||
|
||||
# Network Elements (Node/Link Types)
|
||||
|
||||
63
app/auth/keycloak_dependencies.py
Normal file
63
app/auth/keycloak_dependencies.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# import logging
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
oauth2_optional = OAuth2PasswordBearer(
|
||||
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
|
||||
)
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_current_keycloak_sub(
|
||||
token: str | None = Depends(oauth2_optional),
|
||||
) -> UUID:
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if settings.KEYCLOAK_PUBLIC_KEY:
|
||||
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
algorithms = [settings.KEYCLOAK_ALGORITHM]
|
||||
else:
|
||||
key = settings.SECRET_KEY
|
||||
algorithms = [settings.ALGORITHM]
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=algorithms,
|
||||
audience=settings.KEYCLOAK_AUDIENCE or None,
|
||||
)
|
||||
except JWTError as exc:
|
||||
# logger.warning("Keycloak token validation failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing subject claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
return UUID(sub)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid subject claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
60
app/auth/metadata_dependencies.py
Normal file
60
app/auth/metadata_dependencies.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_metadata_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> MetadataRepository:
|
||||
return MetadataRepository(session)
|
||||
|
||||
|
||||
async def get_current_metadata_user(
|
||||
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
try:
|
||||
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while resolving current user",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_metadata_admin(
|
||||
user=Depends(get_current_metadata_user),
|
||||
):
|
||||
if user.is_superuser or user.role == "admin":
|
||||
return user
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _AuthBypassUser:
|
||||
id: UUID = UUID(int=0)
|
||||
role: str = "admin"
|
||||
is_superuser: bool = True
|
||||
is_active: bool = True
|
||||
213
app/auth/project_dependencies.py
Normal file
213
app/auth/project_dependencies.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from psycopg import AsyncConnection
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.dynamic_manager import project_connection_manager
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
DB_ROLE_BIZ_DATA = "biz_data"
|
||||
DB_ROLE_IOT_DATA = "iot_data"
|
||||
DB_TYPE_POSTGRES = "postgresql"
|
||||
DB_TYPE_TIMESCALE = "timescaledb"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectContext:
|
||||
project_id: UUID
|
||||
user_id: UUID
|
||||
project_role: str
|
||||
|
||||
|
||||
async def get_metadata_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> MetadataRepository:
|
||||
return MetadataRepository(session)
|
||||
|
||||
|
||||
async def get_project_context(
|
||||
x_project_id: str = Header(..., alias="X-Project-Id"),
|
||||
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> ProjectContext:
|
||||
try:
|
||||
project_uuid = UUID(x_project_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
project = await metadata_repo.get_project_by_id(project_uuid)
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||
)
|
||||
if project.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
|
||||
)
|
||||
|
||||
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
|
||||
if not membership_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
|
||||
)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while resolving project context",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
|
||||
return ProjectContext(
|
||||
project_id=project.id,
|
||||
user_id=user.id,
|
||||
project_role=membership_role,
|
||||
)
|
||||
|
||||
|
||||
async def get_project_pg_session(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project PostgreSQL routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_POSTGRES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
|
||||
ctx.project_id,
|
||||
DB_ROLE_BIZ_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with sessionmaker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_project_pg_connection(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project PostgreSQL routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_POSTGRES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool = await project_connection_manager.get_pg_pool(
|
||||
ctx.project_id,
|
||||
DB_ROLE_BIZ_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_project_timescale_connection(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_IOT_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project TimescaleDB routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project TimescaleDB routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project TimescaleDB not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_TIMESCALE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project TimescaleDB type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
|
||||
pool = await project_connection_manager.get_timescale_pool(
|
||||
ctx.project_id,
|
||||
DB_ROLE_IOT_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with pool.connection() as conn:
|
||||
yield conn
|
||||
@@ -7,6 +7,7 @@
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,18 +39,16 @@ class AuditAction:
|
||||
|
||||
async def log_audit_event(
|
||||
action: str,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
error_message: Optional[str] = None,
|
||||
db=None, # 新增:可选的数据库实例
|
||||
session=None,
|
||||
):
|
||||
"""
|
||||
记录审计日志
|
||||
@@ -57,67 +56,60 @@ async def log_audit_event(
|
||||
Args:
|
||||
action: 操作类型
|
||||
user_id: 用户ID
|
||||
username: 用户名
|
||||
project_id: 项目ID
|
||||
resource_type: 资源类型
|
||||
resource_id: 资源ID
|
||||
ip_address: IP地址
|
||||
user_agent: User-Agent
|
||||
request_method: 请求方法
|
||||
request_path: 请求路径
|
||||
request_data: 请求数据(敏感字段需脱敏)
|
||||
response_status: 响应状态码
|
||||
error_message: 错误消息
|
||||
db: 数据库实例(可选,如果不提供则尝试获取)
|
||||
session: 元数据库会话(可选)
|
||||
"""
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
|
||||
try:
|
||||
# 脱敏敏感数据
|
||||
if request_data:
|
||||
request_data = sanitize_sensitive_data(request_data)
|
||||
|
||||
# 如果没有提供数据库实例,尝试从全局获取
|
||||
if db is None:
|
||||
try:
|
||||
from app.infra.db.postgresql.database import db as default_db
|
||||
|
||||
# 仅当连接池已初始化时使用
|
||||
if default_db.pool:
|
||||
db = default_db
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 如果仍然没有数据库实例
|
||||
if db is None:
|
||||
# 在某些上下文中可能无法获取,此时静默失败
|
||||
logger.warning("No database instance provided for audit logging")
|
||||
return
|
||||
|
||||
audit_repo = AuditRepository(db)
|
||||
if request_data:
|
||||
request_data = sanitize_sensitive_data(request_data)
|
||||
|
||||
if session is None:
|
||||
async with SessionLocal() as session:
|
||||
audit_repo = AuditRepository(session)
|
||||
await audit_repo.create_log(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
)
|
||||
else:
|
||||
audit_repo = AuditRepository(session)
|
||||
await audit_repo.create_log(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Audit log created: action={action}, user={username or user_id}, "
|
||||
f"resource={resource_type}:{resource_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 审计日志失败不应影响业务流程
|
||||
logger.error(f"Failed to create audit log: {e}", exc_info=True)
|
||||
logger.info(
|
||||
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
|
||||
action,
|
||||
user_id,
|
||||
project_id,
|
||||
resource_type,
|
||||
resource_id,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_sensitive_data(data: dict) -> dict:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -15,6 +17,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 数据加密密钥 (使用 Fernet)
|
||||
ENCRYPTION_KEY: str = "" # 必须从环境变量设置
|
||||
DATABASE_ENCRYPTION_KEY: str = "" # project_databases.dsn_encrypted 专用
|
||||
|
||||
# Database Config (PostgreSQL)
|
||||
DB_NAME: str = "tjwater"
|
||||
@@ -35,13 +38,45 @@ class Settings(BaseSettings):
|
||||
INFLUXDB_ORG: str = "org"
|
||||
INFLUXDB_BUCKET: str = "bucket"
|
||||
|
||||
# Metadata Database Config (PostgreSQL)
|
||||
METADATA_DB_NAME: str = "system_hub"
|
||||
METADATA_DB_HOST: str = "localhost"
|
||||
METADATA_DB_PORT: str = "5432"
|
||||
METADATA_DB_USER: str = "postgres"
|
||||
METADATA_DB_PASSWORD: str = "password"
|
||||
|
||||
METADATA_DB_POOL_SIZE: int = 5
|
||||
METADATA_DB_MAX_OVERFLOW: int = 10
|
||||
|
||||
PROJECT_PG_CACHE_SIZE: int = 50
|
||||
PROJECT_TS_CACHE_SIZE: int = 50
|
||||
PROJECT_PG_POOL_SIZE: int = 5
|
||||
PROJECT_PG_MAX_OVERFLOW: int = 10
|
||||
PROJECT_TS_POOL_MIN_SIZE: int = 1
|
||||
PROJECT_TS_POOL_MAX_SIZE: int = 10
|
||||
|
||||
# Keycloak JWT (optional override)
|
||||
KEYCLOAK_PUBLIC_KEY: str = ""
|
||||
KEYCLOAK_ALGORITHM: str = "RS256"
|
||||
KEYCLOAK_AUDIENCE: str = ""
|
||||
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
db_password = quote_plus(self.DB_PASSWORD)
|
||||
return f"postgresql://{self.DB_USER}:{db_password}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
@property
|
||||
def METADATA_DATABASE_URI(self) -> str:
|
||||
metadata_password = quote_plus(self.METADATA_DB_PASSWORD)
|
||||
return (
|
||||
f"postgresql+psycopg://{self.METADATA_DB_USER}:{metadata_password}"
|
||||
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=Path(__file__).resolve().parents[2] / ".env",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -3,75 +3,94 @@ from typing import Optional
|
||||
import base64
|
||||
import os
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class Encryptor:
|
||||
"""
|
||||
使用 Fernet (对称加密) 实现数据加密/解密
|
||||
适用于加密敏感配置、用户数据等
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, key: Optional[bytes] = None):
|
||||
"""
|
||||
初始化加密器
|
||||
|
||||
|
||||
Args:
|
||||
key: 加密密钥,如果为 None 则从环境变量读取
|
||||
"""
|
||||
if key is None:
|
||||
key_str = os.getenv("ENCRYPTION_KEY")
|
||||
key_str = os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY
|
||||
if not key_str:
|
||||
raise ValueError(
|
||||
"ENCRYPTION_KEY not found in environment variables. "
|
||||
"ENCRYPTION_KEY not found in environment variables or .env. "
|
||||
"Generate one using: Encryptor.generate_key()"
|
||||
)
|
||||
key = key_str.encode()
|
||||
|
||||
|
||||
self.fernet = Fernet(key)
|
||||
|
||||
|
||||
def encrypt(self, data: str) -> str:
|
||||
"""
|
||||
加密字符串
|
||||
|
||||
|
||||
Args:
|
||||
data: 待加密的明文字符串
|
||||
|
||||
|
||||
Returns:
|
||||
Base64 编码的加密字符串
|
||||
"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
|
||||
encrypted_bytes = self.fernet.encrypt(data.encode())
|
||||
return encrypted_bytes.decode()
|
||||
|
||||
|
||||
def decrypt(self, data: str) -> str:
|
||||
"""
|
||||
解密字符串
|
||||
|
||||
|
||||
Args:
|
||||
data: Base64 编码的加密字符串
|
||||
|
||||
|
||||
Returns:
|
||||
解密后的明文字符串
|
||||
"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
|
||||
decrypted_bytes = self.fernet.decrypt(data.encode())
|
||||
return decrypted_bytes.decode()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_key() -> str:
|
||||
"""
|
||||
生成新的 Fernet 加密密钥
|
||||
|
||||
|
||||
Returns:
|
||||
Base64 编码的密钥字符串
|
||||
"""
|
||||
key = Fernet.generate_key()
|
||||
return key.decode()
|
||||
|
||||
|
||||
# 全局加密器实例(懒加载)
|
||||
_encryptor: Optional[Encryptor] = None
|
||||
_database_encryptor: Optional[Encryptor] = None
|
||||
|
||||
|
||||
def is_encryption_configured() -> bool:
|
||||
return bool(os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY)
|
||||
|
||||
|
||||
def is_database_encryption_configured() -> bool:
|
||||
return bool(
|
||||
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||
or settings.DATABASE_ENCRYPTION_KEY
|
||||
or os.getenv("ENCRYPTION_KEY")
|
||||
or settings.ENCRYPTION_KEY
|
||||
)
|
||||
|
||||
|
||||
def get_encryptor() -> Encryptor:
|
||||
"""获取全局加密器实例"""
|
||||
@@ -80,6 +99,26 @@ def get_encryptor() -> Encryptor:
|
||||
_encryptor = Encryptor()
|
||||
return _encryptor
|
||||
|
||||
|
||||
def get_database_encryptor() -> Encryptor:
|
||||
"""获取 project DB DSN 专用加密器实例"""
|
||||
global _database_encryptor
|
||||
if _database_encryptor is None:
|
||||
key_str = (
|
||||
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||
or settings.DATABASE_ENCRYPTION_KEY
|
||||
or os.getenv("ENCRYPTION_KEY")
|
||||
or settings.ENCRYPTION_KEY
|
||||
)
|
||||
if not key_str:
|
||||
raise ValueError(
|
||||
"DATABASE_ENCRYPTION_KEY not found in environment variables or .env. "
|
||||
"Generate one using: Encryptor.generate_key()"
|
||||
)
|
||||
_database_encryptor = Encryptor(key=key_str.encode())
|
||||
return _database_encryptor
|
||||
|
||||
|
||||
# 向后兼容(延迟加载)
|
||||
def __getattr__(name):
|
||||
if name == "encryptor":
|
||||
|
||||
@@ -1,45 +1,42 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
class AuditLogCreate(BaseModel):
|
||||
"""创建审计日志"""
|
||||
user_id: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
user_id: Optional[UUID] = None
|
||||
project_id: Optional[UUID] = None
|
||||
action: str
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
request_method: Optional[str] = None
|
||||
request_path: Optional[str] = None
|
||||
request_data: Optional[dict] = None
|
||||
response_status: Optional[int] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class AuditLogResponse(BaseModel):
|
||||
"""审计日志响应"""
|
||||
id: int
|
||||
user_id: Optional[int]
|
||||
username: Optional[str]
|
||||
id: UUID
|
||||
user_id: Optional[UUID]
|
||||
project_id: Optional[UUID]
|
||||
action: str
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
ip_address: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
request_method: Optional[str]
|
||||
request_path: Optional[str]
|
||||
request_data: Optional[dict]
|
||||
response_status: Optional[int]
|
||||
error_message: Optional[str]
|
||||
timestamp: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class AuditLogQuery(BaseModel):
|
||||
"""审计日志查询参数"""
|
||||
user_id: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
user_id: Optional[UUID] = None
|
||||
project_id: Optional[UUID] = None
|
||||
action: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
start_time: Optional[datetime] = None
|
||||
|
||||
33
app/domain/schemas/metadata.py
Normal file
33
app/domain/schemas/metadata.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeoServerConfigResponse(BaseModel):
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
class ProjectMetaResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
geoserver: Optional[GeoServerConfigResponse]
|
||||
|
||||
|
||||
class ProjectSummaryResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
@@ -6,12 +6,17 @@
|
||||
|
||||
import time
|
||||
import json
|
||||
from uuid import UUID
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.infra.db.postgresql.database import db as default_db
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
import logging
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
return response
|
||||
|
||||
# 4. 提取审计所需信息
|
||||
user_id = None
|
||||
username = None
|
||||
|
||||
# 尝试从请求状态获取当前用户
|
||||
if hasattr(request.state, "user"):
|
||||
user = request.state.user
|
||||
user_id = getattr(user, "id", None)
|
||||
username = getattr(user, "username", None)
|
||||
user_id = await self._resolve_user_id(request)
|
||||
project_id = self._resolve_project_id(request)
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
# 确定操作类型
|
||||
action = self._determine_action(request)
|
||||
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
await log_audit_event(
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_method=request.method,
|
||||
request_path=str(request.url.path),
|
||||
request_data=request_data,
|
||||
response_status=response.status_code,
|
||||
error_message=(
|
||||
None
|
||||
if response.status_code < 400
|
||||
else f"HTTP {response.status_code}"
|
||||
),
|
||||
db=default_db,
|
||||
)
|
||||
except Exception as e:
|
||||
# 审计失败不应影响响应
|
||||
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return response
|
||||
|
||||
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||
project_header = request.headers.get("X-Project-Id")
|
||||
if not project_header:
|
||||
return None
|
||||
try:
|
||||
return UUID(project_header)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
||||
auth_header = request.headers.get("authorization")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
key = (
|
||||
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else settings.SECRET_KEY
|
||||
)
|
||||
algorithms = (
|
||||
[settings.KEYCLOAK_ALGORITHM]
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else [settings.ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
return None
|
||||
keycloak_id = UUID(sub)
|
||||
except (JWTError, ValueError):
|
||||
return None
|
||||
|
||||
async with SessionLocal() as session:
|
||||
repo = MetadataRepository(session)
|
||||
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||
if user and user.is_active:
|
||||
return user.id
|
||||
return None
|
||||
|
||||
def _determine_action(self, request: Request) -> str:
|
||||
"""根据请求路径和方法确定操作类型"""
|
||||
path = request.url.path.lower()
|
||||
|
||||
211
app/infra/db/dynamic_manager.py
Normal file
211
app/infra/db/dynamic_manager.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.rows import dict_row
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PgEngineEntry:
|
||||
engine: AsyncEngine
|
||||
sessionmaker: async_sessionmaker[AsyncSession]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheKey:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
|
||||
|
||||
class ProjectConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
|
||||
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._pg_lock = asyncio.Lock()
|
||||
self._ts_lock = asyncio.Lock()
|
||||
self._pg_raw_lock = asyncio.Lock()
|
||||
|
||||
def _normalize_pg_url(self, url: str) -> str:
|
||||
parsed = make_url(url)
|
||||
if parsed.drivername == "postgresql":
|
||||
parsed = parsed.set(drivername="postgresql+psycopg")
|
||||
return str(parsed)
|
||||
|
||||
async def get_pg_sessionmaker(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> async_sessionmaker[AsyncSession]:
|
||||
async with self._pg_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
entry = self._pg_cache.get(key)
|
||||
if entry:
|
||||
self._pg_cache.move_to_end(key)
|
||||
return entry.sessionmaker
|
||||
|
||||
normalized_url = self._normalize_pg_url(connection_url)
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
engine = create_async_engine(
|
||||
normalized_url,
|
||||
pool_size=pool_min_size,
|
||||
max_overflow=max(0, pool_max_size - pool_min_size),
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
self._pg_cache[key] = PgEngineEntry(
|
||||
engine=engine,
|
||||
sessionmaker=sessionmaker,
|
||||
)
|
||||
await self._evict_pg_if_needed()
|
||||
logger.info(
|
||||
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return sessionmaker
|
||||
|
||||
async def get_timescale_pool(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._ts_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._ts_cache.get(key)
|
||||
if pool:
|
||||
self._ts_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
max_size=pool_max_size,
|
||||
open=False,
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._ts_cache[key] = pool
|
||||
await self._evict_ts_if_needed()
|
||||
logger.info(
|
||||
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return pool
|
||||
|
||||
async def get_pg_pool(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._pg_raw_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._pg_raw_cache.get(key)
|
||||
if pool:
|
||||
self._pg_raw_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
max_size=pool_max_size,
|
||||
open=False,
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._pg_raw_cache[key] = pool
|
||||
await self._evict_pg_raw_if_needed()
|
||||
logger.info(
|
||||
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return pool
|
||||
|
||||
async def _evict_pg_if_needed(self) -> None:
|
||||
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||
key, entry = self._pg_cache.popitem(last=False)
|
||||
await entry.engine.dispose()
|
||||
logger.info(
|
||||
"Evicted PostgreSQL engine for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def _evict_ts_if_needed(self) -> None:
|
||||
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
|
||||
key, pool = self._ts_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Evicted TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def _evict_pg_raw_if_needed(self) -> None:
|
||||
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||
key, pool = self._pg_raw_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Evicted PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
async with self._pg_lock:
|
||||
for key, entry in list(self._pg_cache.items()):
|
||||
await entry.engine.dispose()
|
||||
logger.info(
|
||||
"Closed PostgreSQL engine for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._pg_cache.clear()
|
||||
|
||||
async with self._ts_lock:
|
||||
for key, pool in list(self._ts_cache.items()):
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Closed TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._ts_cache.clear()
|
||||
|
||||
async with self._pg_raw_lock:
|
||||
for key, pool in list(self._pg_raw_cache.items()):
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Closed PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._pg_raw_cache.clear()
|
||||
|
||||
|
||||
project_connection_manager = ProjectConnectionManager()
|
||||
3
app/infra/db/metadata/__init__.py
Normal file
3
app/infra/db/metadata/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .database import get_metadata_session, close_metadata_engine
|
||||
|
||||
__all__ = ["get_metadata_session", "close_metadata_engine"]
|
||||
27
app/infra/db/metadata/database.py
Normal file
27
app/infra/db/metadata/database.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.METADATA_DATABASE_URI,
|
||||
pool_size=settings.METADATA_DB_POOL_SIZE,
|
||||
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def close_metadata_engine() -> None:
|
||||
await engine.dispose()
|
||||
logger.info("Metadata database engine disposed.")
|
||||
115
app/infra/db/metadata/models.py
Normal file
115
app/infra/db/metadata/models.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
keycloak_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
role: Mapped[str] = mapped_column(String(20), default="user")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Project(Base):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
code: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class ProjectDatabase(Base):
|
||||
__tablename__ = "project_databases"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
db_role: Mapped[str] = mapped_column(String(20))
|
||||
db_type: Mapped[str] = mapped_column(String(20))
|
||||
dsn_encrypted: Mapped[str] = mapped_column(Text)
|
||||
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
|
||||
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
|
||||
|
||||
|
||||
class ProjectGeoServerConfig(Base):
|
||||
__tablename__ = "project_geoserver_configs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
|
||||
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
srid: Mapped[int] = mapped_column(Integer, default=4326)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class UserProjectMembership(Base):
|
||||
__tablename__ = "user_project_membership"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
project_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(50))
|
||||
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
@@ -17,7 +17,14 @@ class Database:
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
target_db_name = db_name or self.db_name
|
||||
|
||||
# Get connection string, handling default case where target_db_name might be None
|
||||
if target_db_name:
|
||||
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
|
||||
else:
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
@@ -26,7 +33,7 @@ class Database:
|
||||
open=False, # Don't open immediately, wait for startup
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: default")
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize postgresql connection pool: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .scada_info import ScadaRepository
|
||||
from .scheme import SchemeRepository
|
||||
from app.auth.project_dependencies import get_project_pg_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
# 动态项目 PostgreSQL 连接依赖
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
@router.get("/scada-info")
|
||||
|
||||
@@ -17,7 +17,14 @@ class Database:
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = timescaledb_info.get_pgconn_string()
|
||||
target_db_name = db_name or self.db_name
|
||||
|
||||
# Get connection string, handling default case where target_db_name might be None
|
||||
if target_db_name:
|
||||
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
else:
|
||||
conn_string = timescaledb_info.get_pgconn_string()
|
||||
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
@@ -27,7 +34,7 @@ class Database:
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(
|
||||
f"TimescaleDB connection pool initialized for database: default"
|
||||
f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
|
||||
@@ -46,7 +53,9 @@ class Database:
|
||||
def get_pgconn_string(self, db_name=None):
|
||||
"""Get the TimescaleDB connection string."""
|
||||
target_db_name = db_name or self.db_name
|
||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
if target_db_name:
|
||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
return timescaledb_info.get_pgconn_string()
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator:
|
||||
|
||||
@@ -1,42 +1,32 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .schemas.realtime import RealtimeRepository
|
||||
from .schemas.scheme import SchemeRepository
|
||||
from .schemas.scada import ScadaRepository
|
||||
from .composite_queries import CompositeQueries
|
||||
from app.infra.db.postgresql.database import (
|
||||
get_database_instance as get_postgres_database_instance,
|
||||
from app.auth.project_dependencies import (
|
||||
get_project_pg_connection,
|
||||
get_project_timescale_connection,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
# 动态项目 TimescaleDB 连接依赖
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
# PostgreSQL 数据库连接依赖函数
|
||||
# 动态项目 PostgreSQL 连接依赖
|
||||
async def get_postgres_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_postgres_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
# --- Realtime Endpoints ---
|
||||
|
||||
@@ -1,220 +1,112 @@
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuditRepository:
|
||||
"""审计日志数据访问层"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
"""审计日志数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_log(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
action: str = "",
|
||||
action: str,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
error_message: Optional[str] = None
|
||||
) -> Optional[AuditLogResponse]:
|
||||
"""
|
||||
创建审计日志
|
||||
|
||||
Args:
|
||||
参数说明见 AuditLogCreate
|
||||
|
||||
Returns:
|
||||
创建的审计日志对象
|
||||
"""
|
||||
query = """
|
||||
INSERT INTO audit_logs (
|
||||
user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message
|
||||
)
|
||||
VALUES (
|
||||
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
|
||||
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
|
||||
%(request_data)s, %(response_status)s, %(error_message)s
|
||||
)
|
||||
RETURNING id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {
|
||||
'user_id': user_id,
|
||||
'username': username,
|
||||
'action': action,
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'ip_address': ip_address,
|
||||
'user_agent': user_agent,
|
||||
'request_method': request_method,
|
||||
'request_path': request_path,
|
||||
'request_data': json.dumps(request_data) if request_data else None,
|
||||
'response_status': response_status,
|
||||
'error_message': error_message
|
||||
})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return AuditLogResponse(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating audit log: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
) -> AuditLogResponse:
|
||||
log = models.AuditLog(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(log)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
return AuditLogResponse.model_validate(log)
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
limit: int = 100,
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志
|
||||
|
||||
Args:
|
||||
user_id: 用户ID过滤
|
||||
username: 用户名过滤
|
||||
action: 操作类型过滤
|
||||
resource_type: 资源类型过滤
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
skip: 跳过记录数
|
||||
limit: 限制记录数
|
||||
|
||||
Returns:
|
||||
审计日志列表
|
||||
"""
|
||||
# 构建动态查询
|
||||
conditions = []
|
||||
params = {'skip': skip, 'limit': limit}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT %(limit)s OFFSET %(skip)s
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
rows = await cur.fetchall()
|
||||
return [AuditLogResponse(**row) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying audit logs: {e}")
|
||||
raise
|
||||
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = (
|
||||
select(models.AuditLog)
|
||||
.where(*conditions)
|
||||
.order_by(models.AuditLog.timestamp.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
AuditLogResponse.model_validate(log)
|
||||
for log in result.scalars().all()
|
||||
]
|
||||
|
||||
async def get_log_count(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
获取审计日志数量
|
||||
|
||||
Args:
|
||||
参数同 get_logs
|
||||
|
||||
Returns:
|
||||
日志总数
|
||||
"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
result = await cur.fetchone()
|
||||
return result['count'] if result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting audit logs: {e}")
|
||||
return 0
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
|
||||
result = await self.session.execute(stmt)
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
197
app/infra/repositories/metadata_repository.py
Normal file
197
app/infra/repositories/metadata_repository.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.encryption import (
|
||||
get_database_encryptor,
|
||||
get_encryptor,
|
||||
is_database_encryption_configured,
|
||||
is_encryption_configured,
|
||||
)
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
|
||||
def _normalize_postgres_dsn(dsn: str) -> str:
|
||||
if not dsn or "://" not in dsn:
|
||||
return dsn
|
||||
scheme, rest = dsn.split("://", 1)
|
||||
if scheme not in ("postgresql", "postgres", "postgresql+psycopg"):
|
||||
return dsn
|
||||
if "@" not in rest:
|
||||
return dsn
|
||||
userinfo, hostinfo = rest.rsplit("@", 1)
|
||||
if ":" not in userinfo:
|
||||
return dsn
|
||||
username, password = userinfo.split(":", 1)
|
||||
if "@" not in password:
|
||||
return dsn
|
||||
password = password.replace("@", "%40")
|
||||
return f"{scheme}://{username}:{password}@{hostinfo}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectDbRouting:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
db_type: str
|
||||
dsn: str
|
||||
pool_min_size: int
|
||||
pool_max_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectGeoServerInfo:
|
||||
project_id: UUID
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_admin_password: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectSummary:
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
|
||||
|
||||
class MetadataRepository:
|
||||
"""元数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_user_by_keycloak_id(self, keycloak_id: UUID) -> Optional[models.User]:
|
||||
result = await self.session.execute(
|
||||
select(models.User).where(models.User.keycloak_id == keycloak_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).where(models.Project.id == project_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_membership_role(
|
||||
self, project_id: UUID, user_id: UUID
|
||||
) -> Optional[str]:
|
||||
result = await self.session.execute(
|
||||
select(models.UserProjectMembership.project_role).where(
|
||||
models.UserProjectMembership.project_id == project_id,
|
||||
models.UserProjectMembership.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_db_routing(
|
||||
self, project_id: UUID, db_role: str
|
||||
) -> Optional[ProjectDbRouting]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectDatabase).where(
|
||||
models.ProjectDatabase.project_id == project_id,
|
||||
models.ProjectDatabase.db_role == db_role,
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if not is_database_encryption_configured():
|
||||
raise ValueError("DATABASE_ENCRYPTION_KEY is not configured")
|
||||
encryptor = get_database_encryptor()
|
||||
try:
|
||||
dsn = encryptor.decrypt(record.dsn_encrypted)
|
||||
except InvalidToken:
|
||||
raise ValueError(
|
||||
"Failed to decrypt project DB DSN: DATABASE_ENCRYPTION_KEY mismatch "
|
||||
"or invalid dsn_encrypted value"
|
||||
)
|
||||
dsn = _normalize_postgres_dsn(dsn)
|
||||
return ProjectDbRouting(
|
||||
project_id=record.project_id,
|
||||
db_role=record.db_role,
|
||||
db_type=record.db_type,
|
||||
dsn=dsn,
|
||||
pool_min_size=record.pool_min_size,
|
||||
pool_max_size=record.pool_max_size,
|
||||
)
|
||||
|
||||
async def get_geoserver_config(
|
||||
self, project_id: UUID
|
||||
) -> Optional[ProjectGeoServerInfo]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectGeoServerConfig).where(
|
||||
models.ProjectGeoServerConfig.project_id == project_id
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if record.gs_admin_password_encrypted:
|
||||
if is_encryption_configured():
|
||||
encryptor = get_encryptor()
|
||||
password = encryptor.decrypt(record.gs_admin_password_encrypted)
|
||||
else:
|
||||
password = record.gs_admin_password_encrypted
|
||||
else:
|
||||
password = None
|
||||
return ProjectGeoServerInfo(
|
||||
project_id=record.project_id,
|
||||
gs_base_url=record.gs_base_url,
|
||||
gs_admin_user=record.gs_admin_user,
|
||||
gs_admin_password=password,
|
||||
gs_datastore_name=record.gs_datastore_name,
|
||||
default_extent=record.default_extent,
|
||||
srid=record.srid,
|
||||
)
|
||||
|
||||
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
|
||||
stmt = (
|
||||
select(models.Project, models.UserProjectMembership.project_role)
|
||||
.join(
|
||||
models.UserProjectMembership,
|
||||
models.UserProjectMembership.project_id == models.Project.id,
|
||||
)
|
||||
.where(models.UserProjectMembership.user_id == user_id)
|
||||
.order_by(models.Project.name)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=role,
|
||||
)
|
||||
for project, role in result.all()
|
||||
]
|
||||
|
||||
async def list_all_projects(self) -> List[ProjectSummary]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).order_by(models.Project.name)
|
||||
)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role="owner",
|
||||
)
|
||||
for project in result.scalars().all()
|
||||
]
|
||||
37
app/main.py
37
app/main.py
@@ -9,6 +9,8 @@ import app.services.project_info as project_info
|
||||
from app.api.v1.router import api_router
|
||||
from app.infra.db.timescaledb.database import db as tsdb
|
||||
from app.infra.db.postgresql.database import db as pgdb
|
||||
from app.infra.db.dynamic_manager import project_connection_manager
|
||||
from app.infra.db.metadata.database import close_metadata_engine
|
||||
from app.services.tjnetwork import open_project
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -33,7 +35,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
await tsdb.open()
|
||||
await pgdb.open()
|
||||
|
||||
|
||||
# 将数据库实例存储到 app.state,供依赖项使用
|
||||
app.state.db = pgdb
|
||||
logger.info("Database connection pool initialized and stored in app.state")
|
||||
@@ -46,6 +48,8 @@ async def lifespan(app: FastAPI):
|
||||
# 清理资源
|
||||
await tsdb.close()
|
||||
await pgdb.close()
|
||||
await project_connection_manager.close_all()
|
||||
await close_metadata_engine()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
@@ -58,22 +62,25 @@ app = FastAPI(
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
# 配置 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 允许所有来源
|
||||
allow_credentials=True, # 允许传递凭证(Cookie、HTTP 头等)
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有 HTTP 头
|
||||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# 添加审计中间件(可选,记录关键操作)
|
||||
# 如果需要启用审计日志,取消下面的注释
|
||||
app.add_middleware(AuditMiddleware)
|
||||
|
||||
# Include Routers
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
# Legcy Routers without version prefix
|
||||
app.include_router(api_router)
|
||||
|
||||
# 配置中间件
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
# 添加审计中间件(可选,记录关键操作)
|
||||
app.add_middleware(AuditMiddleware)
|
||||
# 配置 CORS 中间件
|
||||
# 确保这是你最后一个添加的 app.add_middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:3000", # 必须明确指定
|
||||
"http://127.0.0.1:3000", # 建议同时加上这个
|
||||
],
|
||||
allow_credentials=True, # 既然这里是 True
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@@ -1235,7 +1235,7 @@ def run_simulation(
|
||||
starttime = time.time()
|
||||
if simulation_type.upper() == "REALTIME":
|
||||
TimescaleInternalStorage.store_realtime_simulation(
|
||||
node_result, link_result, modify_pattern_start_time
|
||||
node_result, link_result, modify_pattern_start_time, db_name=name
|
||||
)
|
||||
elif simulation_type.upper() == "EXTENDED":
|
||||
TimescaleInternalStorage.store_scheme_simulation(
|
||||
@@ -1245,6 +1245,7 @@ def run_simulation(
|
||||
link_result,
|
||||
modify_pattern_start_time,
|
||||
num_periods_result,
|
||||
db_name=name,
|
||||
)
|
||||
endtime = time.time()
|
||||
logging.info("store time: %f", endtime - starttime)
|
||||
|
||||
33
scripts/encrypt_string.py
Normal file
33
scripts/encrypt_string.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 将项目根目录添加到 python 路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from app.core.encryption import get_database_encryptor
|
||||
|
||||
|
||||
def main() -> int:
|
||||
plaintext = None
|
||||
if not sys.stdin.isatty():
|
||||
stdin_text = sys.stdin.read()
|
||||
if stdin_text != "":
|
||||
plaintext = stdin_text.rstrip("\r\n")
|
||||
if plaintext is None and len(sys.argv) >= 2:
|
||||
plaintext = sys.argv[1]
|
||||
if plaintext is None:
|
||||
try:
|
||||
plaintext = input("请输入要加密的文本: ")
|
||||
except EOFError:
|
||||
plaintext = ""
|
||||
if not plaintext.strip():
|
||||
print("Error: plaintext string cannot be empty.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
token = get_database_encryptor().encrypt(plaintext)
|
||||
print(token)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -16,6 +16,6 @@ if __name__ == "__main__":
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
workers=2, # 这里可以设置多进程
|
||||
workers=4, # 这里可以设置多进程
|
||||
loop="asyncio",
|
||||
)
|
||||
|
||||
119
tests/unit/test_metadata_repository_dsn_decrypt.py
Normal file
119
tests/unit/test_metadata_repository_dsn_decrypt.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
def __init__(self, record):
|
||||
self._record = record
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._record
|
||||
|
||||
|
||||
class _DummyEncryptor:
|
||||
def __init__(self, decrypted=None, raise_invalid_token=False):
|
||||
self._decrypted = decrypted
|
||||
self._raise_invalid_token = raise_invalid_token
|
||||
self.encrypted_values = []
|
||||
|
||||
def decrypt(self, _value):
|
||||
if self._raise_invalid_token:
|
||||
raise InvalidToken()
|
||||
return self._decrypted
|
||||
|
||||
def _build_record(dsn_encrypted: str):
|
||||
return SimpleNamespace(
|
||||
project_id=uuid4(),
|
||||
db_role="biz_data",
|
||||
db_type="postgresql",
|
||||
dsn_encrypted=dsn_encrypted,
|
||||
pool_min_size=1,
|
||||
pool_max_size=5,
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_token_with_plaintext_dsn_value_raises_clear_error(monkeypatch):
|
||||
record = _build_record("postgresql://user:p@ss@localhost:5432/db")
|
||||
session = SimpleNamespace(
|
||||
execute=None,
|
||||
commit=None,
|
||||
)
|
||||
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||
session.commit = AsyncMock()
|
||||
encryptor = _DummyEncryptor(raise_invalid_token=True)
|
||||
repo = MetadataRepository(session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||
lambda: encryptor,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
|
||||
):
|
||||
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||
session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
def test_invalid_token_with_non_dsn_value_raises_clear_error(monkeypatch):
|
||||
record = _build_record("gAAAAABinvalidciphertext")
|
||||
session = SimpleNamespace(
|
||||
execute=None,
|
||||
commit=None,
|
||||
)
|
||||
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||
session.commit = AsyncMock()
|
||||
repo = MetadataRepository(session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||
lambda: _DummyEncryptor(raise_invalid_token=True),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
|
||||
):
|
||||
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||
session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
def test_encrypted_dsn_decrypts_without_migration(monkeypatch):
|
||||
record = _build_record("encrypted-value")
|
||||
session = SimpleNamespace(
|
||||
execute=None,
|
||||
commit=None,
|
||||
)
|
||||
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||
session.commit = AsyncMock()
|
||||
repo = MetadataRepository(session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||
lambda: _DummyEncryptor(decrypted="postgresql://u:p@ss@host/db"),
|
||||
)
|
||||
|
||||
routing = asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||
|
||||
assert routing.dsn == "postgresql://u:p%40ss@host/db"
|
||||
session.commit.assert_not_awaited()
|
||||
Reference in New Issue
Block a user