Compare commits

..

26 Commits

Author SHA1 Message Date
80b6970970 添加数据库加密处理的单元测试 2026-02-25 16:54:14 +08:00
364a8c8ec2 添加加密字符串脚本以支持文本加密功能 2026-02-25 16:54:09 +08:00
52ccb8abf1 实现数据库的连接串加密 2026-02-25 16:36:53 +08:00
0bc4058f23 更新加密器以支持从环境变量或配置读取密钥 2026-02-24 17:03:25 +08:00
0d3e6ca4fa 重构中间件配置顺序并添加数据库连接日志 2026-02-24 17:03:06 +08:00
6fc3aa5209 添加日志记录和异常处理以增强错误管理 2026-02-24 17:02:56 +08:00
1b1b0a3697 添加 row_factory 参数以支持字典行返回 2026-02-24 17:02:48 +08:00
2826999ddc 修复数据库连接URL中密码包含"@"的问题 2026-02-24 17:01:39 +08:00
efc05f7278 新增KEYCLOAK_AUDIENCE,解决前后端认证失败的问题 2026-02-24 15:15:13 +08:00
29209f5c63 更新gitignore 2026-02-24 10:46:33 +08:00
020432ad0e 取消AUTH_DISABLED参数 2026-02-24 10:45:53 +08:00
780a48d927 重构数据库连接管理,添加元数据支持 2026-02-11 18:57:47 +08:00
ff2011ae24 更新 agent instructions 2026-02-11 11:00:55 +08:00
f5069a5606 统一连接到新的数据库到openproject api 下 2026-02-11 11:00:44 +08:00
eb45e4aaa5 调整代码,支持项目切换,打开不同数据库的连接 2026-02-11 10:42:40 +08:00
a472639b8a 新增Dockerfile;修改simulations中部分参数格式判断 2026-02-10 15:25:03 +08:00
a0987105dc 调整环境变量配置,便于docker打包 2026-02-09 15:31:21 +08:00
a41be9c362 为 emitter_demand 添加新的 pattern,使用新的 pattern 模拟管道冲洗 2026-02-06 18:24:15 +08:00
63b31b46b9 修复管道清洗算法流量单位取值bug 2026-02-06 17:46:56 +08:00
e4f864a28c 更新爆管分析接受参数格式 2026-02-06 16:59:46 +08:00
dc38313cdc 修复scheme计算属性无法显示的问题 2026-02-06 11:32:47 +08:00
f19962510a 为flushing_analysis新增scheme_name参数 2026-02-05 16:13:41 +08:00
6434cae21c 统一scheme_type命名 2026-02-05 15:39:56 +08:00
a85ff8e215 copilot项目描述文件 2026-02-05 10:47:54 +08:00
2794114000 统一scheme_name命名规则 2026-02-05 10:47:38 +08:00
4c208abe55 优化关阀分析算法,实现网络拓扑缓存,增量图处理 2026-02-05 10:46:46 +08:00
46 changed files with 2175 additions and 602 deletions

View File

@@ -1,6 +1,6 @@
# TJWater Server 环境变量配置模板
# 复制此文件为 .env 并填写实际值
NETWORK_NAME="szh"
# ============================================
# 安全配置 (必填)
# ============================================
@@ -12,24 +12,44 @@ SECRET_KEY=your-secret-key-here-change-in-production-use-openssl-rand-hex-32
# 数据加密密钥 - 用于敏感数据加密
# 生成方式: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
ENCRYPTION_KEY=
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
# ============================================
# 数据库配置 (PostgreSQL)
# ============================================
DB_NAME=tjwater
DB_HOST=localhost
DB_PORT=5432
DB_USER=postgres
DB_PASSWORD=password
DB_NAME="tjwater"
DB_HOST="localhost"
DB_PORT="5432"
DB_USER="postgres"
DB_PASSWORD="password"
# ============================================
# 数据库配置 (TimescaleDB)
# ============================================
TIMESCALEDB_DB_NAME=szh
TIMESCALEDB_DB_HOST=localhost
TIMESCALEDB_DB_PORT=5433
TIMESCALEDB_DB_USER=tjwater
TIMESCALEDB_DB_PASSWORD=Tjwater@123456
TIMESCALEDB_DB_NAME="szh"
TIMESCALEDB_DB_HOST="localhost"
TIMESCALEDB_DB_PORT="5433"
TIMESCALEDB_DB_USER="tjwater"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
# ============================================
# 元数据数据库配置 (Metadata DB)
# ============================================
METADATA_DB_NAME="system_hub"
METADATA_DB_HOST="localhost"
METADATA_DB_PORT="5432"
METADATA_DB_USER="tjwater"
METADATA_DB_PASSWORD="password"
# ============================================
# 项目连接缓存与连接池配置
# ============================================
PROJECT_PG_CACHE_SIZE=50
PROJECT_TS_CACHE_SIZE=50
PROJECT_PG_POOL_SIZE=5
PROJECT_PG_MAX_OVERFLOW=10
PROJECT_TS_POOL_MIN_SIZE=1
PROJECT_TS_POOL_MAX_SIZE=10
# ============================================
# InfluxDB 配置 (时序数据)
@@ -46,6 +66,12 @@ TIMESCALEDB_DB_PASSWORD=Tjwater@123456
# REFRESH_TOKEN_EXPIRE_DAYS=7
# ALGORITHM=HS256
# ============================================
# Keycloak JWT (可选)
# ============================================
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
# KEYCLOAK_ALGORITHM=RS256
# ============================================
# 其他配置
# ============================================

23
.env.local Normal file
View File

@@ -0,0 +1,23 @@
NETWORK_NAME="tjwater"
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
KEYCLOAK_ALGORITHM="RS256"
KEYCLOAK_AUDIENCE="account"
DB_NAME="tjwater"
DB_HOST="192.168.1.114"
DB_PORT="5432"
DB_USER="tjwater"
DB_PASSWORD="Tjwater@123456"
TIMESCALEDB_DB_NAME="tjwater"
TIMESCALEDB_DB_HOST="192.168.1.114"
TIMESCALEDB_DB_PORT="5433"
TIMESCALEDB_DB_USER="tjwater"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
METADATA_DB_NAME="system_hub"
METADATA_DB_HOST="192.168.1.114"
METADATA_DB_PORT="5432"
METADATA_DB_USER="tjwater"
METADATA_DB_PASSWORD="Tjwater@123456"
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="

198
.github/copilot-instructions.md vendored Normal file
View File

@@ -0,0 +1,198 @@
# TJWater Server - Copilot Instructions
This is a FastAPI-based water network management system (供水管网智能管理系统) that provides hydraulic simulation, SCADA data integration, network element management, and risk analysis capabilities.
## Running the Server
```bash
# Install dependencies
pip install -r requirements.txt
# Start the server (default: http://0.0.0.0:8000 with 2 workers)
python scripts/run_server.py
# Note: On Windows, the script automatically sets WindowsSelectorEventLoopPolicy
```
## Running Tests
```bash
# Run all tests
pytest
# Run a specific test file with verbose output
pytest tests/unit/test_pipeline_health_analyzer.py -v
# Run from conftest helper
python tests/conftest.py
```
## Architecture Overview
### Core Components
1. **Native Modules** (`app/native/`): Platform-specific compiled extensions (`.so` for Linux, `.pyd` for Windows) providing performance-critical functionality including:
- SCADA device integration
- Water distribution analysis (WDA)
- Pipe risk probability calculations
- Wrapped through `app.services.tjnetwork` interface
2. **Services Layer** (`app/services/`):
- `tjnetwork.py`: Main network API wrapper around native modules
- `simulation.py`: Hydraulic simulation orchestration (EPANET integration)
- `project_info.py`: Project configuration management
- `epanet/`: EPANET hydraulic engine integration
3. **API Layer** (`app/api/v1/`):
- **Network Elements**: Separate endpoint modules for junctions, reservoirs, tanks, pipes, pumps, valves
- **Components**: Curves, patterns, controls, options, quality, visuals
- **Network Features**: Tags, demands, geometry, regions/DMAs
- **Core Services**: Auth, project, simulation, SCADA, data query, snapshots
4. **Database Infrastructure** (`app/infra/db/`):
- **PostgreSQL**: Primary relational database (users, audit logs, project metadata)
- **TimescaleDB**: Time-series extension for historical data
- **InfluxDB**: Optional time-series database for high-frequency SCADA data
- Connection pools initialized in `main.py` lifespan context
- Database instance stored in `app.state.db` for dependency injection
5. **Domain Layer** (`app/domain/`):
- `models/`: Enums and domain objects (e.g., `UserRole`)
- `schemas/`: Pydantic models for request/response validation
6. **Algorithms** (`app/algorithms/`):
- `api_ex/`: Analysis algorithms (k-means sensor placement, sensitivity analysis, pipeline health)
- `data_cleaning.py`: Data preprocessing utilities
- `simulations.py`: Simulation helpers
### Security & Authentication
- **Authentication**: JWT-based with access tokens (30 min) and refresh tokens (7 days)
- **Authorization**: Role-based access control (RBAC) with 4 roles:
- `VIEWER`: Read-only access
- `USER`: Read-write access
- `OPERATOR`: Modify data
- `ADMIN`: Full permissions
- **Audit Logging**: `AuditMiddleware` automatically logs POST/PUT/DELETE requests
- **Encryption**: Fernet symmetric encryption for sensitive data (`app.core.encryption`)
Default admin accounts:
- `admin` / `admin123`
- `tjwater` / `tjwater@123`
### Key Files
- `app/main.py`: FastAPI app initialization, lifespan (DB pools), CORS, middleware, router mounting
- `app/api/v1/router.py`: Central router aggregating all endpoint modules
- `app/core/config.py`: Settings management using `pydantic-settings`
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
## Important Conventions
### Database Connections
- Database instances are initialized in `main.py` lifespan and stored in `app.state.db`
- Access via dependency injection:
```python
from app.auth.dependencies import get_db
async def endpoint(db = Depends(get_db)):
# Use db connection
```
### Authentication in Endpoints
Use dependency injection for auth requirements:
```python
from app.auth.dependencies import get_current_active_user
from app.auth.permissions import require_role, get_current_admin
from app.domain.models.role import UserRole
# Require any authenticated user
@router.get("/data")
async def get_data(current_user = Depends(get_current_active_user)):
return data
# Require specific role (USER or higher)
@router.post("/data")
async def create_data(current_user = Depends(require_role(UserRole.USER))):
return result
# Admin-only access
@router.delete("/data/{id}")
async def delete_data(id: int, current_user = Depends(get_current_admin)):
return result
```
### API Routing Structure
- All v1 APIs are mounted under `/api/v1` prefix via `api_router`
- Legacy routes without version prefix are also mounted for backward compatibility
- Group related endpoints in separate router modules under `app/api/v1/endpoints/`
- Use descriptive tags in `router.py` for OpenAPI documentation grouping
### Native Module Integration
- Native modules are pre-compiled for specific platforms
- Always import through `app.native.api` or `app.services.tjnetwork`
- The `tjnetwork` service wraps native APIs with constants like:
- Element types: `JUNCTION`, `RESERVOIR`, `TANK`, `PIPE`, `PUMP`, `VALVE`
- Operations: `API_ADD`, `API_UPDATE`, `API_DELETE`
- `ChangeSet` for batch operations
### Project Initialization
- On startup, `main.py` automatically loads project from `project_info.name` if set
- Projects are opened via `open_project(name)` from `tjnetwork` service
### Audit Logging
Manual audit logging for critical operations:
```python
from app.core.audit import log_audit_event, AuditAction
await log_audit_event(
action=AuditAction.UPDATE,
user_id=current_user.id,
username=current_user.username,
resource_type="resource_name",
resource_id=str(resource_id),
ip_address=request.client.host,
request_data=data
)
```
### Environment Configuration
- Copy `.env.example` to `.env` before first run
- Required environment variables:
- `SECRET_KEY`: JWT signing (generate with `openssl rand -hex 32`)
- `ENCRYPTION_KEY`: Data encryption (generate with Fernet)
- Database credentials for PostgreSQL, TimescaleDB, and optionally InfluxDB
### Database Migrations
SQL migration scripts are in `migrations/`:
- `001_create_users_table.sql`: User authentication tables
- `002_create_audit_logs_table.sql`: Audit logging tables
Apply with:
```bash
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
```
## API Documentation
- Swagger UI: http://localhost:8000/docs
- ReDoc: http://localhost:8000/redoc
- OpenAPI schema: http://localhost:8000/openapi.json
## Additional Resources
- `SECURITY_README.md`: Comprehensive security feature documentation
- `DEPLOYMENT.md`: Integration guide for security features
- `readme.md`: Project overview and directory structure (in Chinese)

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ build/
.env
*.dump
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
.vscode/

24
Dockerfile Normal file
View File

@@ -0,0 +1,24 @@
FROM continuumio/miniconda3:latest
WORKDIR /app
# 安装 Python 3.12 和 pymetis (通过 conda-forge 避免编译问题)
RUN conda install -y -c conda-forge python=3.12 pymetis && \
conda clean -afy
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 将代码放入子目录 'app',将数据放入子目录 'db_inp'
# 这样临时文件默认会生成在 /app 下,而代码在 /app/app 下,实现了分离
COPY app ./app
COPY db_inp ./db_inp
COPY temp ./temp
COPY .env .
# 设置 PYTHONPATH 以便 uvicorn 找到 app 模块
ENV PYTHONPATH=/app
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]

View File

