Compare commits
20 Commits
dc38313cdc
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
| 80b6970970 | |||
| 364a8c8ec2 | |||
| 52ccb8abf1 | |||
| 0bc4058f23 | |||
| 0d3e6ca4fa | |||
| 6fc3aa5209 | |||
| 1b1b0a3697 | |||
| 2826999ddc | |||
| efc05f7278 | |||
| 29209f5c63 | |||
| 020432ad0e | |||
| 780a48d927 | |||
| ff2011ae24 | |||
| f5069a5606 | |||
| eb45e4aaa5 | |||
| a472639b8a | |||
| a0987105dc | |||
| a41be9c362 | |||
| 63b31b46b9 | |||
| e4f864a28c |
48
.env.example
48
.env.example
@@ -1,6 +1,6 @@
|
|||||||
# TJWater Server 环境变量配置模板
|
# TJWater Server 环境变量配置模板
|
||||||
# 复制此文件为 .env 并填写实际值
|
# 复制此文件为 .env 并填写实际值
|
||||||
|
NETWORK_NAME="szh"
|
||||||
# ============================================
|
# ============================================
|
||||||
# 安全配置 (必填)
|
# 安全配置 (必填)
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -12,24 +12,44 @@ SECRET_KEY=your-secret-key-here-change-in-production-use-openssl-rand-hex-32
|
|||||||
# 数据加密密钥 - 用于敏感数据加密
|
# 数据加密密钥 - 用于敏感数据加密
|
||||||
# 生成方式: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
# 生成方式: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||||
ENCRYPTION_KEY=
|
ENCRYPTION_KEY=
|
||||||
|
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 数据库配置 (PostgreSQL)
|
# 数据库配置 (PostgreSQL)
|
||||||
# ============================================
|
# ============================================
|
||||||
DB_NAME=tjwater
|
DB_NAME="tjwater"
|
||||||
DB_HOST=localhost
|
DB_HOST="localhost"
|
||||||
DB_PORT=5432
|
DB_PORT="5432"
|
||||||
DB_USER=postgres
|
DB_USER="postgres"
|
||||||
DB_PASSWORD=password
|
DB_PASSWORD="password"
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 数据库配置 (TimescaleDB)
|
# 数据库配置 (TimescaleDB)
|
||||||
# ============================================
|
# ============================================
|
||||||
TIMESCALEDB_DB_NAME=szh
|
TIMESCALEDB_DB_NAME="szh"
|
||||||
TIMESCALEDB_DB_HOST=localhost
|
TIMESCALEDB_DB_HOST="localhost"
|
||||||
TIMESCALEDB_DB_PORT=5433
|
TIMESCALEDB_DB_PORT="5433"
|
||||||
TIMESCALEDB_DB_USER=tjwater
|
TIMESCALEDB_DB_USER="tjwater"
|
||||||
TIMESCALEDB_DB_PASSWORD=Tjwater@123456
|
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 元数据数据库配置 (Metadata DB)
|
||||||
|
# ============================================
|
||||||
|
METADATA_DB_NAME="system_hub"
|
||||||
|
METADATA_DB_HOST="localhost"
|
||||||
|
METADATA_DB_PORT="5432"
|
||||||
|
METADATA_DB_USER="tjwater"
|
||||||
|
METADATA_DB_PASSWORD="password"
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 项目连接缓存与连接池配置
|
||||||
|
# ============================================
|
||||||
|
PROJECT_PG_CACHE_SIZE=50
|
||||||
|
PROJECT_TS_CACHE_SIZE=50
|
||||||
|
PROJECT_PG_POOL_SIZE=5
|
||||||
|
PROJECT_PG_MAX_OVERFLOW=10
|
||||||
|
PROJECT_TS_POOL_MIN_SIZE=1
|
||||||
|
PROJECT_TS_POOL_MAX_SIZE=10
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# InfluxDB 配置 (时序数据)
|
# InfluxDB 配置 (时序数据)
|
||||||
@@ -46,6 +66,12 @@ TIMESCALEDB_DB_PASSWORD=Tjwater@123456
|
|||||||
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||||
# ALGORITHM=HS256
|
# ALGORITHM=HS256
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# Keycloak JWT (可选)
|
||||||
|
# ============================================
|
||||||
|
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||||
|
# KEYCLOAK_ALGORITHM=RS256
|
||||||
|
|
||||||
# ============================================
|
# ============================================
|
||||||
# 其他配置
|
# 其他配置
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|||||||
23
.env.local
Normal file
23
.env.local
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
NETWORK_NAME="tjwater"
|
||||||
|
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
|
||||||
|
KEYCLOAK_ALGORITHM="RS256"
|
||||||
|
KEYCLOAK_AUDIENCE="account"
|
||||||
|
|
||||||
|
DB_NAME="tjwater"
|
||||||
|
DB_HOST="192.168.1.114"
|
||||||
|
DB_PORT="5432"
|
||||||
|
DB_USER="tjwater"
|
||||||
|
DB_PASSWORD="Tjwater@123456"
|
||||||
|
|
||||||
|
TIMESCALEDB_DB_NAME="tjwater"
|
||||||
|
TIMESCALEDB_DB_HOST="192.168.1.114"
|
||||||
|
TIMESCALEDB_DB_PORT="5433"
|
||||||
|
TIMESCALEDB_DB_USER="tjwater"
|
||||||
|
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||||
|
|
||||||
|
METADATA_DB_NAME="system_hub"
|
||||||
|
METADATA_DB_HOST="192.168.1.114"
|
||||||
|
METADATA_DB_PORT="5432"
|
||||||
|
METADATA_DB_USER="tjwater"
|
||||||
|
METADATA_DB_PASSWORD="Tjwater@123456"
|
||||||
|
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||||
2
.github/copilot-instructions.md
vendored
2
.github/copilot-instructions.md
vendored
@@ -87,7 +87,6 @@ Default admin accounts:
|
|||||||
- `app/core/config.py`: Settings management using `pydantic-settings`
|
- `app/core/config.py`: Settings management using `pydantic-settings`
|
||||||
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
|
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
|
||||||
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
|
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
|
||||||
- `configs/project_info.yml`: Default project configuration (auto-loaded on startup)
|
|
||||||
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
|
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
|
||||||
|
|
||||||
## Important Conventions
|
## Important Conventions
|
||||||
@@ -148,7 +147,6 @@ async def delete_data(id: int, current_user = Depends(get_current_admin)):
|
|||||||
|
|
||||||
- On startup, `main.py` automatically loads project from `project_info.name` if set
|
- On startup, `main.py` automatically loads project from `project_info.name` if set
|
||||||
- Projects are opened via `open_project(name)` from `tjnetwork` service
|
- Projects are opened via `open_project(name)` from `tjnetwork` service
|
||||||
- Initial project config comes from `configs/project_info.yml`
|
|
||||||
|
|
||||||
### Audit Logging
|
### Audit Logging
|
||||||
|
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ build/
|
|||||||
.env
|
.env
|
||||||
*.dump
|
*.dump
|
||||||
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
|
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
|
||||||
|
.vscode/
|
||||||
|
|||||||
24
Dockerfile
Normal file
24
Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
FROM continuumio/miniconda3:latest
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 安装 Python 3.12 和 pymetis (通过 conda-forge 避免编译问题)
|
||||||
|
RUN conda install -y -c conda-forge python=3.12 pymetis && \
|
||||||
|
conda clean -afy
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# 将代码放入子目录 'app',将数据放入子目录 'db_inp'
|
||||||
|
# 这样临时文件默认会生成在 /app 下,而代码在 /app/app 下,实现了分离
|
||||||
|
COPY app ./app
|
||||||
|
COPY db_inp ./db_inp
|
||||||
|
COPY temp ./temp
|
||||||
|
COPY .env .
|
||||||
|
|
||||||
|
# 设置 PYTHONPATH 以便 uvicorn 找到 app 模块
|
||||||
|
ENV PYTHONPATH=/app
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||||
@@ -103,6 +103,7 @@ def burst_analysis(
|
|||||||
if isinstance(burst_ID, list):
|
if isinstance(burst_ID, list):
|
||||||
if (burst_size is not None) and (type(burst_size) is not list):
|
if (burst_size is not None) and (type(burst_size) is not list):
|
||||||
return json.dumps("Type mismatch.")
|
return json.dumps("Type mismatch.")
|
||||||
|
# 转化为列表形式
|
||||||
elif isinstance(burst_ID, str):
|
elif isinstance(burst_ID, str):
|
||||||
burst_ID = [burst_ID]
|
burst_ID = [burst_ID]
|
||||||
if burst_size is not None:
|
if burst_size is not None:
|
||||||
@@ -344,18 +345,42 @@ def flushing_analysis(
|
|||||||
# status['setting'] = 0.1036 * pow(valve_k, -3.105)
|
# status['setting'] = 0.1036 * pow(valve_k, -3.105)
|
||||||
# cs.append(status)
|
# cs.append(status)
|
||||||
# set_status(new_name,cs)
|
# set_status(new_name,cs)
|
||||||
units = get_option(new_name)
|
options = get_option(new_name)
|
||||||
|
units = options["UNITS"]
|
||||||
# step 2. set the emitter coefficient of drainage node or add flush flow to the drainage node
|
# step 2. set the emitter coefficient of drainage node or add flush flow to the drainage node
|
||||||
|
# 新建 pattern
|
||||||
|
time_option = get_time(new_name)
|
||||||
|
hydraulic_step = time_option["HYDRAULIC TIMESTEP"]
|
||||||
|
secs = from_clock_to_seconds_2(hydraulic_step)
|
||||||
|
cs_pattern = ChangeSet()
|
||||||
|
pt = {}
|
||||||
|
factors = []
|
||||||
|
tmp_duration = modify_total_duration
|
||||||
|
while tmp_duration > 0:
|
||||||
|
factors.append(1.0)
|
||||||
|
tmp_duration = tmp_duration - secs
|
||||||
|
pt["id"] = "flushing_pt"
|
||||||
|
pt["factors"] = factors
|
||||||
|
cs_pattern.append(pt)
|
||||||
|
add_pattern(new_name, cs_pattern)
|
||||||
|
# 为 emitter_demand 添加新的 pattern
|
||||||
emitter_demand = get_demand(new_name, drainage_node_ID)
|
emitter_demand = get_demand(new_name, drainage_node_ID)
|
||||||
cs = ChangeSet()
|
cs = ChangeSet()
|
||||||
if flushing_flow > 0:
|
if flushing_flow > 0:
|
||||||
for r in emitter_demand["demands"]:
|
if units == "LPS":
|
||||||
if units == "LPS":
|
emitter_demand["demands"].append(
|
||||||
r["demand"] += flushing_flow / 3.6
|
{
|
||||||
elif units == "CMH":
|
"demand": flushing_flow / 3.6,
|
||||||
r["demand"] += flushing_flow
|
"pattern": "flushing_pt",
|
||||||
cs.append(emitter_demand)
|
"category": None,
|
||||||
set_demand(new_name, cs)
|
}
|
||||||
|
)
|
||||||
|
elif units == "CMH":
|
||||||
|
emitter_demand["demands"].append(
|
||||||
|
{"demand": flushing_flow, "pattern": "flushing_pt", "category": None}
|
||||||
|
)
|
||||||
|
cs.append(emitter_demand)
|
||||||
|
set_demand(new_name, cs)
|
||||||
else:
|
else:
|
||||||
pipes = get_node_links(new_name, drainage_node_ID)
|
pipes = get_node_links(new_name, drainage_node_ID)
|
||||||
flush_diameter = 50
|
flush_diameter = 50
|
||||||
@@ -484,7 +509,7 @@ def contaminant_simulation(
|
|||||||
# step 2. set pattern
|
# step 2. set pattern
|
||||||
if source_pattern != None:
|
if source_pattern != None:
|
||||||
pt = get_pattern(new_name, source_pattern)
|
pt = get_pattern(new_name, source_pattern)
|
||||||
if pt == None:
|
if len(pt) == 0:
|
||||||
str_response = str("cant find source_pattern")
|
str_response = str("cant find source_pattern")
|
||||||
return str_response
|
return str_response
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,33 +4,38 @@
|
|||||||
仅管理员可访问
|
仅管理员可访问
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from uuid import UUID
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query
|
||||||
from app.domain.schemas.audit import AuditLogResponse, AuditLogQuery
|
from app.domain.schemas.audit import AuditLogResponse
|
||||||
from app.domain.schemas.user import UserInDB
|
|
||||||
from app.infra.repositories.audit_repository import AuditRepository
|
from app.infra.repositories.audit_repository import AuditRepository
|
||||||
from app.auth.dependencies import get_user_repository, get_db
|
from app.auth.metadata_dependencies import (
|
||||||
from app.auth.permissions import get_current_admin
|
get_current_metadata_admin,
|
||||||
from app.infra.db.postgresql.database import Database
|
get_current_metadata_user,
|
||||||
|
)
|
||||||
|
from app.infra.db.metadata.database import get_metadata_session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
async def get_audit_repository(db: Database = Depends(get_db)) -> AuditRepository:
|
async def get_audit_repository(
|
||||||
|
session: AsyncSession = Depends(get_metadata_session),
|
||||||
|
) -> AuditRepository:
|
||||||
"""获取审计日志仓储"""
|
"""获取审计日志仓储"""
|
||||||
return AuditRepository(db)
|
return AuditRepository(session)
|
||||||
|
|
||||||
@router.get("/logs", response_model=List[AuditLogResponse])
|
@router.get("/logs", response_model=List[AuditLogResponse])
|
||||||
async def get_audit_logs(
|
async def get_audit_logs(
|
||||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||||
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
||||||
current_user: UserInDB = Depends(get_current_admin),
|
current_user=Depends(get_current_metadata_admin),
|
||||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||||
) -> List[AuditLogResponse]:
|
) -> List[AuditLogResponse]:
|
||||||
"""
|
"""
|
||||||
查询审计日志(仅管理员)
|
查询审计日志(仅管理员)
|
||||||
@@ -39,7 +44,7 @@ async def get_audit_logs(
|
|||||||
"""
|
"""
|
||||||
logs = await audit_repo.get_logs(
|
logs = await audit_repo.get_logs(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
action=action,
|
action=action,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
@@ -51,21 +56,21 @@ async def get_audit_logs(
|
|||||||
|
|
||||||
@router.get("/logs/count")
|
@router.get("/logs/count")
|
||||||
async def get_audit_logs_count(
|
async def get_audit_logs_count(
|
||||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||||
current_user: UserInDB = Depends(get_current_admin),
|
current_user=Depends(get_current_metadata_admin),
|
||||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取审计日志总数(仅管理员)
|
获取审计日志总数(仅管理员)
|
||||||
"""
|
"""
|
||||||
count = await audit_repo.get_log_count(
|
count = await audit_repo.get_log_count(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
action=action,
|
action=action,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
@@ -80,8 +85,8 @@ async def get_my_audit_logs(
|
|||||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||||
skip: int = Query(0, ge=0),
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
current_user: UserInDB = Depends(get_current_admin),
|
current_user=Depends(get_current_metadata_user),
|
||||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||||
) -> List[AuditLogResponse]:
|
) -> List[AuditLogResponse]:
|
||||||
"""
|
"""
|
||||||
查询当前用户的审计日志
|
查询当前用户的审计日志
|
||||||
|
|||||||
101
app/api/v1/endpoints/meta.py
Normal file
101
app/api/v1/endpoints/meta.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from psycopg import AsyncConnection
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.project_dependencies import (
|
||||||
|
ProjectContext,
|
||||||
|
get_project_context,
|
||||||
|
get_project_pg_session,
|
||||||
|
get_project_timescale_connection,
|
||||||
|
get_metadata_repository,
|
||||||
|
)
|
||||||
|
from app.auth.metadata_dependencies import get_current_metadata_user
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.domain.schemas.metadata import (
|
||||||
|
GeoServerConfigResponse,
|
||||||
|
ProjectMetaResponse,
|
||||||
|
ProjectSummaryResponse,
|
||||||
|
)
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meta/project", response_model=ProjectMetaResponse)
|
||||||
|
async def get_project_metadata(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
):
|
||||||
|
project = await metadata_repo.get_project_by_id(ctx.project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||||
|
)
|
||||||
|
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
|
||||||
|
geoserver_payload = (
|
||||||
|
GeoServerConfigResponse(
|
||||||
|
gs_base_url=geoserver.gs_base_url,
|
||||||
|
gs_admin_user=geoserver.gs_admin_user,
|
||||||
|
gs_datastore_name=geoserver.gs_datastore_name,
|
||||||
|
default_extent=geoserver.default_extent,
|
||||||
|
srid=geoserver.srid,
|
||||||
|
)
|
||||||
|
if geoserver
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return ProjectMetaResponse(
|
||||||
|
project_id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role=ctx.project_role,
|
||||||
|
geoserver=geoserver_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
|
||||||
|
async def list_user_projects(
|
||||||
|
current_user=Depends(get_current_metadata_user),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
projects = await metadata_repo.list_projects_for_user(current_user.id)
|
||||||
|
except SQLAlchemyError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Metadata DB error while listing projects for user %s",
|
||||||
|
current_user.id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Metadata database error: {exc}",
|
||||||
|
) from exc
|
||||||
|
return [
|
||||||
|
ProjectSummaryResponse(
|
||||||
|
project_id=project.project_id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role=project.project_role,
|
||||||
|
)
|
||||||
|
for project in projects
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/meta/db/health")
|
||||||
|
async def project_db_health(
|
||||||
|
pg_session: AsyncSession = Depends(get_project_pg_session),
|
||||||
|
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||||
|
):
|
||||||
|
await pg_session.execute(text("SELECT 1"))
|
||||||
|
async with ts_conn.cursor() as cur:
|
||||||
|
await cur.execute("SELECT 1")
|
||||||
|
return {"postgres": "ok", "timescale": "ok"}
|
||||||
@@ -4,6 +4,8 @@ from fastapi.responses import PlainTextResponse
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
import app.services.project_info as project_info
|
import app.services.project_info as project_info
|
||||||
from app.native.api import ChangeSet
|
from app.native.api import ChangeSet
|
||||||
|
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
|
||||||
|
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
|
||||||
from app.services.tjnetwork import (
|
from app.services.tjnetwork import (
|
||||||
list_project,
|
list_project,
|
||||||
have_project,
|
have_project,
|
||||||
@@ -62,6 +64,28 @@ async def is_project_open_endpoint(network: str):
|
|||||||
@router.post("/openproject/")
|
@router.post("/openproject/")
|
||||||
async def open_project_endpoint(network: str):
|
async def open_project_endpoint(network: str):
|
||||||
open_project(network)
|
open_project(network)
|
||||||
|
|
||||||
|
# 尝试连接指定数据库
|
||||||
|
try:
|
||||||
|
# 初始化 PostgreSQL 连接池
|
||||||
|
pg_instance = await get_pg_db(network)
|
||||||
|
async with pg_instance.get_connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute("SELECT 1")
|
||||||
|
|
||||||
|
# 初始化 TimescaleDB 连接池
|
||||||
|
ts_instance = await get_ts_db(network)
|
||||||
|
async with ts_instance.get_connection() as conn:
|
||||||
|
async with conn.cursor() as cur:
|
||||||
|
await cur.execute("SELECT 1")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 记录错误但不阻断项目打开,或者根据需求决定是否阻断
|
||||||
|
# 这里选择打印错误,因为 open_project 原本只负责原生部分
|
||||||
|
print(f"Failed to connect to databases for {network}: {str(e)}")
|
||||||
|
# 如果数据库连接是必须的,可以抛出异常:
|
||||||
|
# raise HTTPException(status_code=500, detail=f"Database connection failed: {str(e)}")
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
@router.post("/closeproject/")
|
@router.post("/closeproject/")
|
||||||
|
|||||||
@@ -115,7 +115,15 @@ class PressureSensorPlacement(BaseModel):
|
|||||||
def run_simulation_manually_by_date(
|
def run_simulation_manually_by_date(
|
||||||
network_name: str, base_date: datetime, start_time: str, duration: int
|
network_name: str, base_date: datetime, start_time: str, duration: int
|
||||||
) -> None:
|
) -> None:
|
||||||
start_hour, start_minute, start_second = map(int, start_time.split(":"))
|
time_parts = list(map(int, start_time.split(":")))
|
||||||
|
if len(time_parts) == 2:
|
||||||
|
start_hour, start_minute = time_parts
|
||||||
|
start_second = 0
|
||||||
|
elif len(time_parts) == 3:
|
||||||
|
start_hour, start_minute, start_second = time_parts
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid start_time format. Use HH:MM or HH:MM:SS")
|
||||||
|
|
||||||
start_datetime = base_date.replace(
|
start_datetime = base_date.replace(
|
||||||
hour=start_hour, minute=start_minute, second=start_second
|
hour=start_hour, minute=start_minute, second=start_second
|
||||||
)
|
)
|
||||||
@@ -196,10 +204,8 @@ async def burst_analysis_endpoint(
|
|||||||
async def fastapi_burst_analysis(
|
async def fastapi_burst_analysis(
|
||||||
network: str = Query(...),
|
network: str = Query(...),
|
||||||
modify_pattern_start_time: str = Query(...),
|
modify_pattern_start_time: str = Query(...),
|
||||||
burst_ID: list | str = Query(..., alias="burst_ID[]"), # 添加别名以匹配 URL
|
burst_ID: list[str] = Query(...),
|
||||||
burst_size: list | float | int = Query(
|
burst_size: list[float] = Query(...),
|
||||||
..., alias="burst_size[]"
|
|
||||||
), # 添加别名以匹配 URL
|
|
||||||
modify_total_duration: int = Query(...),
|
modify_total_duration: int = Query(...),
|
||||||
scheme_name: str = Query(...),
|
scheme_name: str = Query(...),
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -656,13 +662,9 @@ async def fastapi_run_simulation_manually_by_date(
|
|||||||
globals.realtime_region_pipe_flow_and_demand_id,
|
globals.realtime_region_pipe_flow_and_demand_id,
|
||||||
)
|
)
|
||||||
base_date = datetime.strptime(item["simulation_date"], "%Y-%m-%d")
|
base_date = datetime.strptime(item["simulation_date"], "%Y-%m-%d")
|
||||||
thread = threading.Thread(
|
run_simulation_manually_by_date(
|
||||||
target=lambda: run_simulation_manually_by_date(
|
item["name"], base_date, item["start_time"], item["duration"]
|
||||||
item["name"], base_date, item["start_time"], item["duration"]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
return {"status": "success"}
|
return {"status": "success"}
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return {"status": "error", "message": str(exc)}
|
return {"status": "error", "message": str(exc)}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.api.v1.endpoints import (
|
|||||||
cache,
|
cache,
|
||||||
user_management, # 新增:用户管理
|
user_management, # 新增:用户管理
|
||||||
audit, # 新增:审计日志
|
audit, # 新增:审计日志
|
||||||
|
meta,
|
||||||
)
|
)
|
||||||
from app.api.v1.endpoints.network import (
|
from app.api.v1.endpoints.network import (
|
||||||
general,
|
general,
|
||||||
@@ -46,6 +47,7 @@ api_router = APIRouter()
|
|||||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
||||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
||||||
|
api_router.include_router(meta.router, tags=["Metadata"])
|
||||||
api_router.include_router(project.router, tags=["Project"])
|
api_router.include_router(project.router, tags=["Project"])
|
||||||
|
|
||||||
# Network Elements (Node/Link Types)
|
# Network Elements (Node/Link Types)
|
||||||
|
|||||||
63
app/auth/keycloak_dependencies.py
Normal file
63
app/auth/keycloak_dependencies.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# import logging
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
oauth2_optional = OAuth2PasswordBearer(
|
||||||
|
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_keycloak_sub(
|
||||||
|
token: str | None = Depends(oauth2_optional),
|
||||||
|
) -> UUID:
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Not authenticated",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
if settings.KEYCLOAK_PUBLIC_KEY:
|
||||||
|
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||||
|
algorithms = [settings.KEYCLOAK_ALGORITHM]
|
||||||
|
else:
|
||||||
|
key = settings.SECRET_KEY
|
||||||
|
algorithms = [settings.ALGORITHM]
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
key,
|
||||||
|
algorithms=algorithms,
|
||||||
|
audience=settings.KEYCLOAK_AUDIENCE or None,
|
||||||
|
)
|
||||||
|
except JWTError as exc:
|
||||||
|
# logger.warning("Keycloak token validation failed: %s", exc)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
sub = payload.get("sub")
|
||||||
|
if not sub:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing subject claim",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return UUID(sub)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid subject claim",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
) from exc
|
||||||
60
app/auth/metadata_dependencies.py
Normal file
60
app/auth/metadata_dependencies.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.infra.db.metadata.database import get_metadata_session
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_metadata_repository(
|
||||||
|
session: AsyncSession = Depends(get_metadata_session),
|
||||||
|
) -> MetadataRepository:
|
||||||
|
return MetadataRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_metadata_user(
|
||||||
|
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||||
|
except SQLAlchemyError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Metadata DB error while resolving current user",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Metadata database error: {exc}",
|
||||||
|
) from exc
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_metadata_admin(
|
||||||
|
user=Depends(get_current_metadata_user),
|
||||||
|
):
|
||||||
|
if user.is_superuser or user.role == "admin":
|
||||||
|
return user
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _AuthBypassUser:
|
||||||
|
id: UUID = UUID(int=0)
|
||||||
|
role: str = "admin"
|
||||||
|
is_superuser: bool = True
|
||||||
|
is_active: bool = True
|
||||||
213
app/auth/project_dependencies.py
Normal file
213
app/auth/project_dependencies.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
|
from psycopg import AsyncConnection
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.infra.db.dynamic_manager import project_connection_manager
|
||||||
|
from app.infra.db.metadata.database import get_metadata_session
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
|
DB_ROLE_BIZ_DATA = "biz_data"
|
||||||
|
DB_ROLE_IOT_DATA = "iot_data"
|
||||||
|
DB_TYPE_POSTGRES = "postgresql"
|
||||||
|
DB_TYPE_TIMESCALE = "timescaledb"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectContext:
|
||||||
|
project_id: UUID
|
||||||
|
user_id: UUID
|
||||||
|
project_role: str
|
||||||
|
|
||||||
|
|
||||||
|
async def get_metadata_repository(
|
||||||
|
session: AsyncSession = Depends(get_metadata_session),
|
||||||
|
) -> MetadataRepository:
|
||||||
|
return MetadataRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_context(
|
||||||
|
x_project_id: str = Header(..., alias="X-Project-Id"),
|
||||||
|
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> ProjectContext:
|
||||||
|
try:
|
||||||
|
project_uuid = UUID(x_project_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
project = await metadata_repo.get_project_by_id(project_uuid)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||||
|
)
|
||||||
|
if project.status != "active":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
|
||||||
|
)
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||||
|
)
|
||||||
|
|
||||||
|
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
|
||||||
|
if not membership_role:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
|
||||||
|
)
|
||||||
|
except SQLAlchemyError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Metadata DB error while resolving project context",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Metadata database error: {exc}",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return ProjectContext(
|
||||||
|
project_id=project.id,
|
||||||
|
user_id=user.id,
|
||||||
|
project_role=membership_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_pg_session(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
try:
|
||||||
|
routing = await metadata_repo.get_project_db_routing(
|
||||||
|
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Invalid project PostgreSQL routing DSN configuration",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||||
|
) from exc
|
||||||
|
if not routing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL not configured",
|
||||||
|
)
|
||||||
|
if routing.db_type != DB_TYPE_POSTGRES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL type mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
|
||||||
|
ctx.project_id,
|
||||||
|
DB_ROLE_BIZ_DATA,
|
||||||
|
routing.dsn,
|
||||||
|
pool_min_size,
|
||||||
|
pool_max_size,
|
||||||
|
)
|
||||||
|
async with sessionmaker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_pg_connection(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> AsyncGenerator[AsyncConnection, None]:
|
||||||
|
try:
|
||||||
|
routing = await metadata_repo.get_project_db_routing(
|
||||||
|
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Invalid project PostgreSQL routing DSN configuration",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||||
|
) from exc
|
||||||
|
if not routing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL not configured",
|
||||||
|
)
|
||||||
|
if routing.db_type != DB_TYPE_POSTGRES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project PostgreSQL type mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||||
|
pool = await project_connection_manager.get_pg_pool(
|
||||||
|
ctx.project_id,
|
||||||
|
DB_ROLE_BIZ_DATA,
|
||||||
|
routing.dsn,
|
||||||
|
pool_min_size,
|
||||||
|
pool_max_size,
|
||||||
|
)
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
yield conn
|
||||||
|
|
||||||
|
|
||||||
|
async def get_project_timescale_connection(
|
||||||
|
ctx: ProjectContext = Depends(get_project_context),
|
||||||
|
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||||
|
) -> AsyncGenerator[AsyncConnection, None]:
|
||||||
|
try:
|
||||||
|
routing = await metadata_repo.get_project_db_routing(
|
||||||
|
ctx.project_id, DB_ROLE_IOT_DATA
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.error(
|
||||||
|
"Invalid project TimescaleDB routing DSN configuration",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail=f"Project TimescaleDB routing DSN is invalid: {exc}",
|
||||||
|
) from exc
|
||||||
|
if not routing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project TimescaleDB not configured",
|
||||||
|
)
|
||||||
|
if routing.db_type != DB_TYPE_TIMESCALE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project TimescaleDB type mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
|
||||||
|
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
|
||||||
|
pool = await project_connection_manager.get_timescale_pool(
|
||||||
|
ctx.project_id,
|
||||||
|
DB_ROLE_IOT_DATA,
|
||||||
|
routing.dsn,
|
||||||
|
pool_min_size,
|
||||||
|
pool_max_size,
|
||||||
|
)
|
||||||
|
async with pool.connection() as conn:
|
||||||
|
yield conn
|
||||||
@@ -7,6 +7,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,18 +39,16 @@ class AuditAction:
|
|||||||
|
|
||||||
async def log_audit_event(
|
async def log_audit_event(
|
||||||
action: str,
|
action: str,
|
||||||
user_id: Optional[int] = None,
|
user_id: Optional[UUID] = None,
|
||||||
username: Optional[str] = None,
|
project_id: Optional[UUID] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
resource_id: Optional[str] = None,
|
resource_id: Optional[str] = None,
|
||||||
ip_address: Optional[str] = None,
|
ip_address: Optional[str] = None,
|
||||||
user_agent: Optional[str] = None,
|
|
||||||
request_method: Optional[str] = None,
|
request_method: Optional[str] = None,
|
||||||
request_path: Optional[str] = None,
|
request_path: Optional[str] = None,
|
||||||
request_data: Optional[dict] = None,
|
request_data: Optional[dict] = None,
|
||||||
response_status: Optional[int] = None,
|
response_status: Optional[int] = None,
|
||||||
error_message: Optional[str] = None,
|
session=None,
|
||||||
db=None, # 新增:可选的数据库实例
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
记录审计日志
|
记录审计日志
|
||||||
@@ -57,67 +56,60 @@ async def log_audit_event(
|
|||||||
Args:
|
Args:
|
||||||
action: 操作类型
|
action: 操作类型
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
username: 用户名
|
project_id: 项目ID
|
||||||
resource_type: 资源类型
|
resource_type: 资源类型
|
||||||
resource_id: 资源ID
|
resource_id: 资源ID
|
||||||
ip_address: IP地址
|
ip_address: IP地址
|
||||||
user_agent: User-Agent
|
|
||||||
request_method: 请求方法
|
request_method: 请求方法
|
||||||
request_path: 请求路径
|
request_path: 请求路径
|
||||||
request_data: 请求数据(敏感字段需脱敏)
|
request_data: 请求数据(敏感字段需脱敏)
|
||||||
response_status: 响应状态码
|
response_status: 响应状态码
|
||||||
error_message: 错误消息
|
session: 元数据库会话(可选)
|
||||||
db: 数据库实例(可选,如果不提供则尝试获取)
|
|
||||||
"""
|
"""
|
||||||
|
from app.infra.db.metadata.database import SessionLocal
|
||||||
from app.infra.repositories.audit_repository import AuditRepository
|
from app.infra.repositories.audit_repository import AuditRepository
|
||||||
|
|
||||||
try:
|
if request_data:
|
||||||
# 脱敏敏感数据
|
request_data = sanitize_sensitive_data(request_data)
|
||||||
if request_data:
|
|
||||||
request_data = sanitize_sensitive_data(request_data)
|
|
||||||
|
|
||||||
# 如果没有提供数据库实例,尝试从全局获取
|
|
||||||
if db is None:
|
|
||||||
try:
|
|
||||||
from app.infra.db.postgresql.database import db as default_db
|
|
||||||
|
|
||||||
# 仅当连接池已初始化时使用
|
|
||||||
if default_db.pool:
|
|
||||||
db = default_db
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 如果仍然没有数据库实例
|
|
||||||
if db is None:
|
|
||||||
# 在某些上下文中可能无法获取,此时静默失败
|
|
||||||
logger.warning("No database instance provided for audit logging")
|
|
||||||
return
|
|
||||||
|
|
||||||
audit_repo = AuditRepository(db)
|
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
async with SessionLocal() as session:
|
||||||
|
audit_repo = AuditRepository(session)
|
||||||
|
await audit_repo.create_log(
|
||||||
|
user_id=user_id,
|
||||||
|
project_id=project_id,
|
||||||
|
action=action,
|
||||||
|
resource_type=resource_type,
|
||||||
|
resource_id=resource_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
request_method=request_method,
|
||||||
|
request_path=request_path,
|
||||||
|
request_data=request_data,
|
||||||
|
response_status=response_status,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audit_repo = AuditRepository(session)
|
||||||
await audit_repo.create_log(
|
await audit_repo.create_log(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
action=action,
|
action=action,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
resource_id=resource_id,
|
resource_id=resource_id,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent,
|
|
||||||
request_method=request_method,
|
request_method=request_method,
|
||||||
request_path=request_path,
|
request_path=request_path,
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
response_status=response_status,
|
response_status=response_status,
|
||||||
error_message=error_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Audit log created: action={action}, user={username or user_id}, "
|
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
|
||||||
f"resource={resource_type}:{resource_id}"
|
action,
|
||||||
)
|
user_id,
|
||||||
|
project_id,
|
||||||
except Exception as e:
|
resource_type,
|
||||||
# 审计日志失败不应影响业务流程
|
resource_id,
|
||||||
logger.error(f"Failed to create audit log: {e}", exc_info=True)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sanitize_sensitive_data(data: dict) -> dict:
|
def sanitize_sensitive_data(data: dict) -> dict:
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from pydantic_settings import BaseSettings
|
from pathlib import Path
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -15,6 +17,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# 数据加密密钥 (使用 Fernet)
|
# 数据加密密钥 (使用 Fernet)
|
||||||
ENCRYPTION_KEY: str = "" # 必须从环境变量设置
|
ENCRYPTION_KEY: str = "" # 必须从环境变量设置
|
||||||
|
DATABASE_ENCRYPTION_KEY: str = "" # project_databases.dsn_encrypted 专用
|
||||||
|
|
||||||
# Database Config (PostgreSQL)
|
# Database Config (PostgreSQL)
|
||||||
DB_NAME: str = "tjwater"
|
DB_NAME: str = "tjwater"
|
||||||
@@ -35,13 +38,45 @@ class Settings(BaseSettings):
|
|||||||
INFLUXDB_ORG: str = "org"
|
INFLUXDB_ORG: str = "org"
|
||||||
INFLUXDB_BUCKET: str = "bucket"
|
INFLUXDB_BUCKET: str = "bucket"
|
||||||
|
|
||||||
|
# Metadata Database Config (PostgreSQL)
|
||||||
|
METADATA_DB_NAME: str = "system_hub"
|
||||||
|
METADATA_DB_HOST: str = "localhost"
|
||||||
|
METADATA_DB_PORT: str = "5432"
|
||||||
|
METADATA_DB_USER: str = "postgres"
|
||||||
|
METADATA_DB_PASSWORD: str = "password"
|
||||||
|
|
||||||
|
METADATA_DB_POOL_SIZE: int = 5
|
||||||
|
METADATA_DB_MAX_OVERFLOW: int = 10
|
||||||
|
|
||||||
|
PROJECT_PG_CACHE_SIZE: int = 50
|
||||||
|
PROJECT_TS_CACHE_SIZE: int = 50
|
||||||
|
PROJECT_PG_POOL_SIZE: int = 5
|
||||||
|
PROJECT_PG_MAX_OVERFLOW: int = 10
|
||||||
|
PROJECT_TS_POOL_MIN_SIZE: int = 1
|
||||||
|
PROJECT_TS_POOL_MAX_SIZE: int = 10
|
||||||
|
|
||||||
|
# Keycloak JWT (optional override)
|
||||||
|
KEYCLOAK_PUBLIC_KEY: str = ""
|
||||||
|
KEYCLOAK_ALGORITHM: str = "RS256"
|
||||||
|
KEYCLOAK_AUDIENCE: str = ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||||
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
db_password = quote_plus(self.DB_PASSWORD)
|
||||||
|
return f"postgresql://{self.DB_USER}:{db_password}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||||
|
|
||||||
class Config:
|
@property
|
||||||
env_file = ".env"
|
def METADATA_DATABASE_URI(self) -> str:
|
||||||
extra = "ignore"
|
metadata_password = quote_plus(self.METADATA_DB_PASSWORD)
|
||||||
|
return (
|
||||||
|
f"postgresql+psycopg://{self.METADATA_DB_USER}:{metadata_password}"
|
||||||
|
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=Path(__file__).resolve().parents[2] / ".env",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ from typing import Optional
|
|||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
class Encryptor:
|
class Encryptor:
|
||||||
"""
|
"""
|
||||||
使用 Fernet (对称加密) 实现数据加密/解密
|
使用 Fernet (对称加密) 实现数据加密/解密
|
||||||
@@ -17,10 +20,10 @@ class Encryptor:
|
|||||||
key: 加密密钥,如果为 None 则从环境变量读取
|
key: 加密密钥,如果为 None 则从环境变量读取
|
||||||
"""
|
"""
|
||||||
if key is None:
|
if key is None:
|
||||||
key_str = os.getenv("ENCRYPTION_KEY")
|
key_str = os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY
|
||||||
if not key_str:
|
if not key_str:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ENCRYPTION_KEY not found in environment variables. "
|
"ENCRYPTION_KEY not found in environment variables or .env. "
|
||||||
"Generate one using: Encryptor.generate_key()"
|
"Generate one using: Encryptor.generate_key()"
|
||||||
)
|
)
|
||||||
key = key_str.encode()
|
key = key_str.encode()
|
||||||
@@ -70,8 +73,24 @@ class Encryptor:
|
|||||||
key = Fernet.generate_key()
|
key = Fernet.generate_key()
|
||||||
return key.decode()
|
return key.decode()
|
||||||
|
|
||||||
|
|
||||||
# 全局加密器实例(懒加载)
|
# 全局加密器实例(懒加载)
|
||||||
_encryptor: Optional[Encryptor] = None
|
_encryptor: Optional[Encryptor] = None
|
||||||
|
_database_encryptor: Optional[Encryptor] = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_encryption_configured() -> bool:
|
||||||
|
return bool(os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def is_database_encryption_configured() -> bool:
|
||||||
|
return bool(
|
||||||
|
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||||
|
or settings.DATABASE_ENCRYPTION_KEY
|
||||||
|
or os.getenv("ENCRYPTION_KEY")
|
||||||
|
or settings.ENCRYPTION_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_encryptor() -> Encryptor:
|
def get_encryptor() -> Encryptor:
|
||||||
"""获取全局加密器实例"""
|
"""获取全局加密器实例"""
|
||||||
@@ -80,6 +99,26 @@ def get_encryptor() -> Encryptor:
|
|||||||
_encryptor = Encryptor()
|
_encryptor = Encryptor()
|
||||||
return _encryptor
|
return _encryptor
|
||||||
|
|
||||||
|
|
||||||
|
def get_database_encryptor() -> Encryptor:
|
||||||
|
"""获取 project DB DSN 专用加密器实例"""
|
||||||
|
global _database_encryptor
|
||||||
|
if _database_encryptor is None:
|
||||||
|
key_str = (
|
||||||
|
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||||
|
or settings.DATABASE_ENCRYPTION_KEY
|
||||||
|
or os.getenv("ENCRYPTION_KEY")
|
||||||
|
or settings.ENCRYPTION_KEY
|
||||||
|
)
|
||||||
|
if not key_str:
|
||||||
|
raise ValueError(
|
||||||
|
"DATABASE_ENCRYPTION_KEY not found in environment variables or .env. "
|
||||||
|
"Generate one using: Encryptor.generate_key()"
|
||||||
|
)
|
||||||
|
_database_encryptor = Encryptor(key=key_str.encode())
|
||||||
|
return _database_encryptor
|
||||||
|
|
||||||
|
|
||||||
# 向后兼容(延迟加载)
|
# 向后兼容(延迟加载)
|
||||||
def __getattr__(name):
|
def __getattr__(name):
|
||||||
if name == "encryptor":
|
if name == "encryptor":
|
||||||
|
|||||||
@@ -1,45 +1,42 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Any
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
class AuditLogCreate(BaseModel):
|
class AuditLogCreate(BaseModel):
|
||||||
"""创建审计日志"""
|
"""创建审计日志"""
|
||||||
user_id: Optional[int] = None
|
user_id: Optional[UUID] = None
|
||||||
username: Optional[str] = None
|
project_id: Optional[UUID] = None
|
||||||
action: str
|
action: str
|
||||||
resource_type: Optional[str] = None
|
resource_type: Optional[str] = None
|
||||||
resource_id: Optional[str] = None
|
resource_id: Optional[str] = None
|
||||||
ip_address: Optional[str] = None
|
ip_address: Optional[str] = None
|
||||||
user_agent: Optional[str] = None
|
|
||||||
request_method: Optional[str] = None
|
request_method: Optional[str] = None
|
||||||
request_path: Optional[str] = None
|
request_path: Optional[str] = None
|
||||||
request_data: Optional[dict] = None
|
request_data: Optional[dict] = None
|
||||||
response_status: Optional[int] = None
|
response_status: Optional[int] = None
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
class AuditLogResponse(BaseModel):
|
class AuditLogResponse(BaseModel):
|
||||||
"""审计日志响应"""
|
"""审计日志响应"""
|
||||||
id: int
|
id: UUID
|
||||||
user_id: Optional[int]
|
user_id: Optional[UUID]
|
||||||
username: Optional[str]
|
project_id: Optional[UUID]
|
||||||
action: str
|
action: str
|
||||||
resource_type: Optional[str]
|
resource_type: Optional[str]
|
||||||
resource_id: Optional[str]
|
resource_id: Optional[str]
|
||||||
ip_address: Optional[str]
|
ip_address: Optional[str]
|
||||||
user_agent: Optional[str]
|
|
||||||
request_method: Optional[str]
|
request_method: Optional[str]
|
||||||
request_path: Optional[str]
|
request_path: Optional[str]
|
||||||
request_data: Optional[dict]
|
request_data: Optional[dict]
|
||||||
response_status: Optional[int]
|
response_status: Optional[int]
|
||||||
error_message: Optional[str]
|
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
class AuditLogQuery(BaseModel):
|
class AuditLogQuery(BaseModel):
|
||||||
"""审计日志查询参数"""
|
"""审计日志查询参数"""
|
||||||
user_id: Optional[int] = None
|
user_id: Optional[UUID] = None
|
||||||
username: Optional[str] = None
|
project_id: Optional[UUID] = None
|
||||||
action: Optional[str] = None
|
action: Optional[str] = None
|
||||||
resource_type: Optional[str] = None
|
resource_type: Optional[str] = None
|
||||||
start_time: Optional[datetime] = None
|
start_time: Optional[datetime] = None
|
||||||
|
|||||||
33
app/domain/schemas/metadata.py
Normal file
33
app/domain/schemas/metadata.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GeoServerConfigResponse(BaseModel):
|
||||||
|
gs_base_url: Optional[str]
|
||||||
|
gs_admin_user: Optional[str]
|
||||||
|
gs_datastore_name: str
|
||||||
|
default_extent: Optional[dict]
|
||||||
|
srid: int
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectMetaResponse(BaseModel):
|
||||||
|
project_id: UUID
|
||||||
|
name: str
|
||||||
|
code: str
|
||||||
|
description: Optional[str]
|
||||||
|
gs_workspace: str
|
||||||
|
status: str
|
||||||
|
project_role: str
|
||||||
|
geoserver: Optional[GeoServerConfigResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectSummaryResponse(BaseModel):
|
||||||
|
project_id: UUID
|
||||||
|
name: str
|
||||||
|
code: str
|
||||||
|
description: Optional[str]
|
||||||
|
gs_workspace: str
|
||||||
|
status: str
|
||||||
|
project_role: str
|
||||||
@@ -6,12 +6,17 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
from uuid import UUID
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from app.infra.db.postgresql.database import db as default_db
|
|
||||||
from app.core.audit import log_audit_event, AuditAction
|
from app.core.audit import log_audit_event, AuditAction
|
||||||
import logging
|
import logging
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.infra.db.metadata.database import SessionLocal
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
# 4. 提取审计所需信息
|
# 4. 提取审计所需信息
|
||||||
user_id = None
|
user_id = await self._resolve_user_id(request)
|
||||||
username = None
|
project_id = self._resolve_project_id(request)
|
||||||
|
|
||||||
# 尝试从请求状态获取当前用户
|
|
||||||
if hasattr(request.state, "user"):
|
|
||||||
user = request.state.user
|
|
||||||
user_id = getattr(user, "id", None)
|
|
||||||
username = getattr(user, "username", None)
|
|
||||||
|
|
||||||
# 获取客户端信息
|
# 获取客户端信息
|
||||||
ip_address = request.client.host if request.client else None
|
ip_address = request.client.host if request.client else None
|
||||||
user_agent = request.headers.get("user-agent")
|
|
||||||
|
|
||||||
# 确定操作类型
|
# 确定操作类型
|
||||||
action = self._determine_action(request)
|
action = self._determine_action(request)
|
||||||
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
|||||||
await log_audit_event(
|
await log_audit_event(
|
||||||
action=action,
|
action=action,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
project_id=project_id,
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
resource_id=resource_id,
|
resource_id=resource_id,
|
||||||
ip_address=ip_address,
|
ip_address=ip_address,
|
||||||
user_agent=user_agent,
|
|
||||||
request_method=request.method,
|
request_method=request.method,
|
||||||
request_path=str(request.url.path),
|
request_path=str(request.url.path),
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
response_status=response.status_code,
|
response_status=response.status_code,
|
||||||
error_message=(
|
|
||||||
None
|
|
||||||
if response.status_code < 400
|
|
||||||
else f"HTTP {response.status_code}"
|
|
||||||
),
|
|
||||||
db=default_db,
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 审计失败不应影响响应
|
# 审计失败不应影响响应
|
||||||
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||||
|
project_header = request.headers.get("X-Project-Id")
|
||||||
|
if not project_header:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return UUID(project_header)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
||||||
|
auth_header = request.headers.get("authorization")
|
||||||
|
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||||
|
return None
|
||||||
|
token = auth_header.split(" ", 1)[1].strip()
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
key = (
|
||||||
|
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||||
|
if settings.KEYCLOAK_PUBLIC_KEY
|
||||||
|
else settings.SECRET_KEY
|
||||||
|
)
|
||||||
|
algorithms = (
|
||||||
|
[settings.KEYCLOAK_ALGORITHM]
|
||||||
|
if settings.KEYCLOAK_PUBLIC_KEY
|
||||||
|
else [settings.ALGORITHM]
|
||||||
|
)
|
||||||
|
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||||
|
sub = payload.get("sub")
|
||||||
|
if not sub:
|
||||||
|
return None
|
||||||
|
keycloak_id = UUID(sub)
|
||||||
|
except (JWTError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async with SessionLocal() as session:
|
||||||
|
repo = MetadataRepository(session)
|
||||||
|
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||||
|
if user and user.is_active:
|
||||||
|
return user.id
|
||||||
|
return None
|
||||||
|
|
||||||
def _determine_action(self, request: Request) -> str:
|
def _determine_action(self, request: Request) -> str:
|
||||||
"""根据请求路径和方法确定操作类型"""
|
"""根据请求路径和方法确定操作类型"""
|
||||||
path = request.url.path.lower()
|
path = request.url.path.lower()
|
||||||
|
|||||||
211
app/infra/db/dynamic_manager.py
Normal file
211
app/infra/db/dynamic_manager.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from psycopg.rows import dict_row
|
||||||
|
from sqlalchemy.engine.url import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncEngine,
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PgEngineEntry:
|
||||||
|
engine: AsyncEngine
|
||||||
|
sessionmaker: async_sessionmaker[AsyncSession]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CacheKey:
|
||||||
|
project_id: UUID
|
||||||
|
db_role: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectConnectionManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
|
||||||
|
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||||
|
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||||
|
self._pg_lock = asyncio.Lock()
|
||||||
|
self._ts_lock = asyncio.Lock()
|
||||||
|
self._pg_raw_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _normalize_pg_url(self, url: str) -> str:
|
||||||
|
parsed = make_url(url)
|
||||||
|
if parsed.drivername == "postgresql":
|
||||||
|
parsed = parsed.set(drivername="postgresql+psycopg")
|
||||||
|
return str(parsed)
|
||||||
|
|
||||||
|
async def get_pg_sessionmaker(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
db_role: str,
|
||||||
|
connection_url: str,
|
||||||
|
pool_min_size: int,
|
||||||
|
pool_max_size: int,
|
||||||
|
) -> async_sessionmaker[AsyncSession]:
|
||||||
|
async with self._pg_lock:
|
||||||
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||||
|
entry = self._pg_cache.get(key)
|
||||||
|
if entry:
|
||||||
|
self._pg_cache.move_to_end(key)
|
||||||
|
return entry.sessionmaker
|
||||||
|
|
||||||
|
normalized_url = self._normalize_pg_url(connection_url)
|
||||||
|
pool_min_size = max(1, pool_min_size)
|
||||||
|
pool_max_size = max(pool_min_size, pool_max_size)
|
||||||
|
engine = create_async_engine(
|
||||||
|
normalized_url,
|
||||||
|
pool_size=pool_min_size,
|
||||||
|
max_overflow=max(0, pool_max_size - pool_min_size),
|
||||||
|
pool_pre_ping=True,
|
||||||
|
)
|
||||||
|
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
self._pg_cache[key] = PgEngineEntry(
|
||||||
|
engine=engine,
|
||||||
|
sessionmaker=sessionmaker,
|
||||||
|
)
|
||||||
|
await self._evict_pg_if_needed()
|
||||||
|
logger.info(
|
||||||
|
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
|
||||||
|
)
|
||||||
|
return sessionmaker
|
||||||
|
|
||||||
|
async def get_timescale_pool(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
db_role: str,
|
||||||
|
connection_url: str,
|
||||||
|
pool_min_size: int,
|
||||||
|
pool_max_size: int,
|
||||||
|
) -> AsyncConnectionPool:
|
||||||
|
async with self._ts_lock:
|
||||||
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||||
|
pool = self._ts_cache.get(key)
|
||||||
|
if pool:
|
||||||
|
self._ts_cache.move_to_end(key)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
pool_min_size = max(1, pool_min_size)
|
||||||
|
pool_max_size = max(pool_min_size, pool_max_size)
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=connection_url,
|
||||||
|
min_size=pool_min_size,
|
||||||
|
max_size=pool_max_size,
|
||||||
|
open=False,
|
||||||
|
kwargs={"row_factory": dict_row},
|
||||||
|
)
|
||||||
|
await pool.open()
|
||||||
|
self._ts_cache[key] = pool
|
||||||
|
await self._evict_ts_if_needed()
|
||||||
|
logger.info(
|
||||||
|
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
|
||||||
|
)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
async def get_pg_pool(
|
||||||
|
self,
|
||||||
|
project_id: UUID,
|
||||||
|
db_role: str,
|
||||||
|
connection_url: str,
|
||||||
|
pool_min_size: int,
|
||||||
|
pool_max_size: int,
|
||||||
|
) -> AsyncConnectionPool:
|
||||||
|
async with self._pg_raw_lock:
|
||||||
|
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||||
|
pool = self._pg_raw_cache.get(key)
|
||||||
|
if pool:
|
||||||
|
self._pg_raw_cache.move_to_end(key)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
pool_min_size = max(1, pool_min_size)
|
||||||
|
pool_max_size = max(pool_min_size, pool_max_size)
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=connection_url,
|
||||||
|
min_size=pool_min_size,
|
||||||
|
max_size=pool_max_size,
|
||||||
|
open=False,
|
||||||
|
kwargs={"row_factory": dict_row},
|
||||||
|
)
|
||||||
|
await pool.open()
|
||||||
|
self._pg_raw_cache[key] = pool
|
||||||
|
await self._evict_pg_raw_if_needed()
|
||||||
|
logger.info(
|
||||||
|
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
|
||||||
|
)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
async def _evict_pg_if_needed(self) -> None:
|
||||||
|
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||||
|
key, entry = self._pg_cache.popitem(last=False)
|
||||||
|
await entry.engine.dispose()
|
||||||
|
logger.info(
|
||||||
|
"Evicted PostgreSQL engine for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _evict_ts_if_needed(self) -> None:
|
||||||
|
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
|
||||||
|
key, pool = self._ts_cache.popitem(last=False)
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Evicted TimescaleDB pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _evict_pg_raw_if_needed(self) -> None:
|
||||||
|
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||||
|
key, pool = self._pg_raw_cache.popitem(last=False)
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Evicted PostgreSQL pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close_all(self) -> None:
|
||||||
|
async with self._pg_lock:
|
||||||
|
for key, entry in list(self._pg_cache.items()):
|
||||||
|
await entry.engine.dispose()
|
||||||
|
logger.info(
|
||||||
|
"Closed PostgreSQL engine for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
self._pg_cache.clear()
|
||||||
|
|
||||||
|
async with self._ts_lock:
|
||||||
|
for key, pool in list(self._ts_cache.items()):
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Closed TimescaleDB pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
self._ts_cache.clear()
|
||||||
|
|
||||||
|
async with self._pg_raw_lock:
|
||||||
|
for key, pool in list(self._pg_raw_cache.items()):
|
||||||
|
await pool.close()
|
||||||
|
logger.info(
|
||||||
|
"Closed PostgreSQL pool for project %s (%s)",
|
||||||
|
key.project_id,
|
||||||
|
key.db_role,
|
||||||
|
)
|
||||||
|
self._pg_raw_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
project_connection_manager = ProjectConnectionManager()
|
||||||
3
app/infra/db/metadata/__init__.py
Normal file
3
app/infra/db/metadata/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .database import get_metadata_session, close_metadata_engine
|
||||||
|
|
||||||
|
__all__ = ["get_metadata_session", "close_metadata_engine"]
|
||||||
27
app/infra/db/metadata/database.py
Normal file
27
app/infra/db/metadata/database.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.METADATA_DATABASE_URI,
|
||||||
|
pool_size=settings.METADATA_DB_POOL_SIZE,
|
||||||
|
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
async with SessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def close_metadata_engine() -> None:
|
||||||
|
await engine.dispose()
|
||||||
|
logger.info("Metadata database engine disposed.")
|
||||||
115
app/infra/db/metadata/models.py
Normal file
115
app/infra/db/metadata/models.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
keycloak_id: Mapped[UUID] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), unique=True, index=True
|
||||||
|
)
|
||||||
|
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
|
role: Mapped[str] = mapped_column(String(20), default="user")
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Project(Base):
|
||||||
|
__tablename__ = "projects"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(100))
|
||||||
|
code: Mapped[str] = mapped_column(String(50), unique=True)
|
||||||
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDatabase(Base):
|
||||||
|
__tablename__ = "project_databases"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||||
|
db_role: Mapped[str] = mapped_column(String(20))
|
||||||
|
db_type: Mapped[str] = mapped_column(String(20))
|
||||||
|
dsn_encrypted: Mapped[str] = mapped_column(Text)
|
||||||
|
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
|
||||||
|
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectGeoServerConfig(Base):
|
||||||
|
__tablename__ = "project_geoserver_configs"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
project_id: Mapped[UUID] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), unique=True, index=True
|
||||||
|
)
|
||||||
|
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||||
|
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
|
||||||
|
Text, nullable=True
|
||||||
|
)
|
||||||
|
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
|
||||||
|
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
srid: Mapped[int] = mapped_column(Integer, default=4326)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserProjectMembership(Base):
|
||||||
|
__tablename__ = "user_project_membership"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||||
|
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||||
|
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLog(Base):
|
||||||
|
__tablename__ = "audit_logs"
|
||||||
|
|
||||||
|
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||||
|
user_id: Mapped[UUID | None] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), nullable=True, index=True
|
||||||
|
)
|
||||||
|
project_id: Mapped[UUID | None] = mapped_column(
|
||||||
|
PGUUID(as_uuid=True), nullable=True, index=True
|
||||||
|
)
|
||||||
|
action: Mapped[str] = mapped_column(String(50))
|
||||||
|
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||||
|
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||||
|
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||||
|
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||||
|
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
timestamp: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=datetime.utcnow
|
||||||
|
)
|
||||||
@@ -17,7 +17,14 @@ class Database:
|
|||||||
def init_pool(self, db_name=None):
|
def init_pool(self, db_name=None):
|
||||||
"""Initialize the connection pool."""
|
"""Initialize the connection pool."""
|
||||||
# Use provided db_name, or the one from constructor, or default from config
|
# Use provided db_name, or the one from constructor, or default from config
|
||||||
conn_string = postgresql_info.get_pgconn_string()
|
target_db_name = db_name or self.db_name
|
||||||
|
|
||||||
|
# Get connection string, handling default case where target_db_name might be None
|
||||||
|
if target_db_name:
|
||||||
|
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
|
||||||
|
else:
|
||||||
|
conn_string = postgresql_info.get_pgconn_string()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||||
conninfo=conn_string,
|
conninfo=conn_string,
|
||||||
@@ -26,7 +33,7 @@ class Database:
|
|||||||
open=False, # Don't open immediately, wait for startup
|
open=False, # Don't open immediately, wait for startup
|
||||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||||
)
|
)
|
||||||
logger.info(f"PostgreSQL connection pool initialized for database: default")
|
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize postgresql connection pool: {e}")
|
logger.error(f"Failed to initialize postgresql connection pool: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,24 +1,18 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from typing import Optional
|
|
||||||
from psycopg import AsyncConnection
|
from psycopg import AsyncConnection
|
||||||
|
|
||||||
from .database import get_database_instance
|
|
||||||
from .scada_info import ScadaRepository
|
from .scada_info import ScadaRepository
|
||||||
from .scheme import SchemeRepository
|
from .scheme import SchemeRepository
|
||||||
|
from app.auth.project_dependencies import get_project_pg_connection
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
# 创建支持数据库选择的连接依赖函数
|
# 动态项目 PostgreSQL 连接依赖
|
||||||
async def get_database_connection(
|
async def get_database_connection(
|
||||||
db_name: Optional[str] = Query(
|
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
yield conn
|
||||||
instance = await get_database_instance(db_name)
|
|
||||||
async with instance.get_connection() as conn:
|
|
||||||
yield conn
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/scada-info")
|
@router.get("/scada-info")
|
||||||
|
|||||||
@@ -17,7 +17,14 @@ class Database:
|
|||||||
def init_pool(self, db_name=None):
|
def init_pool(self, db_name=None):
|
||||||
"""Initialize the connection pool."""
|
"""Initialize the connection pool."""
|
||||||
# Use provided db_name, or the one from constructor, or default from config
|
# Use provided db_name, or the one from constructor, or default from config
|
||||||
conn_string = timescaledb_info.get_pgconn_string()
|
target_db_name = db_name or self.db_name
|
||||||
|
|
||||||
|
# Get connection string, handling default case where target_db_name might be None
|
||||||
|
if target_db_name:
|
||||||
|
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||||
|
else:
|
||||||
|
conn_string = timescaledb_info.get_pgconn_string()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||||
conninfo=conn_string,
|
conninfo=conn_string,
|
||||||
@@ -27,7 +34,7 @@ class Database:
|
|||||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"TimescaleDB connection pool initialized for database: default"
|
f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
|
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
|
||||||
@@ -46,7 +53,9 @@ class Database:
|
|||||||
def get_pgconn_string(self, db_name=None):
|
def get_pgconn_string(self, db_name=None):
|
||||||
"""Get the TimescaleDB connection string."""
|
"""Get the TimescaleDB connection string."""
|
||||||
target_db_name = db_name or self.db_name
|
target_db_name = db_name or self.db_name
|
||||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
if target_db_name:
|
||||||
|
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||||
|
return timescaledb_info.get_pgconn_string()
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_connection(self) -> AsyncGenerator:
|
async def get_connection(self) -> AsyncGenerator:
|
||||||
|
|||||||
@@ -1,42 +1,32 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from psycopg import AsyncConnection
|
from psycopg import AsyncConnection
|
||||||
|
|
||||||
from .database import get_database_instance
|
|
||||||
from .schemas.realtime import RealtimeRepository
|
from .schemas.realtime import RealtimeRepository
|
||||||
from .schemas.scheme import SchemeRepository
|
from .schemas.scheme import SchemeRepository
|
||||||
from .schemas.scada import ScadaRepository
|
from .schemas.scada import ScadaRepository
|
||||||
from .composite_queries import CompositeQueries
|
from .composite_queries import CompositeQueries
|
||||||
from app.infra.db.postgresql.database import (
|
from app.auth.project_dependencies import (
|
||||||
get_database_instance as get_postgres_database_instance,
|
get_project_pg_connection,
|
||||||
|
get_project_timescale_connection,
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
# 创建支持数据库选择的连接依赖函数
|
# 动态项目 TimescaleDB 连接依赖
|
||||||
async def get_database_connection(
|
async def get_database_connection(
|
||||||
db_name: Optional[str] = Query(
|
conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
yield conn
|
||||||
instance = await get_database_instance(db_name)
|
|
||||||
async with instance.get_connection() as conn:
|
|
||||||
yield conn
|
|
||||||
|
|
||||||
|
|
||||||
# PostgreSQL 数据库连接依赖函数
|
# 动态项目 PostgreSQL 连接依赖
|
||||||
async def get_postgres_connection(
|
async def get_postgres_connection(
|
||||||
db_name: Optional[str] = Query(
|
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||||
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
|
yield conn
|
||||||
instance = await get_postgres_database_instance(db_name)
|
|
||||||
async with instance.get_connection() as conn:
|
|
||||||
yield conn
|
|
||||||
|
|
||||||
|
|
||||||
# --- Realtime Endpoints ---
|
# --- Realtime Endpoints ---
|
||||||
|
|||||||
@@ -1,220 +1,112 @@
|
|||||||
from typing import Optional, List
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
from typing import Optional, List
|
||||||
from app.infra.db.postgresql.database import Database
|
from uuid import UUID
|
||||||
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
|
|
||||||
import logging
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.domain.schemas.audit import AuditLogResponse
|
||||||
|
from app.infra.db.metadata import models
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class AuditRepository:
|
class AuditRepository:
|
||||||
"""审计日志数据访问层"""
|
"""审计日志数据访问层(system_hub)"""
|
||||||
|
|
||||||
def __init__(self, db: Database):
|
def __init__(self, session: AsyncSession):
|
||||||
self.db = db
|
self.session = session
|
||||||
|
|
||||||
async def create_log(
|
async def create_log(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[int] = None,
|
action: str,
|
||||||
username: Optional[str] = None,
|
user_id: Optional[UUID] = None,
|
||||||
action: str = "",
|
project_id: Optional[UUID] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
resource_id: Optional[str] = None,
|
resource_id: Optional[str] = None,
|
||||||
ip_address: Optional[str] = None,
|
ip_address: Optional[str] = None,
|
||||||
user_agent: Optional[str] = None,
|
|
||||||
request_method: Optional[str] = None,
|
request_method: Optional[str] = None,
|
||||||
request_path: Optional[str] = None,
|
request_path: Optional[str] = None,
|
||||||
request_data: Optional[dict] = None,
|
request_data: Optional[dict] = None,
|
||||||
response_status: Optional[int] = None,
|
response_status: Optional[int] = None,
|
||||||
error_message: Optional[str] = None
|
) -> AuditLogResponse:
|
||||||
) -> Optional[AuditLogResponse]:
|
log = models.AuditLog(
|
||||||
"""
|
user_id=user_id,
|
||||||
创建审计日志
|
project_id=project_id,
|
||||||
|
action=action,
|
||||||
Args:
|
resource_type=resource_type,
|
||||||
参数说明见 AuditLogCreate
|
resource_id=resource_id,
|
||||||
|
ip_address=ip_address,
|
||||||
Returns:
|
request_method=request_method,
|
||||||
创建的审计日志对象
|
request_path=request_path,
|
||||||
"""
|
request_data=request_data,
|
||||||
query = """
|
response_status=response_status,
|
||||||
INSERT INTO audit_logs (
|
timestamp=datetime.utcnow(),
|
||||||
user_id, username, action, resource_type, resource_id,
|
)
|
||||||
ip_address, user_agent, request_method, request_path,
|
self.session.add(log)
|
||||||
request_data, response_status, error_message
|
await self.session.commit()
|
||||||
)
|
await self.session.refresh(log)
|
||||||
VALUES (
|
return AuditLogResponse.model_validate(log)
|
||||||
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
|
|
||||||
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
|
|
||||||
%(request_data)s, %(response_status)s, %(error_message)s
|
|
||||||
)
|
|
||||||
RETURNING id, user_id, username, action, resource_type, resource_id,
|
|
||||||
ip_address, user_agent, request_method, request_path,
|
|
||||||
request_data, response_status, error_message, timestamp
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.db.get_connection() as conn:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute(query, {
|
|
||||||
'user_id': user_id,
|
|
||||||
'username': username,
|
|
||||||
'action': action,
|
|
||||||
'resource_type': resource_type,
|
|
||||||
'resource_id': resource_id,
|
|
||||||
'ip_address': ip_address,
|
|
||||||
'user_agent': user_agent,
|
|
||||||
'request_method': request_method,
|
|
||||||
'request_path': request_path,
|
|
||||||
'request_data': json.dumps(request_data) if request_data else None,
|
|
||||||
'response_status': response_status,
|
|
||||||
'error_message': error_message
|
|
||||||
})
|
|
||||||
row = await cur.fetchone()
|
|
||||||
if row:
|
|
||||||
return AuditLogResponse(**row)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating audit log: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_logs(
|
async def get_logs(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[int] = None,
|
user_id: Optional[UUID] = None,
|
||||||
username: Optional[str] = None,
|
project_id: Optional[UUID] = None,
|
||||||
action: Optional[str] = None,
|
action: Optional[str] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[datetime] = None,
|
||||||
end_time: Optional[datetime] = None,
|
end_time: Optional[datetime] = None,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100
|
limit: int = 100,
|
||||||
) -> List[AuditLogResponse]:
|
) -> List[AuditLogResponse]:
|
||||||
"""
|
|
||||||
查询审计日志
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID过滤
|
|
||||||
username: 用户名过滤
|
|
||||||
action: 操作类型过滤
|
|
||||||
resource_type: 资源类型过滤
|
|
||||||
start_time: 开始时间
|
|
||||||
end_time: 结束时间
|
|
||||||
skip: 跳过记录数
|
|
||||||
limit: 限制记录数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
审计日志列表
|
|
||||||
"""
|
|
||||||
# 构建动态查询
|
|
||||||
conditions = []
|
conditions = []
|
||||||
params = {'skip': skip, 'limit': limit}
|
|
||||||
|
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
conditions.append("user_id = %(user_id)s")
|
conditions.append(models.AuditLog.user_id == user_id)
|
||||||
params['user_id'] = user_id
|
if project_id is not None:
|
||||||
|
conditions.append(models.AuditLog.project_id == project_id)
|
||||||
if username:
|
|
||||||
conditions.append("username = %(username)s")
|
|
||||||
params['username'] = username
|
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
conditions.append("action = %(action)s")
|
conditions.append(models.AuditLog.action == action)
|
||||||
params['action'] = action
|
|
||||||
|
|
||||||
if resource_type:
|
if resource_type:
|
||||||
conditions.append("resource_type = %(resource_type)s")
|
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||||
params['resource_type'] = resource_type
|
|
||||||
|
|
||||||
if start_time:
|
if start_time:
|
||||||
conditions.append("timestamp >= %(start_time)s")
|
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||||
params['start_time'] = start_time
|
|
||||||
|
|
||||||
if end_time:
|
if end_time:
|
||||||
conditions.append("timestamp <= %(end_time)s")
|
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||||
params['end_time'] = end_time
|
|
||||||
|
|
||||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
stmt = (
|
||||||
|
select(models.AuditLog)
|
||||||
query = f"""
|
.where(*conditions)
|
||||||
SELECT id, user_id, username, action, resource_type, resource_id,
|
.order_by(models.AuditLog.timestamp.desc())
|
||||||
ip_address, user_agent, request_method, request_path,
|
.offset(skip)
|
||||||
request_data, response_status, error_message, timestamp
|
.limit(limit)
|
||||||
FROM audit_logs
|
)
|
||||||
{where_clause}
|
result = await self.session.execute(stmt)
|
||||||
ORDER BY timestamp DESC
|
return [
|
||||||
LIMIT %(limit)s OFFSET %(skip)s
|
AuditLogResponse.model_validate(log)
|
||||||
"""
|
for log in result.scalars().all()
|
||||||
|
]
|
||||||
try:
|
|
||||||
async with self.db.get_connection() as conn:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute(query, params)
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
return [AuditLogResponse(**row) for row in rows]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error querying audit logs: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_log_count(
|
async def get_log_count(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[int] = None,
|
user_id: Optional[UUID] = None,
|
||||||
username: Optional[str] = None,
|
project_id: Optional[UUID] = None,
|
||||||
action: Optional[str] = None,
|
action: Optional[str] = None,
|
||||||
resource_type: Optional[str] = None,
|
resource_type: Optional[str] = None,
|
||||||
start_time: Optional[datetime] = None,
|
start_time: Optional[datetime] = None,
|
||||||
end_time: Optional[datetime] = None
|
end_time: Optional[datetime] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
|
||||||
获取审计日志数量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
参数同 get_logs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
日志总数
|
|
||||||
"""
|
|
||||||
conditions = []
|
conditions = []
|
||||||
params = {}
|
|
||||||
|
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
conditions.append("user_id = %(user_id)s")
|
conditions.append(models.AuditLog.user_id == user_id)
|
||||||
params['user_id'] = user_id
|
if project_id is not None:
|
||||||
|
conditions.append(models.AuditLog.project_id == project_id)
|
||||||
if username:
|
|
||||||
conditions.append("username = %(username)s")
|
|
||||||
params['username'] = username
|
|
||||||
|
|
||||||
if action:
|
if action:
|
||||||
conditions.append("action = %(action)s")
|
conditions.append(models.AuditLog.action == action)
|
||||||
params['action'] = action
|
|
||||||
|
|
||||||
if resource_type:
|
if resource_type:
|
||||||
conditions.append("resource_type = %(resource_type)s")
|
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||||
params['resource_type'] = resource_type
|
|
||||||
|
|
||||||
if start_time:
|
if start_time:
|
||||||
conditions.append("timestamp >= %(start_time)s")
|
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||||
params['start_time'] = start_time
|
|
||||||
|
|
||||||
if end_time:
|
if end_time:
|
||||||
conditions.append("timestamp <= %(end_time)s")
|
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||||
params['end_time'] = end_time
|
|
||||||
|
|
||||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
query = f"""
|
return int(result.scalar() or 0)
|
||||||
SELECT COUNT(*) as count
|
|
||||||
FROM audit_logs
|
|
||||||
{where_clause}
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.db.get_connection() as conn:
|
|
||||||
async with conn.cursor() as cur:
|
|
||||||
await cur.execute(query, params)
|
|
||||||
result = await cur.fetchone()
|
|
||||||
return result['count'] if result else 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error counting audit logs: {e}")
|
|
||||||
return 0
|
|
||||||
|
|||||||
197
app/infra/repositories/metadata_repository.py
Normal file
197
app/infra/repositories/metadata_repository.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from cryptography.fernet import InvalidToken
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.encryption import (
|
||||||
|
get_database_encryptor,
|
||||||
|
get_encryptor,
|
||||||
|
is_database_encryption_configured,
|
||||||
|
is_encryption_configured,
|
||||||
|
)
|
||||||
|
from app.infra.db.metadata import models
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_postgres_dsn(dsn: str) -> str:
|
||||||
|
if not dsn or "://" not in dsn:
|
||||||
|
return dsn
|
||||||
|
scheme, rest = dsn.split("://", 1)
|
||||||
|
if scheme not in ("postgresql", "postgres", "postgresql+psycopg"):
|
||||||
|
return dsn
|
||||||
|
if "@" not in rest:
|
||||||
|
return dsn
|
||||||
|
userinfo, hostinfo = rest.rsplit("@", 1)
|
||||||
|
if ":" not in userinfo:
|
||||||
|
return dsn
|
||||||
|
username, password = userinfo.split(":", 1)
|
||||||
|
if "@" not in password:
|
||||||
|
return dsn
|
||||||
|
password = password.replace("@", "%40")
|
||||||
|
return f"{scheme}://{username}:{password}@{hostinfo}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectDbRouting:
|
||||||
|
project_id: UUID
|
||||||
|
db_role: str
|
||||||
|
db_type: str
|
||||||
|
dsn: str
|
||||||
|
pool_min_size: int
|
||||||
|
pool_max_size: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectGeoServerInfo:
|
||||||
|
project_id: UUID
|
||||||
|
gs_base_url: Optional[str]
|
||||||
|
gs_admin_user: Optional[str]
|
||||||
|
gs_admin_password: Optional[str]
|
||||||
|
gs_datastore_name: str
|
||||||
|
default_extent: Optional[dict]
|
||||||
|
srid: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProjectSummary:
|
||||||
|
project_id: UUID
|
||||||
|
name: str
|
||||||
|
code: str
|
||||||
|
description: Optional[str]
|
||||||
|
gs_workspace: str
|
||||||
|
status: str
|
||||||
|
project_role: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataRepository:
|
||||||
|
"""元数据访问层(system_hub)"""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def get_user_by_keycloak_id(self, keycloak_id: UUID) -> Optional[models.User]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.User).where(models.User.keycloak_id == keycloak_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.Project).where(models.Project.id == project_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_membership_role(
|
||||||
|
self, project_id: UUID, user_id: UUID
|
||||||
|
) -> Optional[str]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.UserProjectMembership.project_role).where(
|
||||||
|
models.UserProjectMembership.project_id == project_id,
|
||||||
|
models.UserProjectMembership.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_project_db_routing(
|
||||||
|
self, project_id: UUID, db_role: str
|
||||||
|
) -> Optional[ProjectDbRouting]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.ProjectDatabase).where(
|
||||||
|
models.ProjectDatabase.project_id == project_id,
|
||||||
|
models.ProjectDatabase.db_role == db_role,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if not record:
|
||||||
|
return None
|
||||||
|
if not is_database_encryption_configured():
|
||||||
|
raise ValueError("DATABASE_ENCRYPTION_KEY is not configured")
|
||||||
|
encryptor = get_database_encryptor()
|
||||||
|
try:
|
||||||
|
dsn = encryptor.decrypt(record.dsn_encrypted)
|
||||||
|
except InvalidToken:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to decrypt project DB DSN: DATABASE_ENCRYPTION_KEY mismatch "
|
||||||
|
"or invalid dsn_encrypted value"
|
||||||
|
)
|
||||||
|
dsn = _normalize_postgres_dsn(dsn)
|
||||||
|
return ProjectDbRouting(
|
||||||
|
project_id=record.project_id,
|
||||||
|
db_role=record.db_role,
|
||||||
|
db_type=record.db_type,
|
||||||
|
dsn=dsn,
|
||||||
|
pool_min_size=record.pool_min_size,
|
||||||
|
pool_max_size=record.pool_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_geoserver_config(
|
||||||
|
self, project_id: UUID
|
||||||
|
) -> Optional[ProjectGeoServerInfo]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.ProjectGeoServerConfig).where(
|
||||||
|
models.ProjectGeoServerConfig.project_id == project_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if not record:
|
||||||
|
return None
|
||||||
|
if record.gs_admin_password_encrypted:
|
||||||
|
if is_encryption_configured():
|
||||||
|
encryptor = get_encryptor()
|
||||||
|
password = encryptor.decrypt(record.gs_admin_password_encrypted)
|
||||||
|
else:
|
||||||
|
password = record.gs_admin_password_encrypted
|
||||||
|
else:
|
||||||
|
password = None
|
||||||
|
return ProjectGeoServerInfo(
|
||||||
|
project_id=record.project_id,
|
||||||
|
gs_base_url=record.gs_base_url,
|
||||||
|
gs_admin_user=record.gs_admin_user,
|
||||||
|
gs_admin_password=password,
|
||||||
|
gs_datastore_name=record.gs_datastore_name,
|
||||||
|
default_extent=record.default_extent,
|
||||||
|
srid=record.srid,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
|
||||||
|
stmt = (
|
||||||
|
select(models.Project, models.UserProjectMembership.project_role)
|
||||||
|
.join(
|
||||||
|
models.UserProjectMembership,
|
||||||
|
models.UserProjectMembership.project_id == models.Project.id,
|
||||||
|
)
|
||||||
|
.where(models.UserProjectMembership.user_id == user_id)
|
||||||
|
.order_by(models.Project.name)
|
||||||
|
)
|
||||||
|
result = await self.session.execute(stmt)
|
||||||
|
return [
|
||||||
|
ProjectSummary(
|
||||||
|
project_id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role=role,
|
||||||
|
)
|
||||||
|
for project, role in result.all()
|
||||||
|
]
|
||||||
|
|
||||||
|
async def list_all_projects(self) -> List[ProjectSummary]:
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(models.Project).order_by(models.Project.name)
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
ProjectSummary(
|
||||||
|
project_id=project.id,
|
||||||
|
name=project.name,
|
||||||
|
code=project.code,
|
||||||
|
description=project.description,
|
||||||
|
gs_workspace=project.gs_workspace,
|
||||||
|
status=project.status,
|
||||||
|
project_role="owner",
|
||||||
|
)
|
||||||
|
for project in result.scalars().all()
|
||||||
|
]
|
||||||
35
app/main.py
35
app/main.py
@@ -9,6 +9,8 @@ import app.services.project_info as project_info
|
|||||||
from app.api.v1.router import api_router
|
from app.api.v1.router import api_router
|
||||||
from app.infra.db.timescaledb.database import db as tsdb
|
from app.infra.db.timescaledb.database import db as tsdb
|
||||||
from app.infra.db.postgresql.database import db as pgdb
|
from app.infra.db.postgresql.database import db as pgdb
|
||||||
|
from app.infra.db.dynamic_manager import project_connection_manager
|
||||||
|
from app.infra.db.metadata.database import close_metadata_engine
|
||||||
from app.services.tjnetwork import open_project
|
from app.services.tjnetwork import open_project
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
@@ -46,6 +48,8 @@ async def lifespan(app: FastAPI):
|
|||||||
# 清理资源
|
# 清理资源
|
||||||
await tsdb.close()
|
await tsdb.close()
|
||||||
await pgdb.close()
|
await pgdb.close()
|
||||||
|
await project_connection_manager.close_all()
|
||||||
|
await close_metadata_engine()
|
||||||
logger.info("Database connections closed")
|
logger.info("Database connections closed")
|
||||||
|
|
||||||
|
|
||||||
@@ -58,22 +62,25 @@ app = FastAPI(
|
|||||||
redoc_url="/redoc",
|
redoc_url="/redoc",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 配置 CORS 中间件
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # 允许所有来源
|
|
||||||
allow_credentials=True, # 允许传递凭证(Cookie、HTTP 头等)
|
|
||||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
|
||||||
allow_headers=["*"], # 允许所有 HTTP 头
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
||||||
|
|
||||||
# 添加审计中间件(可选,记录关键操作)
|
|
||||||
# 如果需要启用审计日志,取消下面的注释
|
|
||||||
app.add_middleware(AuditMiddleware)
|
|
||||||
|
|
||||||
# Include Routers
|
# Include Routers
|
||||||
app.include_router(api_router, prefix="/api/v1")
|
app.include_router(api_router, prefix="/api/v1")
|
||||||
# Legcy Routers without version prefix
|
# Legcy Routers without version prefix
|
||||||
app.include_router(api_router)
|
app.include_router(api_router)
|
||||||
|
|
||||||
|
# 配置中间件
|
||||||
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
# 添加审计中间件(可选,记录关键操作)
|
||||||
|
app.add_middleware(AuditMiddleware)
|
||||||
|
# 配置 CORS 中间件
|
||||||
|
# 确保这是你最后一个添加的 app.add_middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=[
|
||||||
|
"http://localhost:3000", # 必须明确指定
|
||||||
|
"http://127.0.0.1:3000", # 建议同时加上这个
|
||||||
|
],
|
||||||
|
allow_credentials=True, # 既然这里是 True
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,22 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import yaml
|
|
||||||
|
|
||||||
# 获取当前项目根目录的路径
|
# 从环境变量 NETWORK_NAME 读取
|
||||||
_current_file = os.path.abspath(__file__)
|
name = os.getenv("NETWORK_NAME")
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(_current_file)))
|
|
||||||
|
|
||||||
# 尝试读取 .yml 或 .yaml 文件
|
|
||||||
config_file = os.path.join(project_root, "configs", "project_info.yml")
|
|
||||||
if not os.path.exists(config_file):
|
|
||||||
config_file = os.path.join(project_root, "configs", "project_info.yaml")
|
|
||||||
|
|
||||||
if not os.path.exists(config_file):
|
|
||||||
raise FileNotFoundError(f"未找到项目配置文件 (project_info.yaml 或 .yml): {os.path.dirname(config_file)}")
|
|
||||||
|
|
||||||
with open(config_file, 'r', encoding='utf-8') as f:
|
|
||||||
_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
if not _config or 'name' not in _config:
|
|
||||||
raise KeyError(f"项目配置文件中缺少 'name' 配置: {config_file}")
|
|
||||||
|
|
||||||
name = _config['name']
|
|
||||||
|
|||||||
@@ -1190,12 +1190,12 @@ def run_simulation(
|
|||||||
if modify_valve_opening[valve_name] == 0:
|
if modify_valve_opening[valve_name] == 0:
|
||||||
valve_status["status"] = "CLOSED"
|
valve_status["status"] = "CLOSED"
|
||||||
valve_status["setting"] = 0
|
valve_status["setting"] = 0
|
||||||
if modify_valve_opening[valve_name] < 1:
|
elif modify_valve_opening[valve_name] < 1:
|
||||||
valve_status["status"] = "OPEN"
|
valve_status["status"] = "OPEN"
|
||||||
valve_status["setting"] = 0.1036 * pow(
|
valve_status["setting"] = 0.1036 * pow(
|
||||||
modify_valve_opening[valve_name], -3.105
|
modify_valve_opening[valve_name], -3.105
|
||||||
)
|
)
|
||||||
if modify_valve_opening[valve_name] == 1:
|
elif modify_valve_opening[valve_name] == 1:
|
||||||
valve_status["status"] = "OPEN"
|
valve_status["status"] = "OPEN"
|
||||||
valve_status["setting"] = 0
|
valve_status["setting"] = 0
|
||||||
cs = ChangeSet()
|
cs = ChangeSet()
|
||||||
@@ -1235,7 +1235,7 @@ def run_simulation(
|
|||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
if simulation_type.upper() == "REALTIME":
|
if simulation_type.upper() == "REALTIME":
|
||||||
TimescaleInternalStorage.store_realtime_simulation(
|
TimescaleInternalStorage.store_realtime_simulation(
|
||||||
node_result, link_result, modify_pattern_start_time
|
node_result, link_result, modify_pattern_start_time, db_name=name
|
||||||
)
|
)
|
||||||
elif simulation_type.upper() == "EXTENDED":
|
elif simulation_type.upper() == "EXTENDED":
|
||||||
TimescaleInternalStorage.store_scheme_simulation(
|
TimescaleInternalStorage.store_scheme_simulation(
|
||||||
@@ -1245,6 +1245,7 @@ def run_simulation(
|
|||||||
link_result,
|
link_result,
|
||||||
modify_pattern_start_time,
|
modify_pattern_start_time,
|
||||||
num_periods_result,
|
num_periods_result,
|
||||||
|
db_name=name,
|
||||||
)
|
)
|
||||||
endtime = time.time()
|
endtime = time.time()
|
||||||
logging.info("store time: %f", endtime - starttime)
|
logging.info("store time: %f", endtime - starttime)
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
name: szh
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
FROM python:3.12-slim
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
COPY app ./app
|
|
||||||
COPY resources ./resources
|
|
||||||
|
|
||||||
ENV PYTHONPATH=/app
|
|
||||||
|
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
||||||
@@ -95,7 +95,6 @@ prometheus_client==0.24.1
|
|||||||
psycopg==3.2.5
|
psycopg==3.2.5
|
||||||
psycopg-binary==3.2.5
|
psycopg-binary==3.2.5
|
||||||
psycopg-pool==3.3.0
|
psycopg-pool==3.3.0
|
||||||
psycopg2==2.9.10
|
|
||||||
PuLP==3.1.1
|
PuLP==3.1.1
|
||||||
py-key-value-aio==0.3.0
|
py-key-value-aio==0.3.0
|
||||||
py-key-value-shared==0.3.0
|
py-key-value-shared==0.3.0
|
||||||
@@ -157,8 +156,6 @@ starlette==0.50.0
|
|||||||
threadpoolctl==3.6.0
|
threadpoolctl==3.6.0
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
typer==0.21.1
|
typer==0.21.1
|
||||||
typing-inspection==0.4.0
|
|
||||||
typing_extensions==4.12.2
|
|
||||||
tzdata==2025.2
|
tzdata==2025.2
|
||||||
urllib3==2.2.3
|
urllib3==2.2.3
|
||||||
uvicorn==0.34.0
|
uvicorn==0.34.0
|
||||||
|
|||||||
33
scripts/encrypt_string.py
Normal file
33
scripts/encrypt_string.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 将项目根目录添加到 python 路径
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from app.core.encryption import get_database_encryptor
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
plaintext = None
|
||||||
|
if not sys.stdin.isatty():
|
||||||
|
stdin_text = sys.stdin.read()
|
||||||
|
if stdin_text != "":
|
||||||
|
plaintext = stdin_text.rstrip("\r\n")
|
||||||
|
if plaintext is None and len(sys.argv) >= 2:
|
||||||
|
plaintext = sys.argv[1]
|
||||||
|
if plaintext is None:
|
||||||
|
try:
|
||||||
|
plaintext = input("请输入要加密的文本: ")
|
||||||
|
except EOFError:
|
||||||
|
plaintext = ""
|
||||||
|
if not plaintext.strip():
|
||||||
|
print("Error: plaintext string cannot be empty.", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
token = get_database_encryptor().encrypt(plaintext)
|
||||||
|
print(token)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -16,6 +16,6 @@ if __name__ == "__main__":
|
|||||||
"app.main:app",
|
"app.main:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
port=8000,
|
port=8000,
|
||||||
workers=2, # 这里可以设置多进程
|
workers=4, # 这里可以设置多进程
|
||||||
loop="asyncio",
|
loop="asyncio",
|
||||||
)
|
)
|
||||||
|
|||||||
119
tests/unit/test_metadata_repository_dsn_decrypt.py
Normal file
119
tests/unit/test_metadata_repository_dsn_decrypt.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cryptography.fernet import InvalidToken
|
||||||
|
|
||||||
|
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyResult:
|
||||||
|
def __init__(self, record):
|
||||||
|
self._record = record
|
||||||
|
|
||||||
|
def scalar_one_or_none(self):
|
||||||
|
return self._record
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyEncryptor:
|
||||||
|
def __init__(self, decrypted=None, raise_invalid_token=False):
|
||||||
|
self._decrypted = decrypted
|
||||||
|
self._raise_invalid_token = raise_invalid_token
|
||||||
|
self.encrypted_values = []
|
||||||
|
|
||||||
|
def decrypt(self, _value):
|
||||||
|
if self._raise_invalid_token:
|
||||||
|
raise InvalidToken()
|
||||||
|
return self._decrypted
|
||||||
|
|
||||||
|
def _build_record(dsn_encrypted: str):
|
||||||
|
return SimpleNamespace(
|
||||||
|
project_id=uuid4(),
|
||||||
|
db_role="biz_data",
|
||||||
|
db_type="postgresql",
|
||||||
|
dsn_encrypted=dsn_encrypted,
|
||||||
|
pool_min_size=1,
|
||||||
|
pool_max_size=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_token_with_plaintext_dsn_value_raises_clear_error(monkeypatch):
|
||||||
|
record = _build_record("postgresql://user:p@ss@localhost:5432/db")
|
||||||
|
session = SimpleNamespace(
|
||||||
|
execute=None,
|
||||||
|
commit=None,
|
||||||
|
)
|
||||||
|
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
encryptor = _DummyEncryptor(raise_invalid_token=True)
|
||||||
|
repo = MetadataRepository(session)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||||
|
lambda: encryptor,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
|
||||||
|
):
|
||||||
|
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||||
|
session.commit.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_token_with_non_dsn_value_raises_clear_error(monkeypatch):
|
||||||
|
record = _build_record("gAAAAABinvalidciphertext")
|
||||||
|
session = SimpleNamespace(
|
||||||
|
execute=None,
|
||||||
|
commit=None,
|
||||||
|
)
|
||||||
|
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
repo = MetadataRepository(session)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||||
|
lambda: _DummyEncryptor(raise_invalid_token=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
|
||||||
|
):
|
||||||
|
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||||
|
session.commit.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encrypted_dsn_decrypts_without_migration(monkeypatch):
|
||||||
|
record = _build_record("encrypted-value")
|
||||||
|
session = SimpleNamespace(
|
||||||
|
execute=None,
|
||||||
|
commit=None,
|
||||||
|
)
|
||||||
|
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||||
|
session.commit = AsyncMock()
|
||||||
|
repo = MetadataRepository(session)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||||
|
lambda: _DummyEncryptor(decrypted="postgresql://u:p@ss@host/db"),
|
||||||
|
)
|
||||||
|
|
||||||
|
routing = asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||||
|
|
||||||
|
assert routing.dsn == "postgresql://u:p%40ss@host/db"
|
||||||
|
session.commit.assert_not_awaited()
|
||||||
Reference in New Issue
Block a user