@@ -103,6 +103,7 @@ def burst_analysis(
if isinstance(burst_ID, list):
if (burst_size is not None) and (type(burst_size) is not list):
return json.dumps("Type mismatch.")
# 转化为列表形式
elif isinstance(burst_ID, str):
burst_ID = [burst_ID]
if burst_size is not None:
@@ -199,7 +200,7 @@ def valve_close_analysis(
modify_pattern_start_time: str,
modify_total_duration: int = 900,
modify_valve_opening: dict[str, float] = None,
scheme_Name: str = None,
scheme_name: str = None,
) -> None:
"""
关阀模拟
@@ -207,7 +208,7 @@ def valve_close_analysis(
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
:param modify_total_duration: 模拟总历时,秒
:param modify_valve_opening: dict中包含多个阀门开启度str为阀门的idfloat为修改后的阀门开启度
:param scheme_Name: 方案名称
:param scheme_name: 方案名称
:return:
"""
print(
@@ -261,8 +262,8 @@ def valve_close_analysis(
modify_pattern_start_time=modify_pattern_start_time,
modify_total_duration=modify_total_duration,
modify_valve_opening=modify_valve_opening,
scheme_Type="valve_close_Analysis",
scheme_Name=scheme_Name,
scheme_type="valve_close_Analysis",
scheme_name=scheme_name,
)
# step 3. restore the base model
# for valve in valves:
@@ -284,7 +285,7 @@ def flushing_analysis(
modify_valve_opening: dict[str, float] = None,
drainage_node_ID: str = None,
flushing_flow: float = 0,
scheme_Name: str = None,
scheme_name: str = None,
) -> None:
"""
管道冲洗模拟
@@ -294,9 +295,15 @@ def flushing_analysis(
:param modify_valve_opening: dict中包含多个阀门开启度str为阀门的idfloat为修改后的阀门开启度
:param drainage_node_ID: 冲洗排放口所在节点ID
:param flushing_flow: 冲洗水量传入参数单位为m3/h
:param scheme_Name: 方案名称
:param scheme_name: 方案名称
:return:
"""
scheme_detail: dict = {
"duration": modify_total_duration,
"valve_opening": modify_valve_opening,
"drainage_node_ID": drainage_node_ID,
"flushing_flow": flushing_flow,
}
print(
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
+ " -- Start Analysis."
@@ -338,18 +345,42 @@ def flushing_analysis(
# status['setting'] = 0.1036 * pow(valve_k, -3.105)
# cs.append(status)
# set_status(new_name,cs)
units = get_option(new_name)
options = get_option(new_name)
units = options["UNITS"]
# step 2. set the emitter coefficient of drainage node or add flush flow to the drainage node
# 新建 pattern
time_option = get_time(new_name)
hydraulic_step = time_option["HYDRAULIC TIMESTEP"]
secs = from_clock_to_seconds_2(hydraulic_step)
cs_pattern = ChangeSet()
pt = {}
factors = []
tmp_duration = modify_total_duration
while tmp_duration > 0:
factors.append(1.0)
tmp_duration = tmp_duration - secs
pt["id"] = "flushing_pt"
pt["factors"] = factors
cs_pattern.append(pt)
add_pattern(new_name, cs_pattern)
# 为 emitter_demand 添加新的 pattern
emitter_demand = get_demand(new_name, drainage_node_ID)
cs = ChangeSet()
if flushing_flow > 0:
for r in emitter_demand["demands"]:
if units == "LPS":
r["demand"] += flushing_flow / 3.6
elif units == "CMH":
r["demand"] += flushing_flow
cs.append(emitter_demand)
set_demand(new_name, cs)
if units == "LPS":
emitter_demand["demands"].append(
{
"demand": flushing_flow / 3.6,
"pattern": "flushing_pt",
"category": None,
}
)
elif units == "CMH":
emitter_demand["demands"].append(
{"demand": flushing_flow, "pattern": "flushing_pt", "category": None}
)
cs.append(emitter_demand)
set_demand(new_name, cs)
else:
pipes = get_node_links(new_name, drainage_node_ID)
flush_diameter = 50
@@ -386,14 +417,23 @@ def flushing_analysis(
modify_pattern_start_time=modify_pattern_start_time,
modify_total_duration=modify_total_duration,
modify_valve_opening=modify_valve_opening,
scheme_Type="flushing_Analysis",
scheme_Name=scheme_Name,
scheme_type="flushing_analysis",
scheme_name=scheme_name,
)
# step 4. restore the base model
if is_project_open(new_name):
close_project(new_name)
delete_project(new_name)
# return result
# 存储方案信息到 PG 数据库
store_scheme_info(
name=name,
scheme_name=scheme_name,
scheme_type="flushing_analysis",
username="admin",
scheme_start_time=modify_pattern_start_time,
scheme_detail=scheme_detail,
)
############################################################
@@ -415,10 +455,10 @@ def contaminant_simulation(
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
:param modify_total_duration: 模拟总历时,秒
:param source: 污染源所在的节点ID
:param concentration: 污染源位置处的浓度单位mg/L,即默认的污染模拟setting为concentration应改为 Set point booster
:param concentration: 污染源位置处的浓度单位mg/L默认的污染模拟setting为SOURCE_TYPE_CONCEN改为SOURCE_TYPE_SETPOINT
:param source_pattern: 污染源的时间变化模式若不传入则默认以恒定浓度持续模拟时间长度等于duration;
若传入,则格式为{1.0,0.5,1.1}等系数列表pattern_step模拟等于模型的hydraulic time step
:param scheme_Name: 方案名称
:param scheme_name: 方案名称
:return:
"""
scheme_detail: dict = {
@@ -469,7 +509,7 @@ def contaminant_simulation(
# step 2. set pattern
if source_pattern != None:
pt = get_pattern(new_name, source_pattern)
if pt == None:
if len(pt) == 0:
str_response = str("cant find source_pattern")
return str_response
else:
@@ -490,7 +530,7 @@ def contaminant_simulation(
cs_source = ChangeSet()
source_schema = {
"node": source,
"s_type": SOURCE_TYPE_CONCEN,
"s_type": SOURCE_TYPE_SETPOINT,
"strength": concentration,
"pattern": pt["id"],
}
@@ -634,7 +674,7 @@ def pressure_regulation(
modify_tank_initial_level: dict[str, float] = None,
modify_fixed_pump_pattern: dict[str, list] = None,
modify_variable_pump_pattern: dict[str, list] = None,
scheme_Name: str = None,
scheme_name: str = None,
) -> None:
"""
区域调压模拟用来模拟未来15分钟内开关水泵对区域压力的影响
@@ -644,7 +684,7 @@ def pressure_regulation(
:param modify_tank_initial_level: dict中包含多个水塔str为水塔的idfloat为修改后的initial_level
:param modify_fixed_pump_pattern: dict中包含多个水泵模式str为工频水泵的idlist为修改后的pattern
:param modify_variable_pump_pattern: dict中包含多个水泵模式str为变频水泵的idlist为修改后的pattern
:param scheme_Name: 模拟方案名称
:param scheme_name: 模拟方案名称
:return:
"""
print(
@@ -696,8 +736,8 @@ def pressure_regulation(
modify_tank_initial_level=modify_tank_initial_level,
modify_fixed_pump_pattern=modify_fixed_pump_pattern,
modify_variable_pump_pattern=modify_variable_pump_pattern,
scheme_Type="pressure_regulation",
scheme_Name=scheme_Name,
scheme_type="pressure_regulation",
scheme_name=scheme_name,
)
if is_project_open(new_name):
close_project(new_name)

View File

@@ -1,10 +1,10 @@
from collections import defaultdict, deque
from functools import lru_cache
from typing import Any
from app.services.tjnetwork import (
get_network_link_nodes,
is_node,
is_link,
get_link_properties,
)
@@ -19,48 +19,102 @@ def _parse_link_entry(link_entry: str) -> tuple[str, str, str, str]:
return parts[0], parts[1], parts[2], parts[3]
@lru_cache(maxsize=16)
def _get_network_topology(network: str):
"""
解析并缓存网络拓扑,大幅减少重复的 API 调用和字符串解析开销。
返回:
- pipe_adj: 永久连通的管道/泵邻接表 (dict[str, set])
- all_valves: 所有阀门字典 {id: (n1, n2)}
- link_lookup: 链路快速查表 {id: (n1, n2, type)} 用于快速定位事故点
- node_set: 所有已知节点集合
"""
pipe_adj = defaultdict(set)
all_valves = {}
link_lookup = {}
node_set = set()
# 此处假设 get_network_link_nodes 获取全网数据
for link_entry in get_network_link_nodes(network):
link_id, link_type, node1, node2 = _parse_link_entry(link_entry)
link_type_name = str(link_type).lower()
link_lookup[link_id] = (node1, node2, link_type_name)
node_set.add(node1)
node_set.add(node2)
if link_type_name == VALVE_LINK_TYPE:
all_valves[link_id] = (node1, node2)
else:
# 只有非阀门(管道/泵)才进入永久连通图
pipe_adj[node1].add(node2)
pipe_adj[node2].add(node1)
return pipe_adj, all_valves, link_lookup, node_set
def valve_isolation_analysis(
network: str, accident_elements: str | list[str]
network: str, accident_elements: str | list[str], disabled_valves: list[str] = None
) -> dict[str, Any]:
"""
关阀搜索/分析:基于拓扑结构确定事故隔离所需关阀。
:param network: 模型名称
:param accident_elements: 事故点(节点或管道/泵/阀门ID可以是单个ID字符串或ID列表
:param disabled_valves: 故障/无法关闭的阀门ID列表
:return: dict包含受影响节点、必须关闭阀门、可选阀门等信息
"""
if disabled_valves is None:
disabled_valves_set = set()
else:
disabled_valves_set = set(disabled_valves)
if isinstance(accident_elements, str):
target_elements = [accident_elements]
else:
target_elements = accident_elements
# 1. 获取缓存拓扑 (极快,无 IO)
pipe_adj, all_valves, link_lookup, node_set = _get_network_topology(network)
# 2. 确定起点,优先查表避免 API 调用
start_nodes = set()
for element in target_elements:
if is_node(network, element):
if element in node_set:
start_nodes.add(element)
elif is_link(network, element):
link_props = get_link_properties(network, element)
node1 = link_props.get("node1")
node2 = link_props.get("node2")
if not node1 or not node2:
# 如果是批量处理,可以选择跳过错误或记录错误,这里暂时保持严谨抛出异常
raise ValueError(f"Accident link {element} missing node endpoints")
start_nodes.add(node1)
start_nodes.add(node2)
elif element in link_lookup:
n1, n2, _ = link_lookup[element]
start_nodes.add(n1)
start_nodes.add(n2)
else:
raise ValueError(f"Accident element {element} not found")
# 仅当缓存中没找到时(极少见),才回退到慢速 API
if is_node(network, element):
start_nodes.add(element)
else:
props = get_link_properties(network, element)
n1, n2 = props.get("node1"), props.get("node2")
if n1 and n2:
start_nodes.add(n1)
start_nodes.add(n2)
else:
raise ValueError(
f"Accident element {element} invalid or missing endpoints"
)
adjacency: dict[str, set[str]] = defaultdict(set)
valve_links: dict[str, tuple[str, str]] = {}
for link_entry in get_network_link_nodes(network):
link_id, link_type, node1, node2 = _parse_link_entry(link_entry)
link_type_name = str(link_type).lower()
if link_type_name == VALVE_LINK_TYPE:
valve_links[link_id] = (node1, node2)
continue
adjacency[node1].add(node2)
adjacency[node2].add(node1)
# 3. 处理故障阀门 (构建临时增量图)
# 我们不修改 cached pipe_adj而是建立一个 extra_adj
extra_adj = defaultdict(list)
boundary_valves = {} # 当前有效的边界阀门
for vid, (n1, n2) in all_valves.items():
if vid in disabled_valves_set:
# 故障阀门:视为连通管道
extra_adj[n1].append(n2)
extra_adj[n2].append(n1)
else:
# 正常阀门:视为潜在边界
boundary_valves[vid] = (n1, n2)
# 4. BFS 搜索 (叠加 pipe_adj 和 extra_adj)
affected_nodes: set[str] = set()
queue = deque(start_nodes)
while queue:
@@ -68,18 +122,29 @@ def valve_isolation_analysis(
if node in affected_nodes:
continue
affected_nodes.add(node)
for neighbor in adjacency.get(node, []):
if neighbor not in affected_nodes:
queue.append(neighbor)
# 遍历永久管道邻居
if node in pipe_adj:
for neighbor in pipe_adj[node]:
if neighbor not in affected_nodes:
queue.append(neighbor)
# 遍历故障阀门带来的额外邻居
if node in extra_adj:
for neighbor in extra_adj[node]:
if neighbor not in affected_nodes:
queue.append(neighbor)
# 5. 结果聚合
must_close_valves: list[str] = []
optional_valves: list[str] = []
for valve_id, (node1, node2) in valve_links.items():
in_node1 = node1 in affected_nodes
in_node2 = node2 in affected_nodes
if in_node1 and in_node2:
for valve_id, (n1, n2) in boundary_valves.items():
in_n1 = n1 in affected_nodes
in_n2 = n2 in affected_nodes
if in_n1 and in_n2:
optional_valves.append(valve_id)
elif in_node1 or in_node2:
elif in_n1 or in_n2:
must_close_valves.append(valve_id)
must_close_valves.sort()
@@ -87,6 +152,7 @@ def valve_isolation_analysis(
result = {
"accident_elements": target_elements,
"disabled_valves": disabled_valves,
"affected_nodes": sorted(affected_nodes),
"must_close_valves": must_close_valves,
"optional_valves": optional_valves,

View File

@@ -4,33 +4,38 @@
仅管理员可访问
"""
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from fastapi import APIRouter, Depends, Query, Request
from app.domain.schemas.audit import AuditLogResponse, AuditLogQuery
from app.domain.schemas.user import UserInDB
from fastapi import APIRouter, Depends, Query
from app.domain.schemas.audit import AuditLogResponse
from app.infra.repositories.audit_repository import AuditRepository
from app.auth.dependencies import get_user_repository, get_db
from app.auth.permissions import get_current_admin
from app.infra.db.postgresql.database import Database
from app.auth.metadata_dependencies import (
get_current_metadata_admin,
get_current_metadata_user,
)
from app.infra.db.metadata.database import get_metadata_session
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
async def get_audit_repository(db: Database = Depends(get_db)) -> AuditRepository:
async def get_audit_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> AuditRepository:
"""获取审计日志仓储"""
return AuditRepository(db)
return AuditRepository(session)
@router.get("/logs", response_model=List[AuditLogResponse])
async def get_audit_logs(
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
username: Optional[str] = Query(None, description="用户名过滤"),
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="项目ID过滤"),
action: Optional[str] = Query(None, description="按操作类型过滤"),
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_admin),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询审计日志(仅管理员)
@@ -39,7 +44,7 @@ async def get_audit_logs(
"""
logs = await audit_repo.get_logs(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
start_time=start_time,
@@ -51,21 +56,21 @@ async def get_audit_logs(
@router.get("/logs/count")
async def get_audit_logs_count(
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
username: Optional[str] = Query(None, description="用户名过滤"),
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="项目ID过滤"),
action: Optional[str] = Query(None, description="按操作类型过滤"),
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_admin),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> dict:
"""
获取审计日志总数(仅管理员)
"""
count = await audit_repo.get_log_count(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
start_time=start_time,
@@ -80,8 +85,8 @@ async def get_my_audit_logs(
end_time: Optional[datetime] = Query(None, description="结束时间"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
current_user: UserInDB = Depends(get_current_admin),
audit_repo: AuditRepository = Depends(get_audit_repository)
current_user=Depends(get_current_metadata_user),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询当前用户的审计日志

View File

@@ -316,7 +316,7 @@ async def fastapi_query_all_scheme_all_records(
return loaded_dict
results = influxdb_api.query_scheme_all_record(
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
)
packed = msgpack.packb(results, default=encode_datetime)
redis_client.set(cache_key, packed)
@@ -334,7 +334,7 @@ async def fastapi_query_all_scheme_all_records_property(
all_results = msgpack.unpackb(data, object_hook=decode_datetime)
else:
all_results = influxdb_api.query_scheme_all_record(
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
)
packed = msgpack.packb(all_results, default=encode_datetime)
redis_client.set(cache_key, packed)

View File

@@ -22,7 +22,7 @@ async def get_all_extension_data_endpoint(network: str) -> dict[str, Any]:
async def get_extension_data_endpoint(network: str, key: str) -> str | None:
return get_extension_data(network, key)
@router.post("/setextensiondata", response_model=None)
@router.post("/setextensiondata/", response_model=None)
async def set_extension_data_endpoint(network: str, req: Request) -> ChangeSet:
props = await req.json()
print(props)

View File

@@ -0,0 +1,101 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from psycopg import AsyncConnection
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.project_dependencies import (
ProjectContext,
get_project_context,
get_project_pg_session,
get_project_timescale_connection,
get_metadata_repository,
)
from app.auth.metadata_dependencies import get_current_metadata_user
from app.core.config import settings
from app.domain.schemas.metadata import (
GeoServerConfigResponse,
ProjectMetaResponse,
ProjectSummaryResponse,
)
from app.infra.repositories.metadata_repository import MetadataRepository
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/meta/project", response_model=ProjectMetaResponse)
async def get_project_metadata(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
project = await metadata_repo.get_project_by_id(ctx.project_id)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
geoserver_payload = (
GeoServerConfigResponse(
gs_base_url=geoserver.gs_base_url,
gs_admin_user=geoserver.gs_admin_user,
gs_datastore_name=geoserver.gs_datastore_name,
default_extent=geoserver.default_extent,
srid=geoserver.srid,
)
if geoserver
else None
)
return ProjectMetaResponse(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=ctx.project_role,
geoserver=geoserver_payload,
)
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
async def list_user_projects(
current_user=Depends(get_current_metadata_user),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
try:
projects = await metadata_repo.list_projects_for_user(current_user.id)
except SQLAlchemyError as exc:
logger.error(
"Metadata DB error while listing projects for user %s",
current_user.id,
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Metadata database error: {exc}",
) from exc
return [
ProjectSummaryResponse(
project_id=project.project_id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=project.project_role,
)
for project in projects
]
@router.get("/meta/db/health")
async def project_db_health(
pg_session: AsyncSession = Depends(get_project_pg_session),
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
):
await pg_session.execute(text("SELECT 1"))
async with ts_conn.cursor() as cur:
await cur.execute("SELECT 1")
return {"postgres": "ok", "timescale": "ok"}

View File

@@ -4,6 +4,8 @@ from fastapi.responses import PlainTextResponse
from typing import Any, Dict
import app.services.project_info as project_info
from app.native.api import ChangeSet
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
from app.services.tjnetwork import (
list_project,
have_project,
@@ -62,6 +64,28 @@ async def is_project_open_endpoint(network: str):
@router.post("/openproject/")
async def open_project_endpoint(network: str):
open_project(network)
# 尝试连接指定数据库
try:
# 初始化 PostgreSQL 连接池
pg_instance = await get_pg_db(network)
async with pg_instance.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT 1")
# 初始化 TimescaleDB 连接池
ts_instance = await get_ts_db(network)
async with ts_instance.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT 1")
except Exception as e:
# 记录错误但不阻断项目打开,或者根据需求决定是否阻断
# 这里选择打印错误,因为 open_project 原本只负责原生部分
print(f"Failed to connect to databases for {network}: {str(e)}")
# 如果数据库连接是必须的,可以抛出异常:
# raise HTTPException(status_code=500, detail=f"Database connection failed: {str(e)}")
return network
@router.post("/closeproject/")

View File

@@ -60,7 +60,7 @@ class BurstAnalysis(BaseModel):
modify_fixed_pump_pattern: Optional[dict[str, list]] = None
modify_variable_pump_pattern: Optional[dict[str, list]] = None
modify_valve_opening: Optional[dict[str, float]] = None
scheme_Name: Optional[str] = None
scheme_name: Optional[str] = None
class SchedulingAnalysis(BaseModel):
@@ -78,7 +78,7 @@ class PressureRegulation(BaseModel):
pump_control: dict
tank_init_level: Optional[dict] = None
duration: Optional[int] = 900
scheme_Name: Optional[str] = None
scheme_name: Optional[str] = None
class ProjectManagement(BaseModel):
@@ -115,7 +115,15 @@ class PressureSensorPlacement(BaseModel):
def run_simulation_manually_by_date(
network_name: str, base_date: datetime, start_time: str, duration: int
) -> None:
start_hour, start_minute, start_second = map(int, start_time.split(":"))
time_parts = list(map(int, start_time.split(":")))
if len(time_parts) == 2:
start_hour, start_minute = time_parts
start_second = 0
elif len(time_parts) == 3:
start_hour, start_minute, start_second = time_parts
else:
raise ValueError("Invalid start_time format. Use HH:MM or HH:MM:SS")
start_datetime = base_date.replace(
hour=start_hour, minute=start_minute, second=start_second
)
@@ -196,10 +204,8 @@ async def burst_analysis_endpoint(
async def fastapi_burst_analysis(
network: str = Query(...),
modify_pattern_start_time: str = Query(...),
burst_ID: list | str = Query(..., alias="burst_ID[]"), # 添加别名以匹配 URL
burst_size: list | float | int = Query(
..., alias="burst_size[]"
), # 添加别名以匹配 URL
burst_ID: list[str] = Query(...),
burst_size: list[float] = Query(...),
modify_total_duration: int = Query(...),
scheme_name: str = Query(...),
) -> str:
@@ -239,9 +245,29 @@ async def fastapi_valve_close_analysis(
@router.get("/valve_isolation_analysis/")
async def valve_isolation_endpoint(
network: str, accident_element: List[str] = Query(...)
network: str,
accident_element: List[str] = Query(...),
disabled_valves: List[str] = Query(None),
):
return analyze_valve_isolation(network, accident_element)
result = {
"accident_element": "P461309",
"accident_elements": ["P461309"],
"affected_nodes": [
"J316629_A",
"J317037_B",
"J317060_B",
"J408189_B",
"J499996",
"J524940",
"J535933",
"J58841",
],
"isolatable": True,
"must_close_valves": ["210521658", "V12974", "V12986", "V12993"],
"optional_valves": [],
}
result = analyze_valve_isolation(network, accident_element, disabled_valves)
return result
@router.get("/flushinganalysis/")
@@ -260,6 +286,7 @@ async def fastapi_flushing_analysis(
drainage_node_ID: str = Query(...),
flush_flow: float = 0,
duration: int | None = None,
scheme_name: str | None = None,
) -> str:
valve_opening = {
valve_id: float(valves_k[idx]) for idx, valve_id in enumerate(valves)
@@ -271,6 +298,7 @@ async def fastapi_flushing_analysis(
modify_valve_opening=valve_opening,
drainage_node_ID=drainage_node_ID,
flushing_flow=flush_flow,
scheme_name=scheme_name,
)
return result or "success"
@@ -342,7 +370,7 @@ async def fastapi_pressure_regulation(data: PressureRegulation) -> str:
modify_tank_initial_level=item["tank_init_level"],
modify_fixed_pump_pattern=fixed_pump_pattern or None,
modify_variable_pump_pattern=variable_pump_pattern or None,
scheme_Name=item["scheme_Name"],
scheme_name=item["scheme_name"],
)
return "success"
@@ -634,13 +662,9 @@ async def fastapi_run_simulation_manually_by_date(
globals.realtime_region_pipe_flow_and_demand_id,
)
base_date = datetime.strptime(item["simulation_date"], "%Y-%m-%d")
thread = threading.Thread(
target=lambda: run_simulation_manually_by_date(
item["name"], base_date, item["start_time"], item["duration"]
)
run_simulation_manually_by_date(
item["name"], base_date, item["start_time"], item["duration"]
)
thread.start()
thread.join()
return {"status": "success"}
except Exception as exc:
return {"status": "error", "message": str(exc)}

View File

@@ -14,6 +14,7 @@ from app.api.v1.endpoints import (
cache,
user_management, # 新增:用户管理
audit, # 新增:审计日志
meta,
)
from app.api.v1.endpoints.network import (
general,
@@ -46,6 +47,7 @@ api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
api_router.include_router(meta.router, tags=["Metadata"])
api_router.include_router(project.router, tags=["Project"])
# Network Elements (Node/Link Types)

View File

@@ -0,0 +1,63 @@
# import logging
from uuid import UUID
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from app.core.config import settings
oauth2_optional = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
)
# logger = logging.getLogger(__name__)
async def get_current_keycloak_sub(
token: str | None = Depends(oauth2_optional),
) -> UUID:
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
if settings.KEYCLOAK_PUBLIC_KEY:
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
algorithms = [settings.KEYCLOAK_ALGORITHM]
else:
key = settings.SECRET_KEY
algorithms = [settings.ALGORITHM]
try:
payload = jwt.decode(
token,
key,
algorithms=algorithms,
audience=settings.KEYCLOAK_AUDIENCE or None,
)
except JWTError as exc:
# logger.warning("Keycloak token validation failed: %s", exc)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
sub = payload.get("sub")
if not sub:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing subject claim",
headers={"WWW-Authenticate": "Bearer"},
)
try:
return UUID(sub)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid subject claim",
headers={"WWW-Authenticate": "Bearer"},
) from exc

View File

@@ -0,0 +1,60 @@
from dataclasses import dataclass
from uuid import UUID
import logging
from fastapi import Depends, HTTPException, status
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
async def get_metadata_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> MetadataRepository:
return MetadataRepository(session)
async def get_current_metadata_user(
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
try:
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
except SQLAlchemyError as exc:
logger.error(
"Metadata DB error while resolving current user",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Metadata database error: {exc}",
) from exc
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
return user
async def get_current_metadata_admin(
user=Depends(get_current_metadata_user),
):
if user.is_superuser or user.role == "admin":
return user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
)
@dataclass(frozen=True)
class _AuthBypassUser:
id: UUID = UUID(int=0)
role: str = "admin"
is_superuser: bool = True
is_active: bool = True

View File

@@ -0,0 +1,213 @@
from dataclasses import dataclass
from typing import AsyncGenerator
from uuid import UUID
import logging
from fastapi import Depends, Header, HTTPException, status
from psycopg import AsyncConnection
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.dynamic_manager import project_connection_manager
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
DB_ROLE_BIZ_DATA = "biz_data"
DB_ROLE_IOT_DATA = "iot_data"
DB_TYPE_POSTGRES = "postgresql"
DB_TYPE_TIMESCALE = "timescaledb"
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ProjectContext:
project_id: UUID
user_id: UUID
project_role: str
async def get_metadata_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> MetadataRepository:
return MetadataRepository(session)
async def get_project_context(
x_project_id: str = Header(..., alias="X-Project-Id"),
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> ProjectContext:
try:
project_uuid = UUID(x_project_id)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
) from exc
try:
project = await metadata_repo.get_project_by_id(project_uuid)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
if project.status != "active":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
)
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
if not user:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
)
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
if not membership_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
)
except SQLAlchemyError as exc:
logger.error(
"Metadata DB error while resolving project context",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Metadata database error: {exc}",
) from exc
return ProjectContext(
project_id=project.id,
user_id=user.id,
project_role=membership_role,
)
async def get_project_pg_session(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncSession, None]:
try:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_BIZ_DATA
)
except ValueError as exc:
logger.error(
"Invalid project PostgreSQL routing DSN configuration",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
) from exc
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL not configured",
)
if routing.db_type != DB_TYPE_POSTGRES:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
ctx.project_id,
DB_ROLE_BIZ_DATA,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with sessionmaker() as session:
yield session
async def get_project_pg_connection(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncConnection, None]:
try:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_BIZ_DATA
)
except ValueError as exc:
logger.error(
"Invalid project PostgreSQL routing DSN configuration",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
) from exc
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL not configured",
)
if routing.db_type != DB_TYPE_POSTGRES:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project PostgreSQL type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
pool = await project_connection_manager.get_pg_pool(
ctx.project_id,
DB_ROLE_BIZ_DATA,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with pool.connection() as conn:
yield conn
async def get_project_timescale_connection(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
) -> AsyncGenerator[AsyncConnection, None]:
try:
routing = await metadata_repo.get_project_db_routing(
ctx.project_id, DB_ROLE_IOT_DATA
)
except ValueError as exc:
logger.error(
"Invalid project TimescaleDB routing DSN configuration",
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project TimescaleDB routing DSN is invalid: {exc}",
) from exc
if not routing:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project TimescaleDB not configured",
)
if routing.db_type != DB_TYPE_TIMESCALE:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project TimescaleDB type mismatch",
)
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
pool = await project_connection_manager.get_timescale_pool(
ctx.project_id,
DB_ROLE_IOT_DATA,
routing.dsn,
pool_min_size,
pool_max_size,
)
async with pool.connection() as conn:
yield conn

View File

@@ -7,6 +7,7 @@
from typing import Optional
from datetime import datetime
import logging
from uuid import UUID
logger = logging.getLogger(__name__)
@@ -38,18 +39,16 @@ class AuditAction:
async def log_audit_event(
action: str,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_method: Optional[str] = None,
request_path: Optional[str] = None,
request_data: Optional[dict] = None,
response_status: Optional[int] = None,
error_message: Optional[str] = None,
db=None, # 新增:可选的数据库实例
session=None,
):
"""
记录审计日志
@@ -57,67 +56,60 @@ async def log_audit_event(
Args:
action: 操作类型
user_id: 用户ID
username: 用户名
project_id: 项目ID
resource_type: 资源类型
resource_id: 资源ID
ip_address: IP地址
user_agent: User-Agent
request_method: 请求方法
request_path: 请求路径
request_data: 请求数据(敏感字段需脱敏)
response_status: 响应状态码
error_message: 错误消息
db: 数据库实例(可选,如果不提供则尝试获取)
session: 元数据库会话(可选)
"""
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.audit_repository import AuditRepository
try:
# 脱敏敏感数据
if request_data:
request_data = sanitize_sensitive_data(request_data)
# 如果没有提供数据库实例,尝试从全局获取
if db is None:
try:
from app.infra.db.postgresql.database import db as default_db
# 仅当连接池已初始化时使用
if default_db.pool:
db = default_db
except ImportError:
pass
# 如果仍然没有数据库实例
if db is None:
# 在某些上下文中可能无法获取,此时静默失败
logger.warning("No database instance provided for audit logging")
return
audit_repo = AuditRepository(db)
if request_data:
request_data = sanitize_sensitive_data(request_data)
if session is None:
async with SessionLocal() as session:
audit_repo = AuditRepository(session)
await audit_repo.create_log(
user_id=user_id,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
)
else:
audit_repo = AuditRepository(session)
await audit_repo.create_log(
user_id=user_id,
username=username,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
error_message=error_message,
)
logger.info(
f"Audit log created: action={action}, user={username or user_id}, "
f"resource={resource_type}:{resource_id}"
)
except Exception as e:
# 审计日志失败不应影响业务流程
logger.error(f"Failed to create audit log: {e}", exc_info=True)
logger.info(
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
action,
user_id,
project_id,
resource_type,
resource_id,
)
def sanitize_sensitive_data(data: dict) -> dict:

View File

@@ -1,4 +1,6 @@
from pydantic_settings import BaseSettings
from pathlib import Path
from urllib.parse import quote_plus
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
@@ -15,6 +17,7 @@ class Settings(BaseSettings):
# 数据加密密钥 (使用 Fernet)
ENCRYPTION_KEY: str = "" # 必须从环境变量设置
DATABASE_ENCRYPTION_KEY: str = "" # project_databases.dsn_encrypted 专用
# Database Config (PostgreSQL)
DB_NAME: str = "tjwater"
@@ -35,13 +38,45 @@ class Settings(BaseSettings):
INFLUXDB_ORG: str = "org"
INFLUXDB_BUCKET: str = "bucket"
# Metadata Database Config (PostgreSQL)
METADATA_DB_NAME: str = "system_hub"
METADATA_DB_HOST: str = "localhost"
METADATA_DB_PORT: str = "5432"
METADATA_DB_USER: str = "postgres"
METADATA_DB_PASSWORD: str = "password"
METADATA_DB_POOL_SIZE: int = 5
METADATA_DB_MAX_OVERFLOW: int = 10
PROJECT_PG_CACHE_SIZE: int = 50
PROJECT_TS_CACHE_SIZE: int = 50
PROJECT_PG_POOL_SIZE: int = 5
PROJECT_PG_MAX_OVERFLOW: int = 10
PROJECT_TS_POOL_MIN_SIZE: int = 1
PROJECT_TS_POOL_MAX_SIZE: int = 10
# Keycloak JWT (optional override)
KEYCLOAK_PUBLIC_KEY: str = ""
KEYCLOAK_ALGORITHM: str = "RS256"
KEYCLOAK_AUDIENCE: str = ""
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
db_password = quote_plus(self.DB_PASSWORD)
return f"postgresql://{self.DB_USER}:{db_password}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
class Config:
env_file = ".env"
extra = "ignore"
@property
def METADATA_DATABASE_URI(self) -> str:
metadata_password = quote_plus(self.METADATA_DB_PASSWORD)
return (
f"postgresql+psycopg://{self.METADATA_DB_USER}:{metadata_password}"
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
)
model_config = SettingsConfigDict(
env_file=Path(__file__).resolve().parents[2] / ".env",
extra="ignore",
)
settings = Settings()

View File

@@ -3,75 +3,94 @@ from typing import Optional
import base64
import os
from app.core.config import settings
class Encryptor:
"""
使用 Fernet (对称加密) 实现数据加密/解密
适用于加密敏感配置、用户数据等
"""
def __init__(self, key: Optional[bytes] = None):
"""
初始化加密器
Args:
key: 加密密钥,如果为 None 则从环境变量读取
"""
if key is None:
key_str = os.getenv("ENCRYPTION_KEY")
key_str = os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY
if not key_str:
raise ValueError(
"ENCRYPTION_KEY not found in environment variables. "
"ENCRYPTION_KEY not found in environment variables or .env. "
"Generate one using: Encryptor.generate_key()"
)
key = key_str.encode()
self.fernet = Fernet(key)
def encrypt(self, data: str) -> str:
"""
加密字符串
Args:
data: 待加密的明文字符串
Returns:
Base64 编码的加密字符串
"""
if not data:
return data
encrypted_bytes = self.fernet.encrypt(data.encode())
return encrypted_bytes.decode()
def decrypt(self, data: str) -> str:
"""
解密字符串
Args:
data: Base64 编码的加密字符串
Returns:
解密后的明文字符串
"""
if not data:
return data
decrypted_bytes = self.fernet.decrypt(data.encode())
return decrypted_bytes.decode()
@staticmethod
def generate_key() -> str:
"""
生成新的 Fernet 加密密钥
Returns:
Base64 编码的密钥字符串
"""
key = Fernet.generate_key()
return key.decode()
# 全局加密器实例(懒加载)
_encryptor: Optional[Encryptor] = None
_database_encryptor: Optional[Encryptor] = None
def is_encryption_configured() -> bool:
return bool(os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY)
def is_database_encryption_configured() -> bool:
return bool(
os.getenv("DATABASE_ENCRYPTION_KEY")
or settings.DATABASE_ENCRYPTION_KEY
or os.getenv("ENCRYPTION_KEY")
or settings.ENCRYPTION_KEY
)
def get_encryptor() -> Encryptor:
"""获取全局加密器实例"""
@@ -80,6 +99,26 @@ def get_encryptor() -> Encryptor:
_encryptor = Encryptor()
return _encryptor
def get_database_encryptor() -> Encryptor:
"""获取 project DB DSN 专用加密器实例"""
global _database_encryptor
if _database_encryptor is None:
key_str = (
os.getenv("DATABASE_ENCRYPTION_KEY")
or settings.DATABASE_ENCRYPTION_KEY
or os.getenv("ENCRYPTION_KEY")
or settings.ENCRYPTION_KEY
)
if not key_str:
raise ValueError(
"DATABASE_ENCRYPTION_KEY not found in environment variables or .env. "
"Generate one using: Encryptor.generate_key()"
)
_database_encryptor = Encryptor(key=key_str.encode())
return _database_encryptor
# 向后兼容(延迟加载)
def __getattr__(name):
if name == "encryptor":

View File

@@ -1,45 +1,42 @@
from datetime import datetime
from typing import Optional, Any
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
class AuditLogCreate(BaseModel):
"""创建审计日志"""
user_id: Optional[int] = None
username: Optional[str] = None
user_id: Optional[UUID] = None
project_id: Optional[UUID] = None
action: str
resource_type: Optional[str] = None
resource_id: Optional[str] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
request_method: Optional[str] = None
request_path: Optional[str] = None
request_data: Optional[dict] = None
response_status: Optional[int] = None
error_message: Optional[str] = None
class AuditLogResponse(BaseModel):
"""审计日志响应"""
id: int
user_id: Optional[int]
username: Optional[str]
id: UUID
user_id: Optional[UUID]
project_id: Optional[UUID]
action: str
resource_type: Optional[str]
resource_id: Optional[str]
ip_address: Optional[str]
user_agent: Optional[str]
request_method: Optional[str]
request_path: Optional[str]
request_data: Optional[dict]
response_status: Optional[int]
error_message: Optional[str]
timestamp: datetime
model_config = ConfigDict(from_attributes=True)
class AuditLogQuery(BaseModel):
"""审计日志查询参数"""
user_id: Optional[int] = None
username: Optional[str] = None
user_id: Optional[UUID] = None
project_id: Optional[UUID] = None
action: Optional[str] = None
resource_type: Optional[str] = None
start_time: Optional[datetime] = None

View File

@@ -0,0 +1,33 @@
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class GeoServerConfigResponse(BaseModel):
gs_base_url: Optional[str]
gs_admin_user: Optional[str]
gs_datastore_name: str
default_extent: Optional[dict]
srid: int
class ProjectMetaResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str
geoserver: Optional[GeoServerConfigResponse]
class ProjectSummaryResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str

View File

@@ -6,12 +6,17 @@
import time
import json
from uuid import UUID
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.infra.db.postgresql.database import db as default_db
from app.core.audit import log_audit_event, AuditAction
import logging
from jose import JWTError, jwt
from app.core.config import settings
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
# 4. 提取审计所需信息
user_id = None
username = None
# 尝试从请求状态获取当前用户
if hasattr(request.state, "user"):
user = request.state.user
user_id = getattr(user, "id", None)
username = getattr(user, "username", None)
user_id = await self._resolve_user_id(request)
project_id = self._resolve_project_id(request)
# 获取客户端信息
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
# 确定操作类型
action = self._determine_action(request)
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
await log_audit_event(
action=action,
user_id=user_id,
username=username,
project_id=project_id,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_method=request.method,
request_path=str(request.url.path),
request_data=request_data,
response_status=response.status_code,
error_message=(
None
if response.status_code < 400
else f"HTTP {response.status_code}"
),
db=default_db,
)
except Exception as e:
# 审计失败不应影响响应
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id")
if not project_header:
return None
try:
return UUID(project_header)
except ValueError:
return None
async def _resolve_user_id(self, request: Request) -> UUID | None:
auth_header = request.headers.get("authorization")
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header.split(" ", 1)[1].strip()
if not token:
return None
try:
key = (
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
if settings.KEYCLOAK_PUBLIC_KEY
else settings.SECRET_KEY
)
algorithms = (
[settings.KEYCLOAK_ALGORITHM]
if settings.KEYCLOAK_PUBLIC_KEY
else [settings.ALGORITHM]
)
payload = jwt.decode(token, key, algorithms=algorithms)
sub = payload.get("sub")
if not sub:
return None
keycloak_id = UUID(sub)
except (JWTError, ValueError):
return None
async with SessionLocal() as session:
repo = MetadataRepository(session)
user = await repo.get_user_by_keycloak_id(keycloak_id)
if user and user.is_active:
return user.id
return None
def _determine_action(self, request: Request) -> str:
"""根据请求路径和方法确定操作类型"""
path = request.url.path.lower()

View File

@@ -0,0 +1,211 @@
import asyncio
import logging
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict
from uuid import UUID
from psycopg_pool import AsyncConnectionPool
from psycopg.rows import dict_row
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class PgEngineEntry:
engine: AsyncEngine
sessionmaker: async_sessionmaker[AsyncSession]
@dataclass(frozen=True)
class CacheKey:
project_id: UUID
db_role: str
class ProjectConnectionManager:
def __init__(self) -> None:
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_lock = asyncio.Lock()
self._ts_lock = asyncio.Lock()
self._pg_raw_lock = asyncio.Lock()
def _normalize_pg_url(self, url: str) -> str:
parsed = make_url(url)
if parsed.drivername == "postgresql":
parsed = parsed.set(drivername="postgresql+psycopg")
return str(parsed)
async def get_pg_sessionmaker(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> async_sessionmaker[AsyncSession]:
async with self._pg_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._pg_cache.get(key)
if entry:
self._pg_cache.move_to_end(key)
return entry.sessionmaker
normalized_url = self._normalize_pg_url(connection_url)
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
engine = create_async_engine(
normalized_url,
pool_size=pool_min_size,
max_overflow=max(0, pool_max_size - pool_min_size),
pool_pre_ping=True,
)
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
self._pg_cache[key] = PgEngineEntry(
engine=engine,
sessionmaker=sessionmaker,
)
await self._evict_pg_if_needed()
logger.info(
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
)
return sessionmaker
async def get_timescale_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._ts_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._ts_cache.get(key)
if pool:
self._ts_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
kwargs={"row_factory": dict_row},
)
await pool.open()
self._ts_cache[key] = pool
await self._evict_ts_if_needed()
logger.info(
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
)
return pool
async def get_pg_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._pg_raw_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._pg_raw_cache.get(key)
if pool:
self._pg_raw_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
kwargs={"row_factory": dict_row},
)
await pool.open()
self._pg_raw_cache[key] = pool
await self._evict_pg_raw_if_needed()
logger.info(
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
)
return pool
async def _evict_pg_if_needed(self) -> None:
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, entry = self._pg_cache.popitem(last=False)
await entry.engine.dispose()
logger.info(
"Evicted PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_ts_if_needed(self) -> None:
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
key, pool = self._ts_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_pg_raw_if_needed(self) -> None:
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, pool = self._pg_raw_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def close_all(self) -> None:
async with self._pg_lock:
for key, entry in list(self._pg_cache.items()):
await entry.engine.dispose()
logger.info(
"Closed PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_cache.clear()
async with self._ts_lock:
for key, pool in list(self._ts_cache.items()):
await pool.close()
logger.info(
"Closed TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._ts_cache.clear()
async with self._pg_raw_lock:
for key, pool in list(self._pg_raw_cache.items()):
await pool.close()
logger.info(
"Closed PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_raw_cache.clear()
project_connection_manager = ProjectConnectionManager()

View File

@@ -404,8 +404,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
Point("link")
.tag("date", None)
.tag("ID", None)
.tag("scheme_Type", None)
.tag("scheme_Name", None)
.tag("scheme_type", None)
.tag("scheme_name", None)
.field("flow", 0.0)
.field("leakage", 0.0)
.field("velocity", 0.0)
@@ -420,8 +420,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
Point("node")
.tag("date", None)
.tag("ID", None)
.tag("scheme_Type", None)
.tag("scheme_Name", None)
.tag("scheme_type", None)
.tag("scheme_name", None)
.field("head", 0.0)
.field("pressure", 0.0)
.field("actualdemand", 0.0)
@@ -436,8 +436,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
.tag("date", None)
.tag("description", None)
.tag("device_ID", None)
.tag("scheme_Type", None)
.tag("scheme_Name", None)
.tag("scheme_type", None)
.tag("scheme_name", None)
.field("monitored_value", 0.0)
.field("datacleaning_value", 0.0)
.field("scheme_simulation_value", 0.0)
@@ -1811,8 +1811,8 @@ def query_SCADA_data_by_device_ID_and_time(
def query_scheme_SCADA_data_by_device_ID_and_time(
query_ids_list: List[str],
query_time: str,
scheme_Type: str,
scheme_Name: str,
scheme_type: str,
scheme_name: str,
bucket: str = "scheme_simulation_result",
) -> Dict[str, float]:
"""
@@ -1843,7 +1843,7 @@ def query_scheme_SCADA_data_by_device_ID_and_time(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|> filter(fn: (r) => r["device_ID"] == "{device_id}" and r["_field"] == "monitored_value" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|> filter(fn: (r) => r["device_ID"] == "{device_id}" and r["_field"] == "monitored_value" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
"""
# 执行查询
try:
@@ -2585,7 +2585,7 @@ def query_all_scheme_record_by_time_property(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
"""
# 执行查询
tables = query_api.query(flux_query)
@@ -2635,7 +2635,7 @@ def query_scheme_simulation_result_by_ID_time(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "node" and r["ID"] == "{ID}")
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" and r["ID"] == "{ID}")
|> pivot(
rowKey:["_time"],
columnKey:["_field"],
@@ -2660,7 +2660,7 @@ def query_scheme_simulation_result_by_ID_time(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "link" and r["ID"] == "{ID}")
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "link" and r["ID"] == "{ID}")
|> pivot(
rowKey:["_time"],
columnKey:["_field"],
@@ -3227,8 +3227,8 @@ def store_scheme_simulation_result_to_influxdb(
link_result_list: List[Dict[str, any]],
scheme_start_time: str,
num_periods: int = 1,
scheme_Type: str = None,
scheme_Name: str = None,
scheme_type: str = None,
scheme_name: str = None,
bucket: str = "scheme_simulation_result",
):
"""
@@ -3237,8 +3237,8 @@ def store_scheme_simulation_result_to_influxdb(
:param link_result_list: (List[Dict[str, any]]): 包含连接和结果数据的字典列表。
:param scheme_start_time: (str): 方案模拟开始时间。
:param num_periods: (int): 方案模拟的周期数
:param scheme_Type: (str): 方案类型
:param scheme_Name: (str): 方案名称
:param scheme_type: (str): 方案类型
:param scheme_name: (str): 方案名称
:param bucket: (str): InfluxDB 的 bucket 名称,默认值为 "scheme_simulation_result"
:return:
"""
@@ -3298,8 +3298,8 @@ def store_scheme_simulation_result_to_influxdb(
Point("node")
.tag("date", date_str)
.tag("ID", node_id)
.tag("scheme_Type", scheme_Type)
.tag("scheme_Name", scheme_Name)
.tag("scheme_type", scheme_type)
.tag("scheme_name", scheme_name)
.field("head", data.get("head", 0.0))
.field("pressure", data.get("pressure", 0.0))
.field("actualdemand", data.get("demand", 0.0))
@@ -3322,8 +3322,8 @@ def store_scheme_simulation_result_to_influxdb(
Point("link")
.tag("date", date_str)
.tag("ID", link_id)
.tag("scheme_Type", scheme_Type)
.tag("scheme_Name", scheme_Name)
.tag("scheme_type", scheme_type)
.tag("scheme_name", scheme_name)
.field("flow", data.get("flow", 0.0))
.field("velocity", data.get("velocity", 0.0))
.field("headloss", data.get("headloss", 0.0))
@@ -3409,14 +3409,14 @@ def query_corresponding_query_id_and_element_id(name: str) -> None:
# 2025/03/11
def fill_scheme_simulation_result_to_SCADA(
scheme_Type: str = None,
scheme_Name: str = None,
scheme_type: str = None,
scheme_name: str = None,
query_date: str = None,
bucket: str = "scheme_simulation_result",
):
"""
:param scheme_Type: 方案类型
:param scheme_Name: 方案名称
:param scheme_type: 方案类型
:param scheme_name: 方案名称
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
:param bucket: InfluxDB 的 bucket 名称,默认值为 "scheme_simulation_result"
:return:
@@ -3457,8 +3457,8 @@ def fill_scheme_simulation_result_to_SCADA(
# 查找associated_element_id的对应值
for key, value in globals.scheme_source_outflow_ids.items():
scheme_source_outflow_result = query_scheme_curve_by_ID_property(
scheme_Type=scheme_Type,
scheme_Name=scheme_Name,
scheme_type=scheme_type,
scheme_name=scheme_name,
query_date=query_date,
ID=value,
type="link",
@@ -3470,8 +3470,8 @@ def fill_scheme_simulation_result_to_SCADA(
Point("scheme_source_outflow")
.tag("date", query_date)
.tag("device_ID", key)
.tag("scheme_Type", scheme_Type)
.tag("scheme_Name", scheme_Name)
.tag("scheme_type", scheme_type)
.tag("scheme_name", scheme_name)
.field("monitored_value", data["value"])
.time(data["time"], write_precision="s")
)
@@ -3480,8 +3480,8 @@ def fill_scheme_simulation_result_to_SCADA(
for key, value in globals.scheme_pipe_flow_ids.items():
scheme_pipe_flow_result = query_scheme_curve_by_ID_property(
scheme_Type=scheme_Type,
scheme_Name=scheme_Name,
scheme_type=scheme_type,
scheme_name=scheme_name,
query_date=query_date,
ID=value,
type="link",
@@ -3492,8 +3492,8 @@ def fill_scheme_simulation_result_to_SCADA(
Point("scheme_pipe_flow")
.tag("date", query_date)
.tag("device_ID", key)
.tag("scheme_Type", scheme_Type)
.tag("scheme_Name", scheme_Name)
.tag("scheme_type", scheme_type)
.tag("scheme_name", scheme_name)
.field("monitored_value", data["value"])
.time(data["time"], write_precision="s")
)
@@ -3502,8 +3502,8 @@ def fill_scheme_simulation_result_to_SCADA(
for key, value in globals.scheme_pressure_ids.items():
scheme_pressure_result = query_scheme_curve_by_ID_property(
scheme_Type=scheme_Type,
scheme_Name=scheme_Name,
scheme_type=scheme_type,
scheme_name=scheme_name,
query_date=query_date,
ID=value,
type="node",
@@ -3514,8 +3514,8 @@ def fill_scheme_simulation_result_to_SCADA(
Point("scheme_pressure")
.tag("date", query_date)
.tag("device_ID", key)
.tag("scheme_Type", scheme_Type)
.tag("scheme_Name", scheme_Name)
.tag("scheme_type", scheme_type)
.tag("scheme_name", scheme_name)
.field("monitored_value", data["value"])
.time(data["time"], write_precision="s")
)
@@ -3524,8 +3524,8 @@ def fill_scheme_simulation_result_to_SCADA(
for key, value in globals.scheme_demand_ids.items():
scheme_demand_result = query_scheme_curve_by_ID_property(
scheme_Type=scheme_Type,
scheme_Name=scheme_Name,
scheme_type=scheme_type,
scheme_name=scheme_name,
query_date=query_date,
ID=value,
type="node",
@@ -3536,8 +3536,8 @@ def fill_scheme_simulation_result_to_SCADA(
Point("scheme_demand")
.tag("date", query_date)
.tag("device_ID", key)
.tag("scheme_Type", scheme_Type)
.tag("scheme_Name", scheme_Name)
.tag("scheme_type", scheme_type)
.tag("scheme_name", scheme_name)
.field("monitored_value", data["value"])
.time(data["time"], write_precision="s")
)
@@ -3546,8 +3546,8 @@ def fill_scheme_simulation_result_to_SCADA(
for key, value in globals.scheme_quality_ids.items():
scheme_quality_result = query_scheme_curve_by_ID_property(
scheme_Type=scheme_Type,
scheme_Name=scheme_Name,
scheme_type=scheme_type,
scheme_name=scheme_name,
query_date=query_date,
ID=value,
type="node",
@@ -3558,8 +3558,8 @@ def fill_scheme_simulation_result_to_SCADA(
Point("scheme_quality")
.tag("date", query_date)
.tag("device_ID", key)
.tag("scheme_Type", scheme_Type)
.tag("scheme_Name", scheme_Name)
.tag("scheme_type", scheme_type)
.tag("scheme_name", scheme_name)
.field("monitored_value", data["value"])
.time(data["time"], write_precision="s")
)
@@ -3629,15 +3629,15 @@ def query_SCADA_data_curve(
# 2025/02/18
def query_scheme_all_record_by_time(
scheme_Type: str,
scheme_Name: str,
scheme_type: str,
scheme_name: str,
query_time: str,
bucket: str = "scheme_simulation_result",
) -> tuple:
"""
查询指定方案某一时刻的所有记录包括node'link分别以指定格式返回。
:param scheme_Type: 方案类型
:param scheme_Name: 方案名称
:param scheme_type: 方案类型
:param scheme_name: 方案名称
:param query_time: 输入的北京时间,格式为 '2024-11-24T17:30:00+08:00'
:param bucket: 数据存储的 bucket 名称。
:return: dict: tuple: (node_records, link_records)
@@ -3660,7 +3660,7 @@ def query_scheme_all_record_by_time(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|> pivot(
rowKey:["_time"],
columnKey:["_field"],
@@ -3710,8 +3710,8 @@ def query_scheme_all_record_by_time(
# 2025/03/04
def query_scheme_all_record_by_time_property(
scheme_Type: str,
scheme_Name: str,
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
@@ -3719,8 +3719,8 @@ def query_scheme_all_record_by_time_property(
) -> list:
"""
查询指定方案某一时刻node'link某一属性值以指定格式返回。
:param scheme_Type: 方案类型
:param scheme_Name: 方案名称
:param scheme_type: 方案类型
:param scheme_name: 方案名称
:param query_time: 输入的北京时间,格式为 '2024-11-24T17:30:00+08:00'
:param type: 查询的类型(决定 measurement
:param property: 查询的字段名称field
@@ -3752,7 +3752,7 @@ def query_scheme_all_record_by_time_property(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
"""
# 执行查询
tables = query_api.query(flux_query)
@@ -3767,8 +3767,8 @@ def query_scheme_all_record_by_time_property(
# 2025/02/19
def query_scheme_curve_by_ID_property(
scheme_Type: str,
scheme_Name: str,
scheme_type: str,
scheme_name: str,
query_date: str,
ID: str,
type: str,
@@ -3776,9 +3776,9 @@ def query_scheme_curve_by_ID_property(
bucket: str = "scheme_simulation_result",
) -> list:
"""
根据scheme_Type和scheme_Name,查询该模拟方案中某一node或link的某一属性值的所有时间的结果
:param scheme_Type: 方案类型
:param scheme_Name: 方案名称
根据scheme_Type和scheme_name,查询该模拟方案中某一node或link的某一属性值的所有时间的结果
:param scheme_type: 方案类型
:param scheme_name: 方案名称
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
:param ID: 元素的ID
:param type: 元素的类型node或link
@@ -3817,7 +3817,7 @@ def query_scheme_curve_by_ID_property(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {start_time}, stop: {stop_time})
|> filter(fn: (r) => r["_measurement"] == "{measurement}" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["ID"] == "{ID}" and r["_field"] == "{property}")
|> filter(fn: (r) => r["_measurement"] == "{measurement}" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["ID"] == "{ID}" and r["_field"] == "{property}")
"""
# 执行查询
tables = query_api.query(flux_query)
@@ -3832,15 +3832,15 @@ def query_scheme_curve_by_ID_property(
# 2025/02/21
def query_scheme_all_record(
scheme_Type: str,
scheme_Name: str,
scheme_type: str,
scheme_name: str,
query_date: str,
bucket: str = "scheme_simulation_result",
) -> tuple:
"""
查询指定方案的所有记录包括node'link分别以指定格式返回。
:param scheme_Type: 方案类型
:param scheme_Name: 方案名称
:param scheme_type: 方案类型
:param scheme_name: 方案名称
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
:param bucket: 数据存储的 bucket 名称。
:return: dict: tuple: (node_records, link_records)
@@ -3867,7 +3867,7 @@ def query_scheme_all_record(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|> pivot(
rowKey:["_time"],
columnKey:["_field"],
@@ -3917,8 +3917,8 @@ def query_scheme_all_record(
# 2025/03/04
def query_scheme_all_record_property(
scheme_Type: str,
scheme_Name: str,
scheme_type: str,
scheme_name: str,
query_date: str,
type: str,
property: str,
@@ -3926,8 +3926,8 @@ def query_scheme_all_record_property(
) -> list:
"""
查询指定方案的node'link的某一属性值以指定格式返回。
:param scheme_Type: 方案类型
:param scheme_Name: 方案名称
:param scheme_type: 方案类型
:param scheme_name: 方案名称
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
:param type: 查询的类型(决定 measurement
:param property: 查询的字段名称field
@@ -3964,7 +3964,7 @@ def query_scheme_all_record_property(
flux_query = f"""
from(bucket: "{bucket}")
|> range(start: {start_time}, stop: {stop_time})
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["date"] == "{query_date}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["date"] == "{query_date}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
"""
# 执行查询
tables = query_api.query(flux_query)
@@ -4245,8 +4245,8 @@ def export_scheme_simulation_result_to_csv_time(
link_data[key][field] = record.get_value()
link_data[key]["measurement"] = record.get_measurement()
link_data[key]["date"] = record.values.get("date", None)
link_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
link_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
link_data[key]["scheme_type"] = record.values.get("scheme_type", None)
link_data[key]["scheme_name"] = record.values.get("scheme_name", None)
# 构建 Flux 查询语句,查询指定时间范围内的数据
flux_query_node = f"""
from(bucket: "{bucket}")
@@ -4267,8 +4267,8 @@ def export_scheme_simulation_result_to_csv_time(
node_data[key][field] = record.get_value()
node_data[key]["measurement"] = record.get_measurement()
node_data[key]["date"] = record.values.get("date", None)
node_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
node_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
node_data[key]["scheme_type"] = record.values.get("scheme_type", None)
node_data[key]["scheme_name"] = record.values.get("scheme_name", None)
for key in set(link_data.keys()):
row = {"time": key[0], "ID": key[1]}
row.update(link_data.get(key, {}))
@@ -4288,8 +4288,8 @@ def export_scheme_simulation_result_to_csv_time(
"time",
"measurement",
"date",
"scheme_Type",
"scheme_Name",
"scheme_type",
"scheme_name",
"ID",
"flow",
"leakage",
@@ -4311,8 +4311,8 @@ def export_scheme_simulation_result_to_csv_time(
"time",
"measurement",
"date",
"scheme_Type",
"scheme_Name",
"scheme_type",
"scheme_name",
"ID",
"head",
"pressure",
@@ -4330,15 +4330,15 @@ def export_scheme_simulation_result_to_csv_time(
# 2025/02/18
def export_scheme_simulation_result_to_csv_scheme(
scheme_Type: str,
scheme_Name: str,
scheme_type: str,
scheme_name: str,
query_date: str,
bucket: str = "scheme_simulation_result",
) -> None:
"""
导出influxdb中scheme_simulation_result这个bucket的数据到csv中
:param scheme_Type: 查询的方案类型
:param scheme_Name: 查询的方案名
:param scheme_type: 查询的方案类型
:param scheme_name: 查询的方案名
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
:param bucket: 数据存储的 bucket 名称,默认值为 "SCADA_data"
:return:
@@ -4366,7 +4366,7 @@ def export_scheme_simulation_result_to_csv_scheme(
flux_query_link = f"""
from(bucket: "{bucket}")
|> range(start: {start_time}, stop: {stop_time})
|> filter(fn: (r) => r["_measurement"] == "link" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|> filter(fn: (r) => r["_measurement"] == "link" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
"""
# 执行查询
link_tables = query_api.query(flux_query_link)
@@ -4382,13 +4382,13 @@ def export_scheme_simulation_result_to_csv_scheme(
link_data[key][field] = record.get_value()
link_data[key]["measurement"] = record.get_measurement()
link_data[key]["date"] = record.values.get("date", None)
link_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
link_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
link_data[key]["scheme_type"] = record.values.get("scheme_type", None)
link_data[key]["scheme_name"] = record.values.get("scheme_name", None)
# 构建 Flux 查询语句,查询指定时间范围内的数据
flux_query_node = f"""
from(bucket: "{bucket}")
|> range(start: {start_time}, stop: {stop_time})
|> filter(fn: (r) => r["_measurement"] == "node" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|> filter(fn: (r) => r["_measurement"] == "node" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
"""
# 执行查询
node_tables = query_api.query(flux_query_node)
@@ -4404,8 +4404,8 @@ def export_scheme_simulation_result_to_csv_scheme(
node_data[key][field] = record.get_value()
node_data[key]["measurement"] = record.get_measurement()
node_data[key]["date"] = record.values.get("date", None)
node_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
node_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
node_data[key]["scheme_type"] = record.values.get("scheme_type", None)
node_data[key]["scheme_name"] = record.values.get("scheme_name", None)
for key in set(link_data.keys()):
row = {"time": key[0], "ID": key[1]}
row.update(link_data.get(key, {}))
@@ -4416,10 +4416,10 @@ def export_scheme_simulation_result_to_csv_scheme(
node_rows.append(row)
# 动态生成 CSV 文件名
csv_filename_link = (
f"scheme_simulation_link_result_{scheme_Name}_of_{scheme_Type}.csv"
f"scheme_simulation_link_result_{scheme_name}_of_{scheme_type}.csv"
)
csv_filename_node = (
f"scheme_simulation_node_result_{scheme_Name}_of_{scheme_Type}.csv"
f"scheme_simulation_node_result_{scheme_name}_of_{scheme_type}.csv"
)
# 写入到 CSV 文件
with open(csv_filename_link, mode="w", newline="") as file:
@@ -4429,8 +4429,8 @@ def export_scheme_simulation_result_to_csv_scheme(
"time",
"measurement",
"date",
"scheme_Type",
"scheme_Name",
"scheme_type",
"scheme_name",
"ID",
"flow",
"leakage",
@@ -4452,8 +4452,8 @@ def export_scheme_simulation_result_to_csv_scheme(
"time",
"measurement",
"date",
"scheme_Type",
"scheme_Name",
"scheme_type",
"scheme_name",
"ID",
"head",
"pressure",
@@ -4878,15 +4878,15 @@ if __name__ == "__main__":
# export_scheme_simulation_result_to_csv_time(start_date='2025-02-13', end_date='2025-02-15')
# 示例9export_scheme_simulation_result_to_csv_scheme
# export_scheme_simulation_result_to_csv_scheme(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10')
# export_scheme_simulation_result_to_csv_scheme(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10')
# 示例10query_scheme_all_record_by_time
# node_records, link_records = query_scheme_all_record_by_time(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_time="2025-02-14T10:30:00+08:00")
# node_records, link_records = query_scheme_all_record_by_time(scheme_type='burst_Analysis', scheme_name='scheme1', query_time="2025-02-14T10:30:00+08:00")
# print("Node 数据:", node_records)
# print("Link 数据:", link_records)
# 示例11query_scheme_curve_by_ID_property
# curve_result = query_scheme_curve_by_ID_property(scheme_Type='burst_Analysis', scheme_Name='scheme1', ID='ZBBDTZDP000022',
# curve_result = query_scheme_curve_by_ID_property(scheme_type='burst_Analysis', scheme_name='scheme1', ID='ZBBDTZDP000022',
# type='node', property='head')
# print(curve_result)
@@ -4896,7 +4896,7 @@ if __name__ == "__main__":
# print("Link 数据:", link_records)
# 示例13query_scheme_all_record
# node_records, link_records = query_scheme_all_record(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10')
# node_records, link_records = query_scheme_all_record(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10')
# print("Node 数据:", node_records)
# print("Link 数据:", link_records)
@@ -4909,16 +4909,16 @@ if __name__ == "__main__":
# print(result_records)
# 示例16query_scheme_all_record_by_time_property
# result_records = query_scheme_all_record_by_time_property(scheme_Type='burst_Analysis', scheme_Name='scheme1',
# result_records = query_scheme_all_record_by_time_property(scheme_type='burst_Analysis', scheme_name='scheme1',
# query_time='2025-02-14T10:30:00+08:00', type='node', property='head')
# print(result_records)
# 示例17query_scheme_all_record_property
# result_records = query_scheme_all_record_property(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10', type='node', property='head')
# result_records = query_scheme_all_record_property(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10', type='node', property='head')
# print(result_records)
# 示例18fill_scheme_simulation_result_to_SCADA
# fill_scheme_simulation_result_to_SCADA(scheme_Type='burst_Analysis', scheme_Name='burst0330', query_date='2025-03-30')
# fill_scheme_simulation_result_to_SCADA(scheme_type='burst_Analysis', scheme_name='burst0330', query_date='2025-03-30')
# 示例19query_SCADA_data_by_device_ID_and_timerange
# result = query_SCADA_data_by_device_ID_and_timerange(query_ids_list=globals.pressure_non_realtime_ids, start_time='2025-04-16T00:00:00+08:00',
@@ -4926,7 +4926,7 @@ if __name__ == "__main__":
# print(result)
# 示例manually_get_burst_flow
# leakage = manually_get_burst_flow(scheme_Type='burst_Analysis', scheme_Name='burst_scheme', scheme_start_time='2025-03-10T12:00:00+08:00')
# leakage = manually_get_burst_flow(scheme_type='burst_Analysis', scheme_name='burst_scheme', scheme_start_time='2025-03-10T12:00:00+08:00')
# print(leakage)
# 示例upload_cleaned_SCADA_data_to_influxdb

View File

@@ -0,0 +1,3 @@
from .database import get_metadata_session, close_metadata_engine
__all__ = ["get_metadata_session", "close_metadata_engine"]

View File

@@ -0,0 +1,27 @@
import logging
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings
logger = logging.getLogger(__name__)
engine = create_async_engine(
settings.METADATA_DATABASE_URI,
pool_size=settings.METADATA_DB_POOL_SIZE,
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
pool_pre_ping=True,
)
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
async with SessionLocal() as session:
yield session
async def close_metadata_engine() -> None:
await engine.dispose()
logger.info("Metadata database engine disposed.")

View File

@@ -0,0 +1,115 @@
from datetime import datetime
from uuid import UUID
from sqlalchemy import Boolean, DateTime, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
keycloak_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), unique=True, index=True
)
username: Mapped[str] = mapped_column(String(50), unique=True)
email: Mapped[str] = mapped_column(String(100), unique=True)
role: Mapped[str] = mapped_column(String(20), default="user")
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
last_login_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
class Project(Base):
__tablename__ = "projects"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
name: Mapped[str] = mapped_column(String(100))
code: Mapped[str] = mapped_column(String(50), unique=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
status: Mapped[str] = mapped_column(String(20), default="active")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
class ProjectDatabase(Base):
__tablename__ = "project_databases"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
db_role: Mapped[str] = mapped_column(String(20))
db_type: Mapped[str] = mapped_column(String(20))
dsn_encrypted: Mapped[str] = mapped_column(Text)
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
class ProjectGeoServerConfig(Base):
__tablename__ = "project_geoserver_configs"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
project_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), unique=True, index=True
)
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
Text, nullable=True
)
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
srid: Mapped[int] = mapped_column(Integer, default=4326)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)
class UserProjectMembership(Base):
__tablename__ = "user_project_membership"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), nullable=True, index=True
)
project_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), nullable=True, index=True
)
action: Mapped[str] = mapped_column(String(50))
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
timestamp: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
)

View File

@@ -17,7 +17,14 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = postgresql_info.get_pgconn_string()
target_db_name = db_name or self.db_name
# Get connection string, handling default case where target_db_name might be None
if target_db_name:
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
else:
conn_string = postgresql_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -26,7 +33,7 @@ class Database:
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(f"PostgreSQL connection pool initialized for database: default")
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")
raise

View File

@@ -1,24 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from psycopg import AsyncConnection
from .database import get_database_instance
from .scada_info import ScadaRepository
from .scheme import SchemeRepository
from app.auth.project_dependencies import get_project_pg_connection
router = APIRouter()
# 创建支持数据库选择的连接依赖函数
# 动态项目 PostgreSQL 连接依赖
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_pg_connection),
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
@router.get("/scada-info")

View File

@@ -17,7 +17,14 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = timescaledb_info.get_pgconn_string()
target_db_name = db_name or self.db_name
# Get connection string, handling default case where target_db_name might be None
if target_db_name:
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
else:
conn_string = timescaledb_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -27,7 +34,7 @@ class Database:
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}"
)
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
@@ -46,7 +53,9 @@ class Database:
def get_pgconn_string(self, db_name=None):
"""Get the TimescaleDB connection string."""
target_db_name = db_name or self.db_name
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
if target_db_name:
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
return timescaledb_info.get_pgconn_string()
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:

View File

@@ -1,40 +1,32 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from typing import List
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
from .schemas.scheme import SchemeRepository
from .schemas.scada import ScadaRepository
from .composite_queries import CompositeQueries
from app.infra.db.postgresql.database import get_database_instance as get_postgres_database_instance
from app.auth.project_dependencies import (
get_project_pg_connection,
get_project_timescale_connection,
)
router = APIRouter()
# 创建支持数据库选择的连接依赖函数
# 动态项目 TimescaleDB 连接依赖
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_timescale_connection),
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
# PostgreSQL 数据库连接依赖函数
# 动态项目 PostgreSQL 连接依赖
async def get_postgres_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
)
conn: AsyncConnection = Depends(get_project_pg_connection),
):
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_postgres_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
yield conn
# --- Realtime Endpoints ---
@@ -152,7 +144,7 @@ async def query_realtime_simulation_by_id_time(
results = await RealtimeRepository.query_simulation_result_by_id_time(
conn, id, type, query_time
)
return {"results": results}
return {"result": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -1,220 +1,112 @@
from typing import Optional, List
from datetime import datetime
import json
from app.infra.db.postgresql.database import Database
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
import logging
from typing import Optional, List
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.domain.schemas.audit import AuditLogResponse
from app.infra.db.metadata import models
logger = logging.getLogger(__name__)
class AuditRepository:
"""审计日志数据访问层"""
def __init__(self, db: Database):
self.db = db
"""审计日志数据访问层system_hub"""
def __init__(self, session: AsyncSession):
self.session = session
async def create_log(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
action: str = "",
action: str,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_method: Optional[str] = None,
request_path: Optional[str] = None,
request_data: Optional[dict] = None,
response_status: Optional[int] = None,
error_message: Optional[str] = None
) -> Optional[AuditLogResponse]:
"""
创建审计日志
Args:
参数说明见 AuditLogCreate
Returns:
创建的审计日志对象
"""
query = """
INSERT INTO audit_logs (
user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message
)
VALUES (
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
%(request_data)s, %(response_status)s, %(error_message)s
)
RETURNING id, user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message, timestamp
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, {
'user_id': user_id,
'username': username,
'action': action,
'resource_type': resource_type,
'resource_id': resource_id,
'ip_address': ip_address,
'user_agent': user_agent,
'request_method': request_method,
'request_path': request_path,
'request_data': json.dumps(request_data) if request_data else None,
'response_status': response_status,
'error_message': error_message
})
row = await cur.fetchone()
if row:
return AuditLogResponse(**row)
except Exception as e:
logger.error(f"Error creating audit log: {e}")
raise
return None
) -> AuditLogResponse:
log = models.AuditLog(
user_id=user_id,
project_id=project_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
request_method=request_method,
request_path=request_path,
request_data=request_data,
response_status=response_status,
timestamp=datetime.utcnow(),
)
self.session.add(log)
await self.session.commit()
await self.session.refresh(log)
return AuditLogResponse.model_validate(log)
async def get_logs(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
action: Optional[str] = None,
resource_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
skip: int = 0,
limit: int = 100
limit: int = 100,
) -> List[AuditLogResponse]:
"""
查询审计日志
Args:
user_id: 用户ID过滤
username: 用户名过滤
action: 操作类型过滤
resource_type: 资源类型过滤
start_time: 开始时间
end_time: 结束时间
skip: 跳过记录数
limit: 限制记录数
Returns:
审计日志列表
"""
# 构建动态查询
conditions = []
params = {'skip': skip, 'limit': limit}
if user_id is not None:
conditions.append("user_id = %(user_id)s")
params['user_id'] = user_id
if username:
conditions.append("username = %(username)s")
params['username'] = username
conditions.append(models.AuditLog.user_id == user_id)
if project_id is not None:
conditions.append(models.AuditLog.project_id == project_id)
if action:
conditions.append("action = %(action)s")
params['action'] = action
conditions.append(models.AuditLog.action == action)
if resource_type:
conditions.append("resource_type = %(resource_type)s")
params['resource_type'] = resource_type
conditions.append(models.AuditLog.resource_type == resource_type)
if start_time:
conditions.append("timestamp >= %(start_time)s")
params['start_time'] = start_time
conditions.append(models.AuditLog.timestamp >= start_time)
if end_time:
conditions.append("timestamp <= %(end_time)s")
params['end_time'] = end_time
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"""
SELECT id, user_id, username, action, resource_type, resource_id,
ip_address, user_agent, request_method, request_path,
request_data, response_status, error_message, timestamp
FROM audit_logs
{where_clause}
ORDER BY timestamp DESC
LIMIT %(limit)s OFFSET %(skip)s
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
rows = await cur.fetchall()
return [AuditLogResponse(**row) for row in rows]
except Exception as e:
logger.error(f"Error querying audit logs: {e}")
raise
conditions.append(models.AuditLog.timestamp <= end_time)
stmt = (
select(models.AuditLog)
.where(*conditions)
.order_by(models.AuditLog.timestamp.desc())
.offset(skip)
.limit(limit)
)
result = await self.session.execute(stmt)
return [
AuditLogResponse.model_validate(log)
for log in result.scalars().all()
]
async def get_log_count(
self,
user_id: Optional[int] = None,
username: Optional[str] = None,
user_id: Optional[UUID] = None,
project_id: Optional[UUID] = None,
action: Optional[str] = None,
resource_type: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
end_time: Optional[datetime] = None,
) -> int:
"""
获取审计日志数量
Args:
参数同 get_logs
Returns:
日志总数
"""
conditions = []
params = {}
if user_id is not None:
conditions.append("user_id = %(user_id)s")
params['user_id'] = user_id
if username:
conditions.append("username = %(username)s")
params['username'] = username
conditions.append(models.AuditLog.user_id == user_id)
if project_id is not None:
conditions.append(models.AuditLog.project_id == project_id)
if action:
conditions.append("action = %(action)s")
params['action'] = action
conditions.append(models.AuditLog.action == action)
if resource_type:
conditions.append("resource_type = %(resource_type)s")
params['resource_type'] = resource_type
conditions.append(models.AuditLog.resource_type == resource_type)
if start_time:
conditions.append("timestamp >= %(start_time)s")
params['start_time'] = start_time
conditions.append(models.AuditLog.timestamp >= start_time)
if end_time:
conditions.append("timestamp <= %(end_time)s")
params['end_time'] = end_time
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"""
SELECT COUNT(*) as count
FROM audit_logs
{where_clause}
"""
try:
async with self.db.get_connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
result = await cur.fetchone()
return result['count'] if result else 0
except Exception as e:
logger.error(f"Error counting audit logs: {e}")
return 0
conditions.append(models.AuditLog.timestamp <= end_time)
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
result = await self.session.execute(stmt)
return int(result.scalar() or 0)

View File

@@ -0,0 +1,197 @@
from dataclasses import dataclass
from typing import Optional, List
from uuid import UUID
from cryptography.fernet import InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.encryption import (
get_database_encryptor,
get_encryptor,
is_database_encryption_configured,
is_encryption_configured,
)
from app.infra.db.metadata import models
def _normalize_postgres_dsn(dsn: str) -> str:
if not dsn or "://" not in dsn:
return dsn
scheme, rest = dsn.split("://", 1)
if scheme not in ("postgresql", "postgres", "postgresql+psycopg"):
return dsn
if "@" not in rest:
return dsn
userinfo, hostinfo = rest.rsplit("@", 1)
if ":" not in userinfo:
return dsn
username, password = userinfo.split(":", 1)
if "@" not in password:
return dsn
password = password.replace("@", "%40")
return f"{scheme}://{username}:{password}@{hostinfo}"
@dataclass(frozen=True)
class ProjectDbRouting:
project_id: UUID
db_role: str
db_type: str
dsn: str
pool_min_size: int
pool_max_size: int
@dataclass(frozen=True)
class ProjectGeoServerInfo:
project_id: UUID
gs_base_url: Optional[str]
gs_admin_user: Optional[str]
gs_admin_password: Optional[str]
gs_datastore_name: str
default_extent: Optional[dict]
srid: int
@dataclass(frozen=True)
class ProjectSummary:
project_id: UUID
name: str
code: str
description: Optional[str]
gs_workspace: str
status: str
project_role: str
class MetadataRepository:
"""元数据访问层system_hub"""
def __init__(self, session: AsyncSession):
self.session = session
async def get_user_by_keycloak_id(self, keycloak_id: UUID) -> Optional[models.User]:
result = await self.session.execute(
select(models.User).where(models.User.keycloak_id == keycloak_id)
)
return result.scalar_one_or_none()
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
result = await self.session.execute(
select(models.Project).where(models.Project.id == project_id)
)
return result.scalar_one_or_none()
async def get_membership_role(
self, project_id: UUID, user_id: UUID
) -> Optional[str]:
result = await self.session.execute(
select(models.UserProjectMembership.project_role).where(
models.UserProjectMembership.project_id == project_id,
models.UserProjectMembership.user_id == user_id,
)
)
return result.scalar_one_or_none()
async def get_project_db_routing(
self, project_id: UUID, db_role: str
) -> Optional[ProjectDbRouting]:
result = await self.session.execute(
select(models.ProjectDatabase).where(
models.ProjectDatabase.project_id == project_id,
models.ProjectDatabase.db_role == db_role,
)
)
record = result.scalar_one_or_none()
if not record:
return None
if not is_database_encryption_configured():
raise ValueError("DATABASE_ENCRYPTION_KEY is not configured")
encryptor = get_database_encryptor()
try:
dsn = encryptor.decrypt(record.dsn_encrypted)
except InvalidToken:
raise ValueError(
"Failed to decrypt project DB DSN: DATABASE_ENCRYPTION_KEY mismatch "
"or invalid dsn_encrypted value"
)
dsn = _normalize_postgres_dsn(dsn)
return ProjectDbRouting(
project_id=record.project_id,
db_role=record.db_role,
db_type=record.db_type,
dsn=dsn,
pool_min_size=record.pool_min_size,
pool_max_size=record.pool_max_size,
)
async def get_geoserver_config(
self, project_id: UUID
) -> Optional[ProjectGeoServerInfo]:
result = await self.session.execute(
select(models.ProjectGeoServerConfig).where(
models.ProjectGeoServerConfig.project_id == project_id
)
)
record = result.scalar_one_or_none()
if not record:
return None
if record.gs_admin_password_encrypted:
if is_encryption_configured():
encryptor = get_encryptor()
password = encryptor.decrypt(record.gs_admin_password_encrypted)
else:
password = record.gs_admin_password_encrypted
else:
password = None
return ProjectGeoServerInfo(
project_id=record.project_id,
gs_base_url=record.gs_base_url,
gs_admin_user=record.gs_admin_user,
gs_admin_password=password,
gs_datastore_name=record.gs_datastore_name,
default_extent=record.default_extent,
srid=record.srid,
)
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
stmt = (
select(models.Project, models.UserProjectMembership.project_role)
.join(
models.UserProjectMembership,
models.UserProjectMembership.project_id == models.Project.id,
)
.where(models.UserProjectMembership.user_id == user_id)
.order_by(models.Project.name)
)
result = await self.session.execute(stmt)
return [
ProjectSummary(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role=role,
)
for project, role in result.all()
]
async def list_all_projects(self) -> List[ProjectSummary]:
result = await self.session.execute(
select(models.Project).order_by(models.Project.name)
)
return [
ProjectSummary(
project_id=project.id,
name=project.name,
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
status=project.status,
project_role="owner",
)
for project in result.scalars().all()
]

View File

@@ -9,6 +9,8 @@ import app.services.project_info as project_info
from app.api.v1.router import api_router
from app.infra.db.timescaledb.database import db as tsdb
from app.infra.db.postgresql.database import db as pgdb
from app.infra.db.dynamic_manager import project_connection_manager
from app.infra.db.metadata.database import close_metadata_engine
from app.services.tjnetwork import open_project
from app.core.config import settings
@@ -33,7 +35,7 @@ async def lifespan(app: FastAPI):
await tsdb.open()
await pgdb.open()
# 将数据库实例存储到 app.state供依赖项使用
app.state.db = pgdb
logger.info("Database connection pool initialized and stored in app.state")
@@ -46,6 +48,8 @@ async def lifespan(app: FastAPI):
# 清理资源
await tsdb.close()
await pgdb.close()
await project_connection_manager.close_all()
await close_metadata_engine()
logger.info("Database connections closed")
@@ -58,22 +62,25 @@ app = FastAPI(
redoc_url="/redoc",
)
# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True, # 允许传递凭证Cookie、HTTP 头等)
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有 HTTP 头
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# 添加审计中间件(可选,记录关键操作)
# 如果需要启用审计日志,取消下面的注释
app.add_middleware(AuditMiddleware)
# Include Routers
app.include_router(api_router, prefix="/api/v1")
# Legcy Routers without version prefix
app.include_router(api_router)
# 配置中间件
app.add_middleware(GZipMiddleware, minimum_size=1000)
# 添加审计中间件(可选,记录关键操作)
app.add_middleware(AuditMiddleware)
# 配置 CORS 中间件
# 确保这是你最后一个添加的 app.add_middleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000", # 必须明确指定
"http://127.0.0.1:3000", # 建议同时加上这个
],
allow_credentials=True, # 既然这里是 True
allow_methods=["*"],
allow_headers=["*"],
)

View File

@@ -1,22 +1,4 @@
import os
import yaml
# 获取当前项目根目录的路径
_current_file = os.path.abspath(__file__)
project_root = os.path.dirname(os.path.dirname(os.path.dirname(_current_file)))
# 尝试读取 .yml 或 .yaml 文件
config_file = os.path.join(project_root, "configs", "project_info.yml")
if not os.path.exists(config_file):
config_file = os.path.join(project_root, "configs", "project_info.yaml")
if not os.path.exists(config_file):
raise FileNotFoundError(f"未找到项目配置文件 (project_info.yaml 或 .yml): {os.path.dirname(config_file)}")
with open(config_file, 'r', encoding='utf-8') as f:
_config = yaml.safe_load(f)
if not _config or 'name' not in _config:
raise KeyError(f"项目配置文件中缺少 'name' 配置: {config_file}")
name = _config['name']
# 从环境变量 NETWORK_NAME 读取
name = os.getenv("NETWORK_NAME")

View File

@@ -1190,12 +1190,12 @@ def run_simulation(
if modify_valve_opening[valve_name] == 0:
valve_status["status"] = "CLOSED"
valve_status["setting"] = 0
if modify_valve_opening[valve_name] < 1:
elif modify_valve_opening[valve_name] < 1:
valve_status["status"] = "OPEN"
valve_status["setting"] = 0.1036 * pow(
modify_valve_opening[valve_name], -3.105
)
if modify_valve_opening[valve_name] == 1:
elif modify_valve_opening[valve_name] == 1:
valve_status["status"] = "OPEN"
valve_status["setting"] = 0
cs = ChangeSet()
@@ -1235,7 +1235,7 @@ def run_simulation(
starttime = time.time()
if simulation_type.upper() == "REALTIME":
TimescaleInternalStorage.store_realtime_simulation(
node_result, link_result, modify_pattern_start_time
node_result, link_result, modify_pattern_start_time, db_name=name
)
elif simulation_type.upper() == "EXTENDED":
TimescaleInternalStorage.store_scheme_simulation(
@@ -1245,11 +1245,12 @@ def run_simulation(
link_result,
modify_pattern_start_time,
num_periods_result,
db_name=name,
)
endtime = time.time()
logging.info("store time: %f", endtime - starttime)
# 暂不需要再次存储 SCADA 模拟信息
# TimescaleInternalQueries.fill_scheme_simulation_result_to_SCADA(scheme_Type=scheme_Type, scheme_Name=scheme_Name)
# TimescaleInternalQueries.fill_scheme_simulation_result_to_SCADA(scheme_type=scheme_type, scheme_name=scheme_name)
# if simulation_type.upper() == "REALTIME":
# influxdb_api.store_realtime_simulation_result_to_influxdb(
@@ -1261,11 +1262,11 @@ def run_simulation(
# link_result,
# modify_pattern_start_time,
# num_periods_result,
# scheme_Type,
# scheme_Name,
# scheme_type,
# scheme_name,
# )
# 暂不需要再次存储 SCADA 模拟信息
# influxdb_api.fill_scheme_simulation_result_to_SCADA(scheme_Type=scheme_Type, scheme_Name=scheme_Name)
# influxdb_api.fill_scheme_simulation_result_to_SCADA(scheme_type=scheme_type, scheme_name=scheme_name)
print("after store result")
@@ -1345,7 +1346,7 @@ if __name__ == "__main__":
# run_simulation(name='bb', simulation_type="realtime", modify_pattern_start_time='2025-02-25T23:45:00+08:00')
# 模拟示例2
# run_simulation(name='bb', simulation_type="extended", modify_pattern_start_time='2025-03-10T12:00:00+08:00',
# modify_total_duration=1800, scheme_Type="burst_Analysis", scheme_Name="scheme1")
# modify_total_duration=1800, scheme_type="burst_Analysis", scheme_name="scheme1")
# 查询示例1query_SCADA_ID_corresponding_info
# result = query_SCADA_ID_corresponding_info(name='bb', SCADA_ID='P10755')

View File

@@ -4,6 +4,8 @@ from app.algorithms.valve_isolation import valve_isolation_analysis
def analyze_valve_isolation(
network: str, accident_element: str | list[str]
network: str,
accident_element: str | list[str],
disabled_valves: list[str] = None,
) -> dict[str, Any]:
return valve_isolation_analysis(network, accident_element)
return valve_isolation_analysis(network, accident_element, disabled_valves)

View File

@@ -1 +0,0 @@
name: szh

View File

@@ -1,13 +0,0 @@
FROM python:3.12-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY app ./app
COPY resources ./resources
ENV PYTHONPATH=/app
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -95,7 +95,6 @@ prometheus_client==0.24.1
psycopg==3.2.5
psycopg-binary==3.2.5
psycopg-pool==3.3.0
psycopg2==2.9.10
PuLP==3.1.1
py-key-value-aio==0.3.0
py-key-value-shared==0.3.0
@@ -157,8 +156,6 @@ starlette==0.50.0
threadpoolctl==3.6.0
tqdm==4.67.1
typer==0.21.1
typing-inspection==0.4.0
typing_extensions==4.12.2
tzdata==2025.2
urllib3==2.2.3
uvicorn==0.34.0

33
scripts/encrypt_string.py Normal file
View File

@@ -0,0 +1,33 @@
import os
import sys
# 将项目根目录添加到 python 路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from app.core.encryption import get_database_encryptor
def main() -> int:
plaintext = None
if not sys.stdin.isatty():
stdin_text = sys.stdin.read()
if stdin_text != "":
plaintext = stdin_text.rstrip("\r\n")
if plaintext is None and len(sys.argv) >= 2:
plaintext = sys.argv[1]
if plaintext is None:
try:
plaintext = input("请输入要加密的文本: ")
except EOFError:
plaintext = ""
if not plaintext.strip():
print("Error: plaintext string cannot be empty.", file=sys.stderr)
return 1
token = get_database_encryptor().encrypt(plaintext)
print(token)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -3233,7 +3233,7 @@ async def fastapi_query_all_scheme_all_records(
return loaded_dict
results = influxdb_api.query_scheme_all_record(
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
)
packed = msgpack.packb(results, default=encode_datetime)
redis_client.set(cache_key, packed)
@@ -3257,7 +3257,7 @@ async def fastapi_query_all_scheme_all_records_property(
all_results = msgpack.unpackb(data, object_hook=decode_datetime)
else:
all_results = influxdb_api.query_scheme_all_record(
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
)
packed = msgpack.packb(all_results, default=encode_datetime)
redis_client.set(cache_key, packed)
@@ -3585,7 +3585,7 @@ class BurstAnalysis(BaseModel):
modify_fixed_pump_pattern: Optional[dict[str, list]] = None
modify_variable_pump_pattern: Optional[dict[str, list]] = None
modify_valve_opening: Optional[dict[str, float]] = None
scheme_Name: Optional[str] = None
scheme_name: Optional[str] = None
@app.post("/burst_analysis/")
@@ -3608,7 +3608,7 @@ async def fastapi_burst_analysis(data: BurstAnalysis) -> str:
modify_fixed_pump_pattern=item["modify_fixed_pump_pattern"],
modify_variable_pump_pattern=item["modify_variable_pump_pattern"],
modify_valve_opening=item["modify_valve_opening"],
scheme_Name=item["scheme_Name"],
scheme_name=item["scheme_name"],
)
# os.rename(filename2, filename)
@@ -3616,7 +3616,7 @@ async def fastapi_burst_analysis(data: BurstAnalysis) -> str:
# 将 时间转换成日期,然后缓存这个计算结果
# 缓存key: burst_analysis_<name>_<modify_pattern_start_time>
global redis_client
schemename = data.scheme_Name
schemename = data.scheme_name
print(data.modify_pattern_start_time)
@@ -3627,7 +3627,7 @@ async def fastapi_burst_analysis(data: BurstAnalysis) -> str:
cache_key = f"queryallschemeallrecords_burst_Analysis_{schemename}_{querydate}"
data = redis_client.get(cache_key)
if not data:
results = influxdb_api.query_scheme_all_record("burst_Analysis", scheme_Name=schemename, query_date=querydate)
results = influxdb_api.query_scheme_all_record("burst_Analysis", scheme_name=schemename, query_date=querydate)
packed = msgpack.packb(results, default=encode_datetime)
redis_client.set(cache_key, packed)
"""
@@ -3712,7 +3712,7 @@ async def fastapi_contaminant_simulation(
concentration: float,
duration: int,
pattern: str = None,
scheme_Name: str = None,
scheme_name: str = None,
) -> str:
filename = "c:/lock.simulation"
filename2 = "c:/lock.simulation2"

View File

@@ -68,7 +68,7 @@ def burst_analysis(
:param modify_fixed_pump_pattern: dict中包含多个水泵模式str为工频水泵的idlist为修改后的pattern
:param modify_variable_pump_pattern: dict中包含多个水泵模式str为变频水泵的idlist为修改后的pattern
:param modify_valve_opening: dict中包含多个阀门开启度str为阀门的idfloat为修改后的阀门开启度
:param scheme_Name: 方案名称
:param scheme_name: 方案名称
:return:
"""
scheme_detail: dict = {
@@ -294,7 +294,7 @@ def flushing_analysis(
modify_valve_opening: dict[str, float] = None,
drainage_node_ID: str = None,
flushing_flow: float = 0,
scheme_Name: str = None,
scheme_name: str = None,
) -> None:
"""
管道冲洗模拟
@@ -304,7 +304,7 @@ def flushing_analysis(
:param modify_valve_opening: dict中包含多个阀门开启度str为阀门的idfloat为修改后的阀门开启度
:param drainage_node_ID: 冲洗排放口所在节点ID
:param flushing_flow: 冲洗水量传入参数单位为m3/h
:param scheme_Name: 方案名称
:param scheme_name: 方案名称
:return:
"""
print(
@@ -396,8 +396,8 @@ def flushing_analysis(
modify_pattern_start_time=modify_pattern_start_time,
modify_total_duration=modify_total_duration,
modify_valve_opening=modify_valve_opening,
scheme_Type="flushing_Analysis",
scheme_Name=scheme_Name,
scheme_type="flushing_Analysis",
scheme_name=scheme_name,
)
# step 4. restore the base model
if is_project_open(new_name):
@@ -417,7 +417,7 @@ def contaminant_simulation(
source: str = None,# 污染源节点ID
concentration: float = None, # 污染源浓度单位mg/L
source_pattern: str = None, # 污染源时间变化模式名称
scheme_Name: str = None,
scheme_name: str = None,
) -> None:
"""
污染模拟
@@ -428,7 +428,7 @@ def contaminant_simulation(
:param concentration: 污染源位置处的浓度单位mg/L即默认的污染模拟setting为concentration应改为 Set point booster
:param source_pattern: 污染源的时间变化模式若不传入则默认以恒定浓度持续模拟时间长度等于duration;
若传入,则格式为{1.0,0.5,1.1}等系数列表pattern_step模拟等于模型的hydraulic time step
:param scheme_Name: 方案名称
:param scheme_name: 方案名称
:return:
"""
print(
@@ -533,8 +533,8 @@ def contaminant_simulation(
simulation_type="extended",
modify_pattern_start_time=modify_pattern_start_time,
modify_total_duration=modify_total_duration,
scheme_Type="contaminant_Analysis",
scheme_Name=scheme_Name,
scheme_type="contaminant_Analysis",
scheme_name=scheme_name,
)
# for i in range(1,operation_step):
@@ -630,7 +630,7 @@ def pressure_regulation(
modify_tank_initial_level: dict[str, float] = None,
modify_fixed_pump_pattern: dict[str, list] = None,
modify_variable_pump_pattern: dict[str, list] = None,
scheme_Name: str = None,
scheme_name: str = None,
) -> None:
"""
区域调压模拟用来模拟未来15分钟内开关水泵对区域压力的影响
@@ -640,7 +640,7 @@ def pressure_regulation(
:param modify_tank_initial_level: dict中包含多个水塔str为水塔的idfloat为修改后的initial_level
:param modify_fixed_pump_pattern: dict中包含多个水泵模式str为工频水泵的idlist为修改后的pattern
:param modify_variable_pump_pattern: dict中包含多个水泵模式str为变频水泵的idlist为修改后的pattern
:param scheme_Name: 模拟方案名称
:param scheme_name: 模拟方案名称
:return:
"""
print(
@@ -692,8 +692,8 @@ def pressure_regulation(
modify_tank_initial_level=modify_tank_initial_level,
modify_fixed_pump_pattern=modify_fixed_pump_pattern,
modify_variable_pump_pattern=modify_variable_pump_pattern,
scheme_Type="pressure_regulation",
scheme_Name=scheme_Name,
scheme_type="pressure_regulation",
scheme_name=scheme_name,
)
if is_project_open(new_name):
close_project(new_name)
@@ -1536,7 +1536,7 @@ if __name__ == "__main__":
# 示例1burst_analysis
# burst_analysis(name='bb', modify_pattern_start_time='2025-04-17T00:00:00+08:00',
# burst_ID='GSD230112144241FA18292A84CB', burst_size=400, modify_total_duration=1800, scheme_Name='GSD230112144241FA18292A84CB_400')
# burst_ID='GSD230112144241FA18292A84CB', burst_size=400, modify_total_duration=1800, scheme_name='GSD230112144241FA18292A84CB_400')
# 示例create_user
# create_user(name=project_info.name, username='tjwater dev', password='123456')

View File

@@ -16,6 +16,6 @@ if __name__ == "__main__":
"app.main:app",
host="0.0.0.0",
port=8000,
workers=2, # 这里可以设置多进程
workers=4, # 这里可以设置多进程
loop="asyncio",
)

View File

@@ -0,0 +1,119 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from cryptography.fernet import InvalidToken
from app.infra.repositories.metadata_repository import MetadataRepository
class _DummyResult:
def __init__(self, record):
self._record = record
def scalar_one_or_none(self):
return self._record
class _DummyEncryptor:
def __init__(self, decrypted=None, raise_invalid_token=False):
self._decrypted = decrypted
self._raise_invalid_token = raise_invalid_token
self.encrypted_values = []
def decrypt(self, _value):
if self._raise_invalid_token:
raise InvalidToken()
return self._decrypted
def _build_record(dsn_encrypted: str):
return SimpleNamespace(
project_id=uuid4(),
db_role="biz_data",
db_type="postgresql",
dsn_encrypted=dsn_encrypted,
pool_min_size=1,
pool_max_size=5,
)
def test_invalid_token_with_plaintext_dsn_value_raises_clear_error(monkeypatch):
record = _build_record("postgresql://user:p@ss@localhost:5432/db")
session = SimpleNamespace(
execute=None,
commit=None,
)
session.execute = AsyncMock(return_value=_DummyResult(record))
session.commit = AsyncMock()
encryptor = _DummyEncryptor(raise_invalid_token=True)
repo = MetadataRepository(session)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
lambda: True,
)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.get_database_encryptor",
lambda: encryptor,
)
with pytest.raises(
ValueError,
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
):
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
session.commit.assert_not_awaited()
def test_invalid_token_with_non_dsn_value_raises_clear_error(monkeypatch):
record = _build_record("gAAAAABinvalidciphertext")
session = SimpleNamespace(
execute=None,
commit=None,
)
session.execute = AsyncMock(return_value=_DummyResult(record))
session.commit = AsyncMock()
repo = MetadataRepository(session)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
lambda: True,
)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.get_database_encryptor",
lambda: _DummyEncryptor(raise_invalid_token=True),
)
with pytest.raises(
ValueError,
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
):
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
session.commit.assert_not_awaited()
def test_encrypted_dsn_decrypts_without_migration(monkeypatch):
record = _build_record("encrypted-value")
session = SimpleNamespace(
execute=None,
commit=None,
)
session.execute = AsyncMock(return_value=_DummyResult(record))
session.commit = AsyncMock()
repo = MetadataRepository(session)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
lambda: True,
)
monkeypatch.setattr(
"app.infra.repositories.metadata_repository.get_database_encryptor",
lambda: _DummyEncryptor(decrypted="postgresql://u:p@ss@host/db"),
)
routing = asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
assert routing.dsn == "postgresql://u:p%40ss@host/db"
session.commit.assert_not_awaited()