Compare commits
26 Commits
e893c7db5f
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
| 80b6970970 | |||
| 364a8c8ec2 | |||
| 52ccb8abf1 | |||
| 0bc4058f23 | |||
| 0d3e6ca4fa | |||
| 6fc3aa5209 | |||
| 1b1b0a3697 | |||
| 2826999ddc | |||
| efc05f7278 | |||
| 29209f5c63 | |||
| 020432ad0e | |||
| 780a48d927 | |||
| ff2011ae24 | |||
| f5069a5606 | |||
| eb45e4aaa5 | |||
| a472639b8a | |||
| a0987105dc | |||
| a41be9c362 | |||
| 63b31b46b9 | |||
| e4f864a28c | |||
| dc38313cdc | |||
| f19962510a | |||
| 6434cae21c | |||
| a85ff8e215 | |||
| 2794114000 | |||
| 4c208abe55 |
48
.env.example
48
.env.example
@@ -1,6 +1,6 @@
|
||||
# TJWater Server 环境变量配置模板
|
||||
# 复制此文件为 .env 并填写实际值
|
||||
|
||||
NETWORK_NAME="szh"
|
||||
# ============================================
|
||||
# 安全配置 (必填)
|
||||
# ============================================
|
||||
@@ -12,24 +12,44 @@ SECRET_KEY=your-secret-key-here-change-in-production-use-openssl-rand-hex-32
|
||||
# 数据加密密钥 - 用于敏感数据加密
|
||||
# 生成方式: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
ENCRYPTION_KEY=
|
||||
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
|
||||
# ============================================
|
||||
# 数据库配置 (PostgreSQL)
|
||||
# ============================================
|
||||
DB_NAME=tjwater
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=password
|
||||
DB_NAME="tjwater"
|
||||
DB_HOST="localhost"
|
||||
DB_PORT="5432"
|
||||
DB_USER="postgres"
|
||||
DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 数据库配置 (TimescaleDB)
|
||||
# ============================================
|
||||
TIMESCALEDB_DB_NAME=szh
|
||||
TIMESCALEDB_DB_HOST=localhost
|
||||
TIMESCALEDB_DB_PORT=5433
|
||||
TIMESCALEDB_DB_USER=tjwater
|
||||
TIMESCALEDB_DB_PASSWORD=Tjwater@123456
|
||||
TIMESCALEDB_DB_NAME="szh"
|
||||
TIMESCALEDB_DB_HOST="localhost"
|
||||
TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
# ============================================
|
||||
# 元数据数据库配置 (Metadata DB)
|
||||
# ============================================
|
||||
METADATA_DB_NAME="system_hub"
|
||||
METADATA_DB_HOST="localhost"
|
||||
METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 项目连接缓存与连接池配置
|
||||
# ============================================
|
||||
PROJECT_PG_CACHE_SIZE=50
|
||||
PROJECT_TS_CACHE_SIZE=50
|
||||
PROJECT_PG_POOL_SIZE=5
|
||||
PROJECT_PG_MAX_OVERFLOW=10
|
||||
PROJECT_TS_POOL_MIN_SIZE=1
|
||||
PROJECT_TS_POOL_MAX_SIZE=10
|
||||
|
||||
# ============================================
|
||||
# InfluxDB 配置 (时序数据)
|
||||
@@ -46,6 +66,12 @@ TIMESCALEDB_DB_PASSWORD=Tjwater@123456
|
||||
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
# ALGORITHM=HS256
|
||||
|
||||
# ============================================
|
||||
# Keycloak JWT (可选)
|
||||
# ============================================
|
||||
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||
# KEYCLOAK_ALGORITHM=RS256
|
||||
|
||||
# ============================================
|
||||
# 其他配置
|
||||
# ============================================
|
||||
|
||||
23
.env.local
Normal file
23
.env.local
Normal file
@@ -0,0 +1,23 @@
|
||||
NETWORK_NAME="tjwater"
|
||||
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
|
||||
KEYCLOAK_ALGORITHM="RS256"
|
||||
KEYCLOAK_AUDIENCE="account"
|
||||
|
||||
DB_NAME="tjwater"
|
||||
DB_HOST="192.168.1.114"
|
||||
DB_PORT="5432"
|
||||
DB_USER="tjwater"
|
||||
DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
TIMESCALEDB_DB_NAME="tjwater"
|
||||
TIMESCALEDB_DB_HOST="192.168.1.114"
|
||||
TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
METADATA_DB_NAME="system_hub"
|
||||
METADATA_DB_HOST="192.168.1.114"
|
||||
METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="Tjwater@123456"
|
||||
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
198
.github/copilot-instructions.md
vendored
Normal file
198
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
# TJWater Server - Copilot Instructions
|
||||
|
||||
This is a FastAPI-based water network management system (供水管网智能管理系统) that provides hydraulic simulation, SCADA data integration, network element management, and risk analysis capabilities.
|
||||
|
||||
## Running the Server
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Start the server (default: http://0.0.0.0:8000 with 2 workers)
|
||||
python scripts/run_server.py
|
||||
|
||||
# Note: On Windows, the script automatically sets WindowsSelectorEventLoopPolicy
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest
|
||||
|
||||
# Run a specific test file with verbose output
|
||||
pytest tests/unit/test_pipeline_health_analyzer.py -v
|
||||
|
||||
# Run from conftest helper
|
||||
python tests/conftest.py
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Native Modules** (`app/native/`): Platform-specific compiled extensions (`.so` for Linux, `.pyd` for Windows) providing performance-critical functionality including:
|
||||
- SCADA device integration
|
||||
- Water distribution analysis (WDA)
|
||||
- Pipe risk probability calculations
|
||||
- Wrapped through `app.services.tjnetwork` interface
|
||||
|
||||
2. **Services Layer** (`app/services/`):
|
||||
- `tjnetwork.py`: Main network API wrapper around native modules
|
||||
- `simulation.py`: Hydraulic simulation orchestration (EPANET integration)
|
||||
- `project_info.py`: Project configuration management
|
||||
- `epanet/`: EPANET hydraulic engine integration
|
||||
|
||||
3. **API Layer** (`app/api/v1/`):
|
||||
- **Network Elements**: Separate endpoint modules for junctions, reservoirs, tanks, pipes, pumps, valves
|
||||
- **Components**: Curves, patterns, controls, options, quality, visuals
|
||||
- **Network Features**: Tags, demands, geometry, regions/DMAs
|
||||
- **Core Services**: Auth, project, simulation, SCADA, data query, snapshots
|
||||
|
||||
4. **Database Infrastructure** (`app/infra/db/`):
|
||||
- **PostgreSQL**: Primary relational database (users, audit logs, project metadata)
|
||||
- **TimescaleDB**: Time-series extension for historical data
|
||||
- **InfluxDB**: Optional time-series database for high-frequency SCADA data
|
||||
- Connection pools initialized in `main.py` lifespan context
|
||||
- Database instance stored in `app.state.db` for dependency injection
|
||||
|
||||
5. **Domain Layer** (`app/domain/`):
|
||||
- `models/`: Enums and domain objects (e.g., `UserRole`)
|
||||
- `schemas/`: Pydantic models for request/response validation
|
||||
|
||||
6. **Algorithms** (`app/algorithms/`):
|
||||
- `api_ex/`: Analysis algorithms (k-means sensor placement, sensitivity analysis, pipeline health)
|
||||
- `data_cleaning.py`: Data preprocessing utilities
|
||||
- `simulations.py`: Simulation helpers
|
||||
|
||||
### Security & Authentication
|
||||
|
||||
- **Authentication**: JWT-based with access tokens (30 min) and refresh tokens (7 days)
|
||||
- **Authorization**: Role-based access control (RBAC) with 4 roles:
|
||||
- `VIEWER`: Read-only access
|
||||
- `USER`: Read-write access
|
||||
- `OPERATOR`: Modify data
|
||||
- `ADMIN`: Full permissions
|
||||
- **Audit Logging**: `AuditMiddleware` automatically logs POST/PUT/DELETE requests
|
||||
- **Encryption**: Fernet symmetric encryption for sensitive data (`app.core.encryption`)
|
||||
|
||||
Default admin accounts:
|
||||
- `admin` / `admin123`
|
||||
- `tjwater` / `tjwater@123`
|
||||
|
||||
### Key Files
|
||||
|
||||
- `app/main.py`: FastAPI app initialization, lifespan (DB pools), CORS, middleware, router mounting
|
||||
- `app/api/v1/router.py`: Central router aggregating all endpoint modules
|
||||
- `app/core/config.py`: Settings management using `pydantic-settings`
|
||||
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
|
||||
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
|
||||
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
|
||||
|
||||
## Important Conventions
|
||||
|
||||
### Database Connections
|
||||
|
||||
- Database instances are initialized in `main.py` lifespan and stored in `app.state.db`
|
||||
- Access via dependency injection:
|
||||
```python
|
||||
from app.auth.dependencies import get_db
|
||||
|
||||
async def endpoint(db = Depends(get_db)):
|
||||
# Use db connection
|
||||
```
|
||||
|
||||
### Authentication in Endpoints
|
||||
|
||||
Use dependency injection for auth requirements:
|
||||
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
from app.auth.permissions import require_role, get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
# Require any authenticated user
|
||||
@router.get("/data")
|
||||
async def get_data(current_user = Depends(get_current_active_user)):
|
||||
return data
|
||||
|
||||
# Require specific role (USER or higher)
|
||||
@router.post("/data")
|
||||
async def create_data(current_user = Depends(require_role(UserRole.USER))):
|
||||
return result
|
||||
|
||||
# Admin-only access
|
||||
@router.delete("/data/{id}")
|
||||
async def delete_data(id: int, current_user = Depends(get_current_admin)):
|
||||
return result
|
||||
```
|
||||
|
||||
### API Routing Structure
|
||||
|
||||
- All v1 APIs are mounted under `/api/v1` prefix via `api_router`
|
||||
- Legacy routes without version prefix are also mounted for backward compatibility
|
||||
- Group related endpoints in separate router modules under `app/api/v1/endpoints/`
|
||||
- Use descriptive tags in `router.py` for OpenAPI documentation grouping
|
||||
|
||||
### Native Module Integration
|
||||
|
||||
- Native modules are pre-compiled for specific platforms
|
||||
- Always import through `app.native.api` or `app.services.tjnetwork`
|
||||
- The `tjnetwork` service wraps native APIs with constants like:
|
||||
- Element types: `JUNCTION`, `RESERVOIR`, `TANK`, `PIPE`, `PUMP`, `VALVE`
|
||||
- Operations: `API_ADD`, `API_UPDATE`, `API_DELETE`
|
||||
- `ChangeSet` for batch operations
|
||||
|
||||
### Project Initialization
|
||||
|
||||
- On startup, `main.py` automatically loads project from `project_info.name` if set
|
||||
- Projects are opened via `open_project(name)` from `tjnetwork` service
|
||||
|
||||
### Audit Logging
|
||||
|
||||
Manual audit logging for critical operations:
|
||||
|
||||
```python
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
|
||||
await log_audit_event(
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
username=current_user.username,
|
||||
resource_type="resource_name",
|
||||
resource_id=str(resource_id),
|
||||
ip_address=request.client.host,
|
||||
request_data=data
|
||||
)
|
||||
```
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
- Copy `.env.example` to `.env` before first run
|
||||
- Required environment variables:
|
||||
- `SECRET_KEY`: JWT signing (generate with `openssl rand -hex 32`)
|
||||
- `ENCRYPTION_KEY`: Data encryption (generate with Fernet)
|
||||
- Database credentials for PostgreSQL, TimescaleDB, and optionally InfluxDB
|
||||
|
||||
### Database Migrations
|
||||
|
||||
SQL migration scripts are in `migrations/`:
|
||||
- `001_create_users_table.sql`: User authentication tables
|
||||
- `002_create_audit_logs_table.sql`: Audit logging tables
|
||||
|
||||
Apply with:
|
||||
```bash
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
```
|
||||
|
||||
## API Documentation
|
||||
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
- OpenAPI schema: http://localhost:8000/openapi.json
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `SECURITY_README.md`: Comprehensive security feature documentation
|
||||
- `DEPLOYMENT.md`: Integration guide for security features
|
||||
- `readme.md`: Project overview and directory structure (in Chinese)
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,3 +6,4 @@ build/
|
||||
.env
|
||||
*.dump
|
||||
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
|
||||
.vscode/
|
||||
|
||||
24
Dockerfile
Normal file
24
Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
FROM continuumio/miniconda3:latest
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装 Python 3.12 和 pymetis (通过 conda-forge 避免编译问题)
|
||||
RUN conda install -y -c conda-forge python=3.12 pymetis && \
|
||||
conda clean -afy
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 将代码放入子目录 'app',将数据放入子目录 'db_inp'
|
||||
# 这样临时文件默认会生成在 /app 下,而代码在 /app/app 下,实现了分离
|
||||
COPY app ./app
|
||||
COPY db_inp ./db_inp
|
||||
COPY temp ./temp
|
||||
COPY .env .
|
||||
|
||||
# 设置 PYTHONPATH 以便 uvicorn 找到 app 模块
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||
@@ -103,6 +103,7 @@ def burst_analysis(
|
||||
if isinstance(burst_ID, list):
|
||||
if (burst_size is not None) and (type(burst_size) is not list):
|
||||
return json.dumps("Type mismatch.")
|
||||
# 转化为列表形式
|
||||
elif isinstance(burst_ID, str):
|
||||
burst_ID = [burst_ID]
|
||||
if burst_size is not None:
|
||||
@@ -199,7 +200,7 @@ def valve_close_analysis(
|
||||
modify_pattern_start_time: str,
|
||||
modify_total_duration: int = 900,
|
||||
modify_valve_opening: dict[str, float] = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
关阀模拟
|
||||
@@ -207,7 +208,7 @@ def valve_close_analysis(
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
@@ -261,8 +262,8 @@ def valve_close_analysis(
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
modify_valve_opening=modify_valve_opening,
|
||||
scheme_Type="valve_close_Analysis",
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type="valve_close_Analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
# step 3. restore the base model
|
||||
# for valve in valves:
|
||||
@@ -284,7 +285,7 @@ def flushing_analysis(
|
||||
modify_valve_opening: dict[str, float] = None,
|
||||
drainage_node_ID: str = None,
|
||||
flushing_flow: float = 0,
|
||||
scheme_Name: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
管道冲洗模拟
|
||||
@@ -294,9 +295,15 @@ def flushing_analysis(
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param drainage_node_ID: 冲洗排放口所在节点ID
|
||||
:param flushing_flow: 冲洗水量,传入参数单位为m3/h
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
scheme_detail: dict = {
|
||||
"duration": modify_total_duration,
|
||||
"valve_opening": modify_valve_opening,
|
||||
"drainage_node_ID": drainage_node_ID,
|
||||
"flushing_flow": flushing_flow,
|
||||
}
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
@@ -338,18 +345,42 @@ def flushing_analysis(
|
||||
# status['setting'] = 0.1036 * pow(valve_k, -3.105)
|
||||
# cs.append(status)
|
||||
# set_status(new_name,cs)
|
||||
units = get_option(new_name)
|
||||
options = get_option(new_name)
|
||||
units = options["UNITS"]
|
||||
# step 2. set the emitter coefficient of drainage node or add flush flow to the drainage node
|
||||
# 新建 pattern
|
||||
time_option = get_time(new_name)
|
||||
hydraulic_step = time_option["HYDRAULIC TIMESTEP"]
|
||||
secs = from_clock_to_seconds_2(hydraulic_step)
|
||||
cs_pattern = ChangeSet()
|
||||
pt = {}
|
||||
factors = []
|
||||
tmp_duration = modify_total_duration
|
||||
while tmp_duration > 0:
|
||||
factors.append(1.0)
|
||||
tmp_duration = tmp_duration - secs
|
||||
pt["id"] = "flushing_pt"
|
||||
pt["factors"] = factors
|
||||
cs_pattern.append(pt)
|
||||
add_pattern(new_name, cs_pattern)
|
||||
# 为 emitter_demand 添加新的 pattern
|
||||
emitter_demand = get_demand(new_name, drainage_node_ID)
|
||||
cs = ChangeSet()
|
||||
if flushing_flow > 0:
|
||||
for r in emitter_demand["demands"]:
|
||||
if units == "LPS":
|
||||
r["demand"] += flushing_flow / 3.6
|
||||
elif units == "CMH":
|
||||
r["demand"] += flushing_flow
|
||||
cs.append(emitter_demand)
|
||||
set_demand(new_name, cs)
|
||||
if units == "LPS":
|
||||
emitter_demand["demands"].append(
|
||||
{
|
||||
"demand": flushing_flow / 3.6,
|
||||
"pattern": "flushing_pt",
|
||||
"category": None,
|
||||
}
|
||||
)
|
||||
elif units == "CMH":
|
||||
emitter_demand["demands"].append(
|
||||
{"demand": flushing_flow, "pattern": "flushing_pt", "category": None}
|
||||
)
|
||||
cs.append(emitter_demand)
|
||||
set_demand(new_name, cs)
|
||||
else:
|
||||
pipes = get_node_links(new_name, drainage_node_ID)
|
||||
flush_diameter = 50
|
||||
@@ -386,14 +417,23 @@ def flushing_analysis(
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
modify_valve_opening=modify_valve_opening,
|
||||
scheme_Type="flushing_Analysis",
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type="flushing_analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
# step 4. restore the base model
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# return result
|
||||
# 存储方案信息到 PG 数据库
|
||||
store_scheme_info(
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
scheme_type="flushing_analysis",
|
||||
username="admin",
|
||||
scheme_start_time=modify_pattern_start_time,
|
||||
scheme_detail=scheme_detail,
|
||||
)
|
||||
|
||||
|
||||
############################################################
|
||||
@@ -415,10 +455,10 @@ def contaminant_simulation(
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:param source: 污染源所在的节点ID
|
||||
:param concentration: 污染源位置处的浓度,单位mg/L,即默认的污染模拟setting为concentration(应改为 Set point booster)
|
||||
:param concentration: 污染源位置处的浓度,单位mg/L。默认的污染模拟setting为SOURCE_TYPE_CONCEN(改为SOURCE_TYPE_SETPOINT)
|
||||
:param source_pattern: 污染源的时间变化模式,若不传入则默认以恒定浓度持续模拟,时间长度等于duration;
|
||||
若传入,则格式为{1.0,0.5,1.1}等系数列表pattern_step模拟等于模型的hydraulic time step
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
scheme_detail: dict = {
|
||||
@@ -469,7 +509,7 @@ def contaminant_simulation(
|
||||
# step 2. set pattern
|
||||
if source_pattern != None:
|
||||
pt = get_pattern(new_name, source_pattern)
|
||||
if pt == None:
|
||||
if len(pt) == 0:
|
||||
str_response = str("cant find source_pattern")
|
||||
return str_response
|
||||
else:
|
||||
@@ -490,7 +530,7 @@ def contaminant_simulation(
|
||||
cs_source = ChangeSet()
|
||||
source_schema = {
|
||||
"node": source,
|
||||
"s_type": SOURCE_TYPE_CONCEN,
|
||||
"s_type": SOURCE_TYPE_SETPOINT,
|
||||
"strength": concentration,
|
||||
"pattern": pt["id"],
|
||||
}
|
||||
@@ -634,7 +674,7 @@ def pressure_regulation(
|
||||
modify_tank_initial_level: dict[str, float] = None,
|
||||
modify_fixed_pump_pattern: dict[str, list] = None,
|
||||
modify_variable_pump_pattern: dict[str, list] = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
区域调压模拟,用来模拟未来15分钟内,开关水泵对区域压力的影响
|
||||
@@ -644,7 +684,7 @@ def pressure_regulation(
|
||||
:param modify_tank_initial_level: dict中包含多个水塔,str为水塔的id,float为修改后的initial_level
|
||||
:param modify_fixed_pump_pattern: dict中包含多个水泵模式,str为工频水泵的id,list为修改后的pattern
|
||||
:param modify_variable_pump_pattern: dict中包含多个水泵模式,str为变频水泵的id,list为修改后的pattern
|
||||
:param scheme_Name: 模拟方案名称
|
||||
:param scheme_name: 模拟方案名称
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
@@ -696,8 +736,8 @@ def pressure_regulation(
|
||||
modify_tank_initial_level=modify_tank_initial_level,
|
||||
modify_fixed_pump_pattern=modify_fixed_pump_pattern,
|
||||
modify_variable_pump_pattern=modify_variable_pump_pattern,
|
||||
scheme_Type="pressure_regulation",
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type="pressure_regulation",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from collections import defaultdict, deque
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from app.services.tjnetwork import (
|
||||
get_network_link_nodes,
|
||||
is_node,
|
||||
is_link,
|
||||
get_link_properties,
|
||||
)
|
||||
|
||||
@@ -19,48 +19,102 @@ def _parse_link_entry(link_entry: str) -> tuple[str, str, str, str]:
|
||||
return parts[0], parts[1], parts[2], parts[3]
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def _get_network_topology(network: str):
|
||||
"""
|
||||
解析并缓存网络拓扑,大幅减少重复的 API 调用和字符串解析开销。
|
||||
返回:
|
||||
- pipe_adj: 永久连通的管道/泵邻接表 (dict[str, set])
|
||||
- all_valves: 所有阀门字典 {id: (n1, n2)}
|
||||
- link_lookup: 链路快速查表 {id: (n1, n2, type)} 用于快速定位事故点
|
||||
- node_set: 所有已知节点集合
|
||||
"""
|
||||
pipe_adj = defaultdict(set)
|
||||
all_valves = {}
|
||||
link_lookup = {}
|
||||
node_set = set()
|
||||
|
||||
# 此处假设 get_network_link_nodes 获取全网数据
|
||||
for link_entry in get_network_link_nodes(network):
|
||||
link_id, link_type, node1, node2 = _parse_link_entry(link_entry)
|
||||
link_type_name = str(link_type).lower()
|
||||
|
||||
link_lookup[link_id] = (node1, node2, link_type_name)
|
||||
node_set.add(node1)
|
||||
node_set.add(node2)
|
||||
|
||||
if link_type_name == VALVE_LINK_TYPE:
|
||||
all_valves[link_id] = (node1, node2)
|
||||
else:
|
||||
# 只有非阀门(管道/泵)才进入永久连通图
|
||||
pipe_adj[node1].add(node2)
|
||||
pipe_adj[node2].add(node1)
|
||||
|
||||
return pipe_adj, all_valves, link_lookup, node_set
|
||||
|
||||
|
||||
def valve_isolation_analysis(
|
||||
network: str, accident_elements: str | list[str]
|
||||
network: str, accident_elements: str | list[str], disabled_valves: list[str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
关阀搜索/分析:基于拓扑结构确定事故隔离所需关阀。
|
||||
:param network: 模型名称
|
||||
:param accident_elements: 事故点(节点或管道/泵/阀门ID),可以是单个ID字符串或ID列表
|
||||
:param disabled_valves: 故障/无法关闭的阀门ID列表
|
||||
:return: dict,包含受影响节点、必须关闭阀门、可选阀门等信息
|
||||
"""
|
||||
if disabled_valves is None:
|
||||
disabled_valves_set = set()
|
||||
else:
|
||||
disabled_valves_set = set(disabled_valves)
|
||||
|
||||
if isinstance(accident_elements, str):
|
||||
target_elements = [accident_elements]
|
||||
else:
|
||||
target_elements = accident_elements
|
||||
|
||||
# 1. 获取缓存拓扑 (极快,无 IO)
|
||||
pipe_adj, all_valves, link_lookup, node_set = _get_network_topology(network)
|
||||
|
||||
# 2. 确定起点,优先查表避免 API 调用
|
||||
start_nodes = set()
|
||||
|
||||
for element in target_elements:
|
||||
if is_node(network, element):
|
||||
if element in node_set:
|
||||
start_nodes.add(element)
|
||||
elif is_link(network, element):
|
||||
link_props = get_link_properties(network, element)
|
||||
node1 = link_props.get("node1")
|
||||
node2 = link_props.get("node2")
|
||||
if not node1 or not node2:
|
||||
# 如果是批量处理,可以选择跳过错误或记录错误,这里暂时保持严谨抛出异常
|
||||
raise ValueError(f"Accident link {element} missing node endpoints")
|
||||
start_nodes.add(node1)
|
||||
start_nodes.add(node2)
|
||||
elif element in link_lookup:
|
||||
n1, n2, _ = link_lookup[element]
|
||||
start_nodes.add(n1)
|
||||
start_nodes.add(n2)
|
||||
else:
|
||||
raise ValueError(f"Accident element {element} not found")
|
||||
# 仅当缓存中没找到时(极少见),才回退到慢速 API
|
||||
if is_node(network, element):
|
||||
start_nodes.add(element)
|
||||
else:
|
||||
props = get_link_properties(network, element)
|
||||
n1, n2 = props.get("node1"), props.get("node2")
|
||||
if n1 and n2:
|
||||
start_nodes.add(n1)
|
||||
start_nodes.add(n2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Accident element {element} invalid or missing endpoints"
|
||||
)
|
||||
|
||||
adjacency: dict[str, set[str]] = defaultdict(set)
|
||||
valve_links: dict[str, tuple[str, str]] = {}
|
||||
for link_entry in get_network_link_nodes(network):
|
||||
link_id, link_type, node1, node2 = _parse_link_entry(link_entry)
|
||||
link_type_name = str(link_type).lower()
|
||||
if link_type_name == VALVE_LINK_TYPE:
|
||||
valve_links[link_id] = (node1, node2)
|
||||
continue
|
||||
adjacency[node1].add(node2)
|
||||
adjacency[node2].add(node1)
|
||||
# 3. 处理故障阀门 (构建临时增量图)
|
||||
# 我们不修改 cached pipe_adj,而是建立一个 extra_adj
|
||||
extra_adj = defaultdict(list)
|
||||
boundary_valves = {} # 当前有效的边界阀门
|
||||
|
||||
for vid, (n1, n2) in all_valves.items():
|
||||
if vid in disabled_valves_set:
|
||||
# 故障阀门:视为连通管道
|
||||
extra_adj[n1].append(n2)
|
||||
extra_adj[n2].append(n1)
|
||||
else:
|
||||
# 正常阀门:视为潜在边界
|
||||
boundary_valves[vid] = (n1, n2)
|
||||
|
||||
# 4. BFS 搜索 (叠加 pipe_adj 和 extra_adj)
|
||||
affected_nodes: set[str] = set()
|
||||
queue = deque(start_nodes)
|
||||
while queue:
|
||||
@@ -68,18 +122,29 @@ def valve_isolation_analysis(
|
||||
if node in affected_nodes:
|
||||
continue
|
||||
affected_nodes.add(node)
|
||||
for neighbor in adjacency.get(node, []):
|
||||
if neighbor not in affected_nodes:
|
||||
queue.append(neighbor)
|
||||
|
||||
# 遍历永久管道邻居
|
||||
if node in pipe_adj:
|
||||
for neighbor in pipe_adj[node]:
|
||||
if neighbor not in affected_nodes:
|
||||
queue.append(neighbor)
|
||||
|
||||
# 遍历故障阀门带来的额外邻居
|
||||
if node in extra_adj:
|
||||
for neighbor in extra_adj[node]:
|
||||
if neighbor not in affected_nodes:
|
||||
queue.append(neighbor)
|
||||
|
||||
# 5. 结果聚合
|
||||
must_close_valves: list[str] = []
|
||||
optional_valves: list[str] = []
|
||||
for valve_id, (node1, node2) in valve_links.items():
|
||||
in_node1 = node1 in affected_nodes
|
||||
in_node2 = node2 in affected_nodes
|
||||
if in_node1 and in_node2:
|
||||
|
||||
for valve_id, (n1, n2) in boundary_valves.items():
|
||||
in_n1 = n1 in affected_nodes
|
||||
in_n2 = n2 in affected_nodes
|
||||
if in_n1 and in_n2:
|
||||
optional_valves.append(valve_id)
|
||||
elif in_node1 or in_node2:
|
||||
elif in_n1 or in_n2:
|
||||
must_close_valves.append(valve_id)
|
||||
|
||||
must_close_valves.sort()
|
||||
@@ -87,6 +152,7 @@ def valve_isolation_analysis(
|
||||
|
||||
result = {
|
||||
"accident_elements": target_elements,
|
||||
"disabled_valves": disabled_valves,
|
||||
"affected_nodes": sorted(affected_nodes),
|
||||
"must_close_valves": must_close_valves,
|
||||
"optional_valves": optional_valves,
|
||||
|
||||
@@ -4,33 +4,38 @@
|
||||
仅管理员可访问
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from app.domain.schemas.audit import AuditLogResponse, AuditLogQuery
|
||||
from app.domain.schemas.user import UserInDB
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
from app.auth.dependencies import get_user_repository, get_db
|
||||
from app.auth.permissions import get_current_admin
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.auth.metadata_dependencies import (
|
||||
get_current_metadata_admin,
|
||||
get_current_metadata_user,
|
||||
)
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
async def get_audit_repository(db: Database = Depends(get_db)) -> AuditRepository:
|
||||
async def get_audit_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> AuditRepository:
|
||||
"""获取审计日志仓储"""
|
||||
return AuditRepository(db)
|
||||
return AuditRepository(session)
|
||||
|
||||
@router.get("/logs", response_model=List[AuditLogResponse])
|
||||
async def get_audit_logs(
|
||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
||||
current_user=Depends(get_current_metadata_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志(仅管理员)
|
||||
@@ -39,7 +44,7 @@ async def get_audit_logs(
|
||||
"""
|
||||
logs = await audit_repo.get_logs(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
start_time=start_time,
|
||||
@@ -51,21 +56,21 @@ async def get_audit_logs(
|
||||
|
||||
@router.get("/logs/count")
|
||||
async def get_audit_logs_count(
|
||||
user_id: Optional[int] = Query(None, description="按用户ID过滤"),
|
||||
username: Optional[str] = Query(None, description="按用户名过滤"),
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
||||
current_user=Depends(get_current_metadata_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> dict:
|
||||
"""
|
||||
获取审计日志总数(仅管理员)
|
||||
"""
|
||||
count = await audit_repo.get_log_count(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
start_time=start_time,
|
||||
@@ -80,8 +85,8 @@ async def get_my_audit_logs(
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository)
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询当前用户的审计日志
|
||||
|
||||
@@ -316,7 +316,7 @@ async def fastapi_query_all_scheme_all_records(
|
||||
return loaded_dict
|
||||
|
||||
results = influxdb_api.query_scheme_all_record(
|
||||
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
@@ -334,7 +334,7 @@ async def fastapi_query_all_scheme_all_records_property(
|
||||
all_results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
else:
|
||||
all_results = influxdb_api.query_scheme_all_record(
|
||||
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(all_results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
@@ -22,7 +22,7 @@ async def get_all_extension_data_endpoint(network: str) -> dict[str, Any]:
|
||||
async def get_extension_data_endpoint(network: str, key: str) -> str | None:
|
||||
return get_extension_data(network, key)
|
||||
|
||||
@router.post("/setextensiondata", response_model=None)
|
||||
@router.post("/setextensiondata/", response_model=None)
|
||||
async def set_extension_data_endpoint(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
print(props)
|
||||
|
||||
101
app/api/v1/endpoints/meta.py
Normal file
101
app/api/v1/endpoints/meta.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from psycopg import AsyncConnection
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.project_dependencies import (
|
||||
ProjectContext,
|
||||
get_project_context,
|
||||
get_project_pg_session,
|
||||
get_project_timescale_connection,
|
||||
get_metadata_repository,
|
||||
)
|
||||
from app.auth.metadata_dependencies import get_current_metadata_user
|
||||
from app.core.config import settings
|
||||
from app.domain.schemas.metadata import (
|
||||
GeoServerConfigResponse,
|
||||
ProjectMetaResponse,
|
||||
ProjectSummaryResponse,
|
||||
)
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/meta/project", response_model=ProjectMetaResponse)
|
||||
async def get_project_metadata(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
project = await metadata_repo.get_project_by_id(ctx.project_id)
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||
)
|
||||
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
|
||||
geoserver_payload = (
|
||||
GeoServerConfigResponse(
|
||||
gs_base_url=geoserver.gs_base_url,
|
||||
gs_admin_user=geoserver.gs_admin_user,
|
||||
gs_datastore_name=geoserver.gs_datastore_name,
|
||||
default_extent=geoserver.default_extent,
|
||||
srid=geoserver.srid,
|
||||
)
|
||||
if geoserver
|
||||
else None
|
||||
)
|
||||
return ProjectMetaResponse(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=ctx.project_role,
|
||||
geoserver=geoserver_payload,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
|
||||
async def list_user_projects(
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
try:
|
||||
projects = await metadata_repo.list_projects_for_user(current_user.id)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while listing projects for user %s",
|
||||
current_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
return [
|
||||
ProjectSummaryResponse(
|
||||
project_id=project.project_id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=project.project_role,
|
||||
)
|
||||
for project in projects
|
||||
]
|
||||
|
||||
|
||||
@router.get("/meta/db/health")
|
||||
async def project_db_health(
|
||||
pg_session: AsyncSession = Depends(get_project_pg_session),
|
||||
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
await pg_session.execute(text("SELECT 1"))
|
||||
async with ts_conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
return {"postgres": "ok", "timescale": "ok"}
|
||||
@@ -4,6 +4,8 @@ from fastapi.responses import PlainTextResponse
|
||||
from typing import Any, Dict
|
||||
import app.services.project_info as project_info
|
||||
from app.native.api import ChangeSet
|
||||
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
|
||||
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
|
||||
from app.services.tjnetwork import (
|
||||
list_project,
|
||||
have_project,
|
||||
@@ -62,6 +64,28 @@ async def is_project_open_endpoint(network: str):
|
||||
@router.post("/openproject/")
|
||||
async def open_project_endpoint(network: str):
|
||||
open_project(network)
|
||||
|
||||
# 尝试连接指定数据库
|
||||
try:
|
||||
# 初始化 PostgreSQL 连接池
|
||||
pg_instance = await get_pg_db(network)
|
||||
async with pg_instance.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
|
||||
# 初始化 TimescaleDB 连接池
|
||||
ts_instance = await get_ts_db(network)
|
||||
async with ts_instance.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误但不阻断项目打开,或者根据需求决定是否阻断
|
||||
# 这里选择打印错误,因为 open_project 原本只负责原生部分
|
||||
print(f"Failed to connect to databases for {network}: {str(e)}")
|
||||
# 如果数据库连接是必须的,可以抛出异常:
|
||||
# raise HTTPException(status_code=500, detail=f"Database connection failed: {str(e)}")
|
||||
|
||||
return network
|
||||
|
||||
@router.post("/closeproject/")
|
||||
|
||||
@@ -60,7 +60,7 @@ class BurstAnalysis(BaseModel):
|
||||
modify_fixed_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_variable_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_valve_opening: Optional[dict[str, float]] = None
|
||||
scheme_Name: Optional[str] = None
|
||||
scheme_name: Optional[str] = None
|
||||
|
||||
|
||||
class SchedulingAnalysis(BaseModel):
|
||||
@@ -78,7 +78,7 @@ class PressureRegulation(BaseModel):
|
||||
pump_control: dict
|
||||
tank_init_level: Optional[dict] = None
|
||||
duration: Optional[int] = 900
|
||||
scheme_Name: Optional[str] = None
|
||||
scheme_name: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectManagement(BaseModel):
|
||||
@@ -115,7 +115,15 @@ class PressureSensorPlacement(BaseModel):
|
||||
def run_simulation_manually_by_date(
|
||||
network_name: str, base_date: datetime, start_time: str, duration: int
|
||||
) -> None:
|
||||
start_hour, start_minute, start_second = map(int, start_time.split(":"))
|
||||
time_parts = list(map(int, start_time.split(":")))
|
||||
if len(time_parts) == 2:
|
||||
start_hour, start_minute = time_parts
|
||||
start_second = 0
|
||||
elif len(time_parts) == 3:
|
||||
start_hour, start_minute, start_second = time_parts
|
||||
else:
|
||||
raise ValueError("Invalid start_time format. Use HH:MM or HH:MM:SS")
|
||||
|
||||
start_datetime = base_date.replace(
|
||||
hour=start_hour, minute=start_minute, second=start_second
|
||||
)
|
||||
@@ -196,10 +204,8 @@ async def burst_analysis_endpoint(
|
||||
async def fastapi_burst_analysis(
|
||||
network: str = Query(...),
|
||||
modify_pattern_start_time: str = Query(...),
|
||||
burst_ID: list | str = Query(..., alias="burst_ID[]"), # 添加别名以匹配 URL
|
||||
burst_size: list | float | int = Query(
|
||||
..., alias="burst_size[]"
|
||||
), # 添加别名以匹配 URL
|
||||
burst_ID: list[str] = Query(...),
|
||||
burst_size: list[float] = Query(...),
|
||||
modify_total_duration: int = Query(...),
|
||||
scheme_name: str = Query(...),
|
||||
) -> str:
|
||||
@@ -239,9 +245,29 @@ async def fastapi_valve_close_analysis(
|
||||
|
||||
@router.get("/valve_isolation_analysis/")
|
||||
async def valve_isolation_endpoint(
|
||||
network: str, accident_element: List[str] = Query(...)
|
||||
network: str,
|
||||
accident_element: List[str] = Query(...),
|
||||
disabled_valves: List[str] = Query(None),
|
||||
):
|
||||
return analyze_valve_isolation(network, accident_element)
|
||||
result = {
|
||||
"accident_element": "P461309",
|
||||
"accident_elements": ["P461309"],
|
||||
"affected_nodes": [
|
||||
"J316629_A",
|
||||
"J317037_B",
|
||||
"J317060_B",
|
||||
"J408189_B",
|
||||
"J499996",
|
||||
"J524940",
|
||||
"J535933",
|
||||
"J58841",
|
||||
],
|
||||
"isolatable": True,
|
||||
"must_close_valves": ["210521658", "V12974", "V12986", "V12993"],
|
||||
"optional_valves": [],
|
||||
}
|
||||
result = analyze_valve_isolation(network, accident_element, disabled_valves)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/flushinganalysis/")
|
||||
@@ -260,6 +286,7 @@ async def fastapi_flushing_analysis(
|
||||
drainage_node_ID: str = Query(...),
|
||||
flush_flow: float = 0,
|
||||
duration: int | None = None,
|
||||
scheme_name: str | None = None,
|
||||
) -> str:
|
||||
valve_opening = {
|
||||
valve_id: float(valves_k[idx]) for idx, valve_id in enumerate(valves)
|
||||
@@ -271,6 +298,7 @@ async def fastapi_flushing_analysis(
|
||||
modify_valve_opening=valve_opening,
|
||||
drainage_node_ID=drainage_node_ID,
|
||||
flushing_flow=flush_flow,
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
return result or "success"
|
||||
|
||||
@@ -342,7 +370,7 @@ async def fastapi_pressure_regulation(data: PressureRegulation) -> str:
|
||||
modify_tank_initial_level=item["tank_init_level"],
|
||||
modify_fixed_pump_pattern=fixed_pump_pattern or None,
|
||||
modify_variable_pump_pattern=variable_pump_pattern or None,
|
||||
scheme_Name=item["scheme_Name"],
|
||||
scheme_name=item["scheme_name"],
|
||||
)
|
||||
return "success"
|
||||
|
||||
@@ -634,13 +662,9 @@ async def fastapi_run_simulation_manually_by_date(
|
||||
globals.realtime_region_pipe_flow_and_demand_id,
|
||||
)
|
||||
base_date = datetime.strptime(item["simulation_date"], "%Y-%m-%d")
|
||||
thread = threading.Thread(
|
||||
target=lambda: run_simulation_manually_by_date(
|
||||
item["name"], base_date, item["start_time"], item["duration"]
|
||||
)
|
||||
run_simulation_manually_by_date(
|
||||
item["name"], base_date, item["start_time"], item["duration"]
|
||||
)
|
||||
thread.start()
|
||||
thread.join()
|
||||
return {"status": "success"}
|
||||
except Exception as exc:
|
||||
return {"status": "error", "message": str(exc)}
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.api.v1.endpoints import (
|
||||
cache,
|
||||
user_management, # 新增:用户管理
|
||||
audit, # 新增:审计日志
|
||||
meta,
|
||||
)
|
||||
from app.api.v1.endpoints.network import (
|
||||
general,
|
||||
@@ -46,6 +47,7 @@ api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
||||
api_router.include_router(meta.router, tags=["Metadata"])
|
||||
api_router.include_router(project.router, tags=["Project"])
|
||||
|
||||
# Network Elements (Node/Link Types)
|
||||
|
||||
63
app/auth/keycloak_dependencies.py
Normal file
63
app/auth/keycloak_dependencies.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# import logging
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
oauth2_optional = OAuth2PasswordBearer(
|
||||
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
|
||||
)
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_current_keycloak_sub(
|
||||
token: str | None = Depends(oauth2_optional),
|
||||
) -> UUID:
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if settings.KEYCLOAK_PUBLIC_KEY:
|
||||
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
algorithms = [settings.KEYCLOAK_ALGORITHM]
|
||||
else:
|
||||
key = settings.SECRET_KEY
|
||||
algorithms = [settings.ALGORITHM]
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=algorithms,
|
||||
audience=settings.KEYCLOAK_AUDIENCE or None,
|
||||
)
|
||||
except JWTError as exc:
|
||||
# logger.warning("Keycloak token validation failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing subject claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
return UUID(sub)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid subject claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
60
app/auth/metadata_dependencies.py
Normal file
60
app/auth/metadata_dependencies.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_metadata_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> MetadataRepository:
|
||||
return MetadataRepository(session)
|
||||
|
||||
|
||||
async def get_current_metadata_user(
|
||||
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
try:
|
||||
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while resolving current user",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_metadata_admin(
|
||||
user=Depends(get_current_metadata_user),
|
||||
):
|
||||
if user.is_superuser or user.role == "admin":
|
||||
return user
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _AuthBypassUser:
|
||||
id: UUID = UUID(int=0)
|
||||
role: str = "admin"
|
||||
is_superuser: bool = True
|
||||
is_active: bool = True
|
||||
213
app/auth/project_dependencies.py
Normal file
213
app/auth/project_dependencies.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from psycopg import AsyncConnection
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.dynamic_manager import project_connection_manager
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
DB_ROLE_BIZ_DATA = "biz_data"
|
||||
DB_ROLE_IOT_DATA = "iot_data"
|
||||
DB_TYPE_POSTGRES = "postgresql"
|
||||
DB_TYPE_TIMESCALE = "timescaledb"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectContext:
|
||||
project_id: UUID
|
||||
user_id: UUID
|
||||
project_role: str
|
||||
|
||||
|
||||
async def get_metadata_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> MetadataRepository:
|
||||
return MetadataRepository(session)
|
||||
|
||||
|
||||
async def get_project_context(
|
||||
x_project_id: str = Header(..., alias="X-Project-Id"),
|
||||
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> ProjectContext:
|
||||
try:
|
||||
project_uuid = UUID(x_project_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
project = await metadata_repo.get_project_by_id(project_uuid)
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||
)
|
||||
if project.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
|
||||
)
|
||||
|
||||
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
|
||||
if not membership_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
|
||||
)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while resolving project context",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
|
||||
return ProjectContext(
|
||||
project_id=project.id,
|
||||
user_id=user.id,
|
||||
project_role=membership_role,
|
||||
)
|
||||
|
||||
|
||||
async def get_project_pg_session(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project PostgreSQL routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_POSTGRES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
|
||||
ctx.project_id,
|
||||
DB_ROLE_BIZ_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with sessionmaker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_project_pg_connection(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project PostgreSQL routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_POSTGRES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool = await project_connection_manager.get_pg_pool(
|
||||
ctx.project_id,
|
||||
DB_ROLE_BIZ_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_project_timescale_connection(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_IOT_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project TimescaleDB routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project TimescaleDB routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project TimescaleDB not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_TIMESCALE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project TimescaleDB type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
|
||||
pool = await project_connection_manager.get_timescale_pool(
|
||||
ctx.project_id,
|
||||
DB_ROLE_IOT_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with pool.connection() as conn:
|
||||
yield conn
|
||||
@@ -7,6 +7,7 @@
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,18 +39,16 @@ class AuditAction:
|
||||
|
||||
async def log_audit_event(
|
||||
action: str,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
error_message: Optional[str] = None,
|
||||
db=None, # 新增:可选的数据库实例
|
||||
session=None,
|
||||
):
|
||||
"""
|
||||
记录审计日志
|
||||
@@ -57,67 +56,60 @@ async def log_audit_event(
|
||||
Args:
|
||||
action: 操作类型
|
||||
user_id: 用户ID
|
||||
username: 用户名
|
||||
project_id: 项目ID
|
||||
resource_type: 资源类型
|
||||
resource_id: 资源ID
|
||||
ip_address: IP地址
|
||||
user_agent: User-Agent
|
||||
request_method: 请求方法
|
||||
request_path: 请求路径
|
||||
request_data: 请求数据(敏感字段需脱敏)
|
||||
response_status: 响应状态码
|
||||
error_message: 错误消息
|
||||
db: 数据库实例(可选,如果不提供则尝试获取)
|
||||
session: 元数据库会话(可选)
|
||||
"""
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
|
||||
try:
|
||||
# 脱敏敏感数据
|
||||
if request_data:
|
||||
request_data = sanitize_sensitive_data(request_data)
|
||||
|
||||
# 如果没有提供数据库实例,尝试从全局获取
|
||||
if db is None:
|
||||
try:
|
||||
from app.infra.db.postgresql.database import db as default_db
|
||||
|
||||
# 仅当连接池已初始化时使用
|
||||
if default_db.pool:
|
||||
db = default_db
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 如果仍然没有数据库实例
|
||||
if db is None:
|
||||
# 在某些上下文中可能无法获取,此时静默失败
|
||||
logger.warning("No database instance provided for audit logging")
|
||||
return
|
||||
|
||||
audit_repo = AuditRepository(db)
|
||||
if request_data:
|
||||
request_data = sanitize_sensitive_data(request_data)
|
||||
|
||||
if session is None:
|
||||
async with SessionLocal() as session:
|
||||
audit_repo = AuditRepository(session)
|
||||
await audit_repo.create_log(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
)
|
||||
else:
|
||||
audit_repo = AuditRepository(session)
|
||||
await audit_repo.create_log(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Audit log created: action={action}, user={username or user_id}, "
|
||||
f"resource={resource_type}:{resource_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 审计日志失败不应影响业务流程
|
||||
logger.error(f"Failed to create audit log: {e}", exc_info=True)
|
||||
logger.info(
|
||||
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
|
||||
action,
|
||||
user_id,
|
||||
project_id,
|
||||
resource_type,
|
||||
resource_id,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_sensitive_data(data: dict) -> dict:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
@@ -15,6 +17,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 数据加密密钥 (使用 Fernet)
|
||||
ENCRYPTION_KEY: str = "" # 必须从环境变量设置
|
||||
DATABASE_ENCRYPTION_KEY: str = "" # project_databases.dsn_encrypted 专用
|
||||
|
||||
# Database Config (PostgreSQL)
|
||||
DB_NAME: str = "tjwater"
|
||||
@@ -35,13 +38,45 @@ class Settings(BaseSettings):
|
||||
INFLUXDB_ORG: str = "org"
|
||||
INFLUXDB_BUCKET: str = "bucket"
|
||||
|
||||
# Metadata Database Config (PostgreSQL)
|
||||
METADATA_DB_NAME: str = "system_hub"
|
||||
METADATA_DB_HOST: str = "localhost"
|
||||
METADATA_DB_PORT: str = "5432"
|
||||
METADATA_DB_USER: str = "postgres"
|
||||
METADATA_DB_PASSWORD: str = "password"
|
||||
|
||||
METADATA_DB_POOL_SIZE: int = 5
|
||||
METADATA_DB_MAX_OVERFLOW: int = 10
|
||||
|
||||
PROJECT_PG_CACHE_SIZE: int = 50
|
||||
PROJECT_TS_CACHE_SIZE: int = 50
|
||||
PROJECT_PG_POOL_SIZE: int = 5
|
||||
PROJECT_PG_MAX_OVERFLOW: int = 10
|
||||
PROJECT_TS_POOL_MIN_SIZE: int = 1
|
||||
PROJECT_TS_POOL_MAX_SIZE: int = 10
|
||||
|
||||
# Keycloak JWT (optional override)
|
||||
KEYCLOAK_PUBLIC_KEY: str = ""
|
||||
KEYCLOAK_ALGORITHM: str = "RS256"
|
||||
KEYCLOAK_AUDIENCE: str = ""
|
||||
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
db_password = quote_plus(self.DB_PASSWORD)
|
||||
return f"postgresql://{self.DB_USER}:{db_password}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
@property
|
||||
def METADATA_DATABASE_URI(self) -> str:
|
||||
metadata_password = quote_plus(self.METADATA_DB_PASSWORD)
|
||||
return (
|
||||
f"postgresql+psycopg://{self.METADATA_DB_USER}:{metadata_password}"
|
||||
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=Path(__file__).resolve().parents[2] / ".env",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -3,75 +3,94 @@ from typing import Optional
|
||||
import base64
|
||||
import os
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class Encryptor:
|
||||
"""
|
||||
使用 Fernet (对称加密) 实现数据加密/解密
|
||||
适用于加密敏感配置、用户数据等
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, key: Optional[bytes] = None):
|
||||
"""
|
||||
初始化加密器
|
||||
|
||||
|
||||
Args:
|
||||
key: 加密密钥,如果为 None 则从环境变量读取
|
||||
"""
|
||||
if key is None:
|
||||
key_str = os.getenv("ENCRYPTION_KEY")
|
||||
key_str = os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY
|
||||
if not key_str:
|
||||
raise ValueError(
|
||||
"ENCRYPTION_KEY not found in environment variables. "
|
||||
"ENCRYPTION_KEY not found in environment variables or .env. "
|
||||
"Generate one using: Encryptor.generate_key()"
|
||||
)
|
||||
key = key_str.encode()
|
||||
|
||||
|
||||
self.fernet = Fernet(key)
|
||||
|
||||
|
||||
def encrypt(self, data: str) -> str:
|
||||
"""
|
||||
加密字符串
|
||||
|
||||
|
||||
Args:
|
||||
data: 待加密的明文字符串
|
||||
|
||||
|
||||
Returns:
|
||||
Base64 编码的加密字符串
|
||||
"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
|
||||
encrypted_bytes = self.fernet.encrypt(data.encode())
|
||||
return encrypted_bytes.decode()
|
||||
|
||||
|
||||
def decrypt(self, data: str) -> str:
|
||||
"""
|
||||
解密字符串
|
||||
|
||||
|
||||
Args:
|
||||
data: Base64 编码的加密字符串
|
||||
|
||||
|
||||
Returns:
|
||||
解密后的明文字符串
|
||||
"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
|
||||
decrypted_bytes = self.fernet.decrypt(data.encode())
|
||||
return decrypted_bytes.decode()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_key() -> str:
|
||||
"""
|
||||
生成新的 Fernet 加密密钥
|
||||
|
||||
|
||||
Returns:
|
||||
Base64 编码的密钥字符串
|
||||
"""
|
||||
key = Fernet.generate_key()
|
||||
return key.decode()
|
||||
|
||||
|
||||
# 全局加密器实例(懒加载)
|
||||
_encryptor: Optional[Encryptor] = None
|
||||
_database_encryptor: Optional[Encryptor] = None
|
||||
|
||||
|
||||
def is_encryption_configured() -> bool:
|
||||
return bool(os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY)
|
||||
|
||||
|
||||
def is_database_encryption_configured() -> bool:
|
||||
return bool(
|
||||
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||
or settings.DATABASE_ENCRYPTION_KEY
|
||||
or os.getenv("ENCRYPTION_KEY")
|
||||
or settings.ENCRYPTION_KEY
|
||||
)
|
||||
|
||||
|
||||
def get_encryptor() -> Encryptor:
|
||||
"""获取全局加密器实例"""
|
||||
@@ -80,6 +99,26 @@ def get_encryptor() -> Encryptor:
|
||||
_encryptor = Encryptor()
|
||||
return _encryptor
|
||||
|
||||
|
||||
def get_database_encryptor() -> Encryptor:
|
||||
"""获取 project DB DSN 专用加密器实例"""
|
||||
global _database_encryptor
|
||||
if _database_encryptor is None:
|
||||
key_str = (
|
||||
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||
or settings.DATABASE_ENCRYPTION_KEY
|
||||
or os.getenv("ENCRYPTION_KEY")
|
||||
or settings.ENCRYPTION_KEY
|
||||
)
|
||||
if not key_str:
|
||||
raise ValueError(
|
||||
"DATABASE_ENCRYPTION_KEY not found in environment variables or .env. "
|
||||
"Generate one using: Encryptor.generate_key()"
|
||||
)
|
||||
_database_encryptor = Encryptor(key=key_str.encode())
|
||||
return _database_encryptor
|
||||
|
||||
|
||||
# 向后兼容(延迟加载)
|
||||
def __getattr__(name):
|
||||
if name == "encryptor":
|
||||
|
||||
@@ -1,45 +1,42 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
class AuditLogCreate(BaseModel):
|
||||
"""创建审计日志"""
|
||||
user_id: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
user_id: Optional[UUID] = None
|
||||
project_id: Optional[UUID] = None
|
||||
action: str
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
request_method: Optional[str] = None
|
||||
request_path: Optional[str] = None
|
||||
request_data: Optional[dict] = None
|
||||
response_status: Optional[int] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
class AuditLogResponse(BaseModel):
|
||||
"""审计日志响应"""
|
||||
id: int
|
||||
user_id: Optional[int]
|
||||
username: Optional[str]
|
||||
id: UUID
|
||||
user_id: Optional[UUID]
|
||||
project_id: Optional[UUID]
|
||||
action: str
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
ip_address: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
request_method: Optional[str]
|
||||
request_path: Optional[str]
|
||||
request_data: Optional[dict]
|
||||
response_status: Optional[int]
|
||||
error_message: Optional[str]
|
||||
timestamp: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class AuditLogQuery(BaseModel):
|
||||
"""审计日志查询参数"""
|
||||
user_id: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
user_id: Optional[UUID] = None
|
||||
project_id: Optional[UUID] = None
|
||||
action: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
start_time: Optional[datetime] = None
|
||||
|
||||
33
app/domain/schemas/metadata.py
Normal file
33
app/domain/schemas/metadata.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeoServerConfigResponse(BaseModel):
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
class ProjectMetaResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
geoserver: Optional[GeoServerConfigResponse]
|
||||
|
||||
|
||||
class ProjectSummaryResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
@@ -6,12 +6,17 @@
|
||||
|
||||
import time
|
||||
import json
|
||||
from uuid import UUID
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.infra.db.postgresql.database import db as default_db
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
import logging
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,18 +105,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
return response
|
||||
|
||||
# 4. 提取审计所需信息
|
||||
user_id = None
|
||||
username = None
|
||||
|
||||
# 尝试从请求状态获取当前用户
|
||||
if hasattr(request.state, "user"):
|
||||
user = request.state.user
|
||||
user_id = getattr(user, "id", None)
|
||||
username = getattr(user, "username", None)
|
||||
user_id = await self._resolve_user_id(request)
|
||||
project_id = self._resolve_project_id(request)
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
# 确定操作类型
|
||||
action = self._determine_action(request)
|
||||
@@ -122,21 +120,14 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
await log_audit_event(
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
project_id=project_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_method=request.method,
|
||||
request_path=str(request.url.path),
|
||||
request_data=request_data,
|
||||
response_status=response.status_code,
|
||||
error_message=(
|
||||
None
|
||||
if response.status_code < 400
|
||||
else f"HTTP {response.status_code}"
|
||||
),
|
||||
db=default_db,
|
||||
)
|
||||
except Exception as e:
|
||||
# 审计失败不应影响响应
|
||||
@@ -148,6 +139,48 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return response
|
||||
|
||||
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||
project_header = request.headers.get("X-Project-Id")
|
||||
if not project_header:
|
||||
return None
|
||||
try:
|
||||
return UUID(project_header)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
||||
auth_header = request.headers.get("authorization")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
key = (
|
||||
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else settings.SECRET_KEY
|
||||
)
|
||||
algorithms = (
|
||||
[settings.KEYCLOAK_ALGORITHM]
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else [settings.ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
return None
|
||||
keycloak_id = UUID(sub)
|
||||
except (JWTError, ValueError):
|
||||
return None
|
||||
|
||||
async with SessionLocal() as session:
|
||||
repo = MetadataRepository(session)
|
||||
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||
if user and user.is_active:
|
||||
return user.id
|
||||
return None
|
||||
|
||||
def _determine_action(self, request: Request) -> str:
|
||||
"""根据请求路径和方法确定操作类型"""
|
||||
path = request.url.path.lower()
|
||||
|
||||
211
app/infra/db/dynamic_manager.py
Normal file
211
app/infra/db/dynamic_manager.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.rows import dict_row
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PgEngineEntry:
|
||||
engine: AsyncEngine
|
||||
sessionmaker: async_sessionmaker[AsyncSession]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheKey:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
|
||||
|
||||
class ProjectConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
|
||||
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._pg_lock = asyncio.Lock()
|
||||
self._ts_lock = asyncio.Lock()
|
||||
self._pg_raw_lock = asyncio.Lock()
|
||||
|
||||
def _normalize_pg_url(self, url: str) -> str:
|
||||
parsed = make_url(url)
|
||||
if parsed.drivername == "postgresql":
|
||||
parsed = parsed.set(drivername="postgresql+psycopg")
|
||||
return str(parsed)
|
||||
|
||||
async def get_pg_sessionmaker(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> async_sessionmaker[AsyncSession]:
|
||||
async with self._pg_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
entry = self._pg_cache.get(key)
|
||||
if entry:
|
||||
self._pg_cache.move_to_end(key)
|
||||
return entry.sessionmaker
|
||||
|
||||
normalized_url = self._normalize_pg_url(connection_url)
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
engine = create_async_engine(
|
||||
normalized_url,
|
||||
pool_size=pool_min_size,
|
||||
max_overflow=max(0, pool_max_size - pool_min_size),
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
self._pg_cache[key] = PgEngineEntry(
|
||||
engine=engine,
|
||||
sessionmaker=sessionmaker,
|
||||
)
|
||||
await self._evict_pg_if_needed()
|
||||
logger.info(
|
||||
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return sessionmaker
|
||||
|
||||
async def get_timescale_pool(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._ts_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._ts_cache.get(key)
|
||||
if pool:
|
||||
self._ts_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
max_size=pool_max_size,
|
||||
open=False,
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._ts_cache[key] = pool
|
||||
await self._evict_ts_if_needed()
|
||||
logger.info(
|
||||
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return pool
|
||||
|
||||
async def get_pg_pool(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._pg_raw_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._pg_raw_cache.get(key)
|
||||
if pool:
|
||||
self._pg_raw_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
max_size=pool_max_size,
|
||||
open=False,
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._pg_raw_cache[key] = pool
|
||||
await self._evict_pg_raw_if_needed()
|
||||
logger.info(
|
||||
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return pool
|
||||
|
||||
async def _evict_pg_if_needed(self) -> None:
|
||||
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||
key, entry = self._pg_cache.popitem(last=False)
|
||||
await entry.engine.dispose()
|
||||
logger.info(
|
||||
"Evicted PostgreSQL engine for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def _evict_ts_if_needed(self) -> None:
|
||||
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
|
||||
key, pool = self._ts_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Evicted TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def _evict_pg_raw_if_needed(self) -> None:
|
||||
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||
key, pool = self._pg_raw_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Evicted PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
async with self._pg_lock:
|
||||
for key, entry in list(self._pg_cache.items()):
|
||||
await entry.engine.dispose()
|
||||
logger.info(
|
||||
"Closed PostgreSQL engine for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._pg_cache.clear()
|
||||
|
||||
async with self._ts_lock:
|
||||
for key, pool in list(self._ts_cache.items()):
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Closed TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._ts_cache.clear()
|
||||
|
||||
async with self._pg_raw_lock:
|
||||
for key, pool in list(self._pg_raw_cache.items()):
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Closed PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._pg_raw_cache.clear()
|
||||
|
||||
|
||||
project_connection_manager = ProjectConnectionManager()
|
||||
@@ -404,8 +404,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
|
||||
Point("link")
|
||||
.tag("date", None)
|
||||
.tag("ID", None)
|
||||
.tag("scheme_Type", None)
|
||||
.tag("scheme_Name", None)
|
||||
.tag("scheme_type", None)
|
||||
.tag("scheme_name", None)
|
||||
.field("flow", 0.0)
|
||||
.field("leakage", 0.0)
|
||||
.field("velocity", 0.0)
|
||||
@@ -420,8 +420,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
|
||||
Point("node")
|
||||
.tag("date", None)
|
||||
.tag("ID", None)
|
||||
.tag("scheme_Type", None)
|
||||
.tag("scheme_Name", None)
|
||||
.tag("scheme_type", None)
|
||||
.tag("scheme_name", None)
|
||||
.field("head", 0.0)
|
||||
.field("pressure", 0.0)
|
||||
.field("actualdemand", 0.0)
|
||||
@@ -436,8 +436,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
|
||||
.tag("date", None)
|
||||
.tag("description", None)
|
||||
.tag("device_ID", None)
|
||||
.tag("scheme_Type", None)
|
||||
.tag("scheme_Name", None)
|
||||
.tag("scheme_type", None)
|
||||
.tag("scheme_name", None)
|
||||
.field("monitored_value", 0.0)
|
||||
.field("datacleaning_value", 0.0)
|
||||
.field("scheme_simulation_value", 0.0)
|
||||
@@ -1811,8 +1811,8 @@ def query_SCADA_data_by_device_ID_and_time(
|
||||
def query_scheme_SCADA_data_by_device_ID_and_time(
|
||||
query_ids_list: List[str],
|
||||
query_time: str,
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
@@ -1843,7 +1843,7 @@ def query_scheme_SCADA_data_by_device_ID_and_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["device_ID"] == "{device_id}" and r["_field"] == "monitored_value" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|
||||
|> filter(fn: (r) => r["device_ID"] == "{device_id}" and r["_field"] == "monitored_value" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
|
||||
"""
|
||||
# 执行查询
|
||||
try:
|
||||
@@ -2585,7 +2585,7 @@ def query_all_scheme_record_by_time_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -2635,7 +2635,7 @@ def query_scheme_simulation_result_by_ID_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "node" and r["ID"] == "{ID}")
|
||||
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" and r["ID"] == "{ID}")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -2660,7 +2660,7 @@ def query_scheme_simulation_result_by_ID_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "link" and r["ID"] == "{ID}")
|
||||
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "link" and r["ID"] == "{ID}")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -3227,8 +3227,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
link_result_list: List[Dict[str, any]],
|
||||
scheme_start_time: str,
|
||||
num_periods: int = 1,
|
||||
scheme_Type: str = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_type: str = None,
|
||||
scheme_name: str = None,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
):
|
||||
"""
|
||||
@@ -3237,8 +3237,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
:param link_result_list: (List[Dict[str, any]]): 包含连接和结果数据的字典列表。
|
||||
:param scheme_start_time: (str): 方案模拟开始时间。
|
||||
:param num_periods: (int): 方案模拟的周期数
|
||||
:param scheme_Type: (str): 方案类型
|
||||
:param scheme_Name: (str): 方案名称
|
||||
:param scheme_type: (str): 方案类型
|
||||
:param scheme_name: (str): 方案名称
|
||||
:param bucket: (str): InfluxDB 的 bucket 名称,默认值为 "scheme_simulation_result"。
|
||||
:return:
|
||||
"""
|
||||
@@ -3298,8 +3298,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
Point("node")
|
||||
.tag("date", date_str)
|
||||
.tag("ID", node_id)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("head", data.get("head", 0.0))
|
||||
.field("pressure", data.get("pressure", 0.0))
|
||||
.field("actualdemand", data.get("demand", 0.0))
|
||||
@@ -3322,8 +3322,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
Point("link")
|
||||
.tag("date", date_str)
|
||||
.tag("ID", link_id)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("flow", data.get("flow", 0.0))
|
||||
.field("velocity", data.get("velocity", 0.0))
|
||||
.field("headloss", data.get("headloss", 0.0))
|
||||
@@ -3409,14 +3409,14 @@ def query_corresponding_query_id_and_element_id(name: str) -> None:
|
||||
|
||||
# 2025/03/11
|
||||
def fill_scheme_simulation_result_to_SCADA(
|
||||
scheme_Type: str = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_type: str = None,
|
||||
scheme_name: str = None,
|
||||
query_date: str = None,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
):
|
||||
"""
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param bucket: InfluxDB 的 bucket 名称,默认值为 "scheme_simulation_result"
|
||||
:return:
|
||||
@@ -3457,8 +3457,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
# 查找associated_element_id的对应值
|
||||
for key, value in globals.scheme_source_outflow_ids.items():
|
||||
scheme_source_outflow_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="link",
|
||||
@@ -3470,8 +3470,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_source_outflow")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3480,8 +3480,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_pipe_flow_ids.items():
|
||||
scheme_pipe_flow_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="link",
|
||||
@@ -3492,8 +3492,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_pipe_flow")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3502,8 +3502,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_pressure_ids.items():
|
||||
scheme_pressure_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="node",
|
||||
@@ -3514,8 +3514,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_pressure")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3524,8 +3524,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_demand_ids.items():
|
||||
scheme_demand_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="node",
|
||||
@@ -3536,8 +3536,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_demand")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3546,8 +3546,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_quality_ids.items():
|
||||
scheme_quality_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="node",
|
||||
@@ -3558,8 +3558,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_quality")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3629,15 +3629,15 @@ def query_SCADA_data_curve(
|
||||
|
||||
# 2025/02/18
|
||||
def query_scheme_all_record_by_time(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_time: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> tuple:
|
||||
"""
|
||||
查询指定方案某一时刻的所有记录,包括‘node'和‘link’,分别以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_time: 输入的北京时间,格式为 '2024-11-24T17:30:00+08:00'。
|
||||
:param bucket: 数据存储的 bucket 名称。
|
||||
:return: dict: tuple: (node_records, link_records)
|
||||
@@ -3660,7 +3660,7 @@ def query_scheme_all_record_by_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -3710,8 +3710,8 @@ def query_scheme_all_record_by_time(
|
||||
|
||||
# 2025/03/04
|
||||
def query_scheme_all_record_by_time_property(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_time: str,
|
||||
type: str,
|
||||
property: str,
|
||||
@@ -3719,8 +3719,8 @@ def query_scheme_all_record_by_time_property(
|
||||
) -> list:
|
||||
"""
|
||||
查询指定方案某一时刻‘node'或‘link’某一属性值,以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_time: 输入的北京时间,格式为 '2024-11-24T17:30:00+08:00'。
|
||||
:param type: 查询的类型(决定 measurement)
|
||||
:param property: 查询的字段名称(field)
|
||||
@@ -3752,7 +3752,7 @@ def query_scheme_all_record_by_time_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -3767,8 +3767,8 @@ def query_scheme_all_record_by_time_property(
|
||||
|
||||
# 2025/02/19
|
||||
def query_scheme_curve_by_ID_property(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
ID: str,
|
||||
type: str,
|
||||
@@ -3776,9 +3776,9 @@ def query_scheme_curve_by_ID_property(
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> list:
|
||||
"""
|
||||
根据scheme_Type和scheme_Name,查询该模拟方案中,某一node或link的某一属性值的所有时间的结果
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
根据scheme_Type和scheme_name,查询该模拟方案中,某一node或link的某一属性值的所有时间的结果
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param ID: 元素的ID
|
||||
:param type: 元素的类型,node或link
|
||||
@@ -3817,7 +3817,7 @@ def query_scheme_curve_by_ID_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["_measurement"] == "{measurement}" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["ID"] == "{ID}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["_measurement"] == "{measurement}" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["ID"] == "{ID}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -3832,15 +3832,15 @@ def query_scheme_curve_by_ID_property(
|
||||
|
||||
# 2025/02/21
|
||||
def query_scheme_all_record(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> tuple:
|
||||
"""
|
||||
查询指定方案的所有记录,包括‘node'和‘link’,分别以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param bucket: 数据存储的 bucket 名称。
|
||||
:return: dict: tuple: (node_records, link_records)
|
||||
@@ -3867,7 +3867,7 @@ def query_scheme_all_record(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -3917,8 +3917,8 @@ def query_scheme_all_record(
|
||||
|
||||
# 2025/03/04
|
||||
def query_scheme_all_record_property(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
type: str,
|
||||
property: str,
|
||||
@@ -3926,8 +3926,8 @@ def query_scheme_all_record_property(
|
||||
) -> list:
|
||||
"""
|
||||
查询指定方案的‘node'或‘link’的某一属性值,以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param type: 查询的类型(决定 measurement)
|
||||
:param property: 查询的字段名称(field)
|
||||
@@ -3964,7 +3964,7 @@ def query_scheme_all_record_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["date"] == "{query_date}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["date"] == "{query_date}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -4245,8 +4245,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
link_data[key][field] = record.get_value()
|
||||
link_data[key]["measurement"] = record.get_measurement()
|
||||
link_data[key]["date"] = record.values.get("date", None)
|
||||
link_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
link_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
link_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
link_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
# 构建 Flux 查询语句,查询指定时间范围内的数据
|
||||
flux_query_node = f"""
|
||||
from(bucket: "{bucket}")
|
||||
@@ -4267,8 +4267,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
node_data[key][field] = record.get_value()
|
||||
node_data[key]["measurement"] = record.get_measurement()
|
||||
node_data[key]["date"] = record.values.get("date", None)
|
||||
node_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
node_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
node_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
node_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
for key in set(link_data.keys()):
|
||||
row = {"time": key[0], "ID": key[1]}
|
||||
row.update(link_data.get(key, {}))
|
||||
@@ -4288,8 +4288,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"flow",
|
||||
"leakage",
|
||||
@@ -4311,8 +4311,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"head",
|
||||
"pressure",
|
||||
@@ -4330,15 +4330,15 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
|
||||
# 2025/02/18
|
||||
def export_scheme_simulation_result_to_csv_scheme(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> None:
|
||||
"""
|
||||
导出influxdb中scheme_simulation_result这个bucket的数据到csv中
|
||||
:param scheme_Type: 查询的方案类型
|
||||
:param scheme_Name: 查询的方案名
|
||||
:param scheme_type: 查询的方案类型
|
||||
:param scheme_name: 查询的方案名
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param bucket: 数据存储的 bucket 名称,默认值为 "SCADA_data"
|
||||
:return:
|
||||
@@ -4366,7 +4366,7 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
flux_query_link = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["_measurement"] == "link" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|
||||
|> filter(fn: (r) => r["_measurement"] == "link" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
|
||||
"""
|
||||
# 执行查询
|
||||
link_tables = query_api.query(flux_query_link)
|
||||
@@ -4382,13 +4382,13 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
link_data[key][field] = record.get_value()
|
||||
link_data[key]["measurement"] = record.get_measurement()
|
||||
link_data[key]["date"] = record.values.get("date", None)
|
||||
link_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
link_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
link_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
link_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
# 构建 Flux 查询语句,查询指定时间范围内的数据
|
||||
flux_query_node = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["_measurement"] == "node" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|
||||
|> filter(fn: (r) => r["_measurement"] == "node" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
|
||||
"""
|
||||
# 执行查询
|
||||
node_tables = query_api.query(flux_query_node)
|
||||
@@ -4404,8 +4404,8 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
node_data[key][field] = record.get_value()
|
||||
node_data[key]["measurement"] = record.get_measurement()
|
||||
node_data[key]["date"] = record.values.get("date", None)
|
||||
node_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
node_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
node_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
node_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
for key in set(link_data.keys()):
|
||||
row = {"time": key[0], "ID": key[1]}
|
||||
row.update(link_data.get(key, {}))
|
||||
@@ -4416,10 +4416,10 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
node_rows.append(row)
|
||||
# 动态生成 CSV 文件名
|
||||
csv_filename_link = (
|
||||
f"scheme_simulation_link_result_{scheme_Name}_of_{scheme_Type}.csv"
|
||||
f"scheme_simulation_link_result_{scheme_name}_of_{scheme_type}.csv"
|
||||
)
|
||||
csv_filename_node = (
|
||||
f"scheme_simulation_node_result_{scheme_Name}_of_{scheme_Type}.csv"
|
||||
f"scheme_simulation_node_result_{scheme_name}_of_{scheme_type}.csv"
|
||||
)
|
||||
# 写入到 CSV 文件
|
||||
with open(csv_filename_link, mode="w", newline="") as file:
|
||||
@@ -4429,8 +4429,8 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"flow",
|
||||
"leakage",
|
||||
@@ -4452,8 +4452,8 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"head",
|
||||
"pressure",
|
||||
@@ -4878,15 +4878,15 @@ if __name__ == "__main__":
|
||||
# export_scheme_simulation_result_to_csv_time(start_date='2025-02-13', end_date='2025-02-15')
|
||||
|
||||
# 示例9:export_scheme_simulation_result_to_csv_scheme
|
||||
# export_scheme_simulation_result_to_csv_scheme(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10')
|
||||
# export_scheme_simulation_result_to_csv_scheme(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10')
|
||||
|
||||
# 示例10:query_scheme_all_record_by_time
|
||||
# node_records, link_records = query_scheme_all_record_by_time(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_time="2025-02-14T10:30:00+08:00")
|
||||
# node_records, link_records = query_scheme_all_record_by_time(scheme_type='burst_Analysis', scheme_name='scheme1', query_time="2025-02-14T10:30:00+08:00")
|
||||
# print("Node 数据:", node_records)
|
||||
# print("Link 数据:", link_records)
|
||||
|
||||
# 示例11:query_scheme_curve_by_ID_property
|
||||
# curve_result = query_scheme_curve_by_ID_property(scheme_Type='burst_Analysis', scheme_Name='scheme1', ID='ZBBDTZDP000022',
|
||||
# curve_result = query_scheme_curve_by_ID_property(scheme_type='burst_Analysis', scheme_name='scheme1', ID='ZBBDTZDP000022',
|
||||
# type='node', property='head')
|
||||
# print(curve_result)
|
||||
|
||||
@@ -4896,7 +4896,7 @@ if __name__ == "__main__":
|
||||
# print("Link 数据:", link_records)
|
||||
|
||||
# 示例13:query_scheme_all_record
|
||||
# node_records, link_records = query_scheme_all_record(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10')
|
||||
# node_records, link_records = query_scheme_all_record(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10')
|
||||
# print("Node 数据:", node_records)
|
||||
# print("Link 数据:", link_records)
|
||||
|
||||
@@ -4909,16 +4909,16 @@ if __name__ == "__main__":
|
||||
# print(result_records)
|
||||
|
||||
# 示例16:query_scheme_all_record_by_time_property
|
||||
# result_records = query_scheme_all_record_by_time_property(scheme_Type='burst_Analysis', scheme_Name='scheme1',
|
||||
# result_records = query_scheme_all_record_by_time_property(scheme_type='burst_Analysis', scheme_name='scheme1',
|
||||
# query_time='2025-02-14T10:30:00+08:00', type='node', property='head')
|
||||
# print(result_records)
|
||||
|
||||
# 示例17:query_scheme_all_record_property
|
||||
# result_records = query_scheme_all_record_property(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10', type='node', property='head')
|
||||
# result_records = query_scheme_all_record_property(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10', type='node', property='head')
|
||||
# print(result_records)
|
||||
|
||||
# 示例18:fill_scheme_simulation_result_to_SCADA
|
||||
# fill_scheme_simulation_result_to_SCADA(scheme_Type='burst_Analysis', scheme_Name='burst0330', query_date='2025-03-30')
|
||||
# fill_scheme_simulation_result_to_SCADA(scheme_type='burst_Analysis', scheme_name='burst0330', query_date='2025-03-30')
|
||||
|
||||
# 示例19:query_SCADA_data_by_device_ID_and_timerange
|
||||
# result = query_SCADA_data_by_device_ID_and_timerange(query_ids_list=globals.pressure_non_realtime_ids, start_time='2025-04-16T00:00:00+08:00',
|
||||
@@ -4926,7 +4926,7 @@ if __name__ == "__main__":
|
||||
# print(result)
|
||||
|
||||
# 示例:manually_get_burst_flow
|
||||
# leakage = manually_get_burst_flow(scheme_Type='burst_Analysis', scheme_Name='burst_scheme', scheme_start_time='2025-03-10T12:00:00+08:00')
|
||||
# leakage = manually_get_burst_flow(scheme_type='burst_Analysis', scheme_name='burst_scheme', scheme_start_time='2025-03-10T12:00:00+08:00')
|
||||
# print(leakage)
|
||||
|
||||
# 示例:upload_cleaned_SCADA_data_to_influxdb
|
||||
|
||||
3
app/infra/db/metadata/__init__.py
Normal file
3
app/infra/db/metadata/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .database import get_metadata_session, close_metadata_engine
|
||||
|
||||
__all__ = ["get_metadata_session", "close_metadata_engine"]
|
||||
27
app/infra/db/metadata/database.py
Normal file
27
app/infra/db/metadata/database.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.METADATA_DATABASE_URI,
|
||||
pool_size=settings.METADATA_DB_POOL_SIZE,
|
||||
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def close_metadata_engine() -> None:
|
||||
await engine.dispose()
|
||||
logger.info("Metadata database engine disposed.")
|
||||
115
app/infra/db/metadata/models.py
Normal file
115
app/infra/db/metadata/models.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
keycloak_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
role: Mapped[str] = mapped_column(String(20), default="user")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Project(Base):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
code: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class ProjectDatabase(Base):
|
||||
__tablename__ = "project_databases"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
db_role: Mapped[str] = mapped_column(String(20))
|
||||
db_type: Mapped[str] = mapped_column(String(20))
|
||||
dsn_encrypted: Mapped[str] = mapped_column(Text)
|
||||
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
|
||||
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
|
||||
|
||||
|
||||
class ProjectGeoServerConfig(Base):
|
||||
__tablename__ = "project_geoserver_configs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
|
||||
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
srid: Mapped[int] = mapped_column(Integer, default=4326)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class UserProjectMembership(Base):
|
||||
__tablename__ = "user_project_membership"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
project_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(50))
|
||||
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
@@ -17,7 +17,14 @@ class Database:
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
target_db_name = db_name or self.db_name
|
||||
|
||||
# Get connection string, handling default case where target_db_name might be None
|
||||
if target_db_name:
|
||||
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
|
||||
else:
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
@@ -26,7 +33,7 @@ class Database:
|
||||
open=False, # Don't open immediately, wait for startup
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: default")
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize postgresql connection pool: {e}")
|
||||
raise
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .scada_info import ScadaRepository
|
||||
from .scheme import SchemeRepository
|
||||
from app.auth.project_dependencies import get_project_pg_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
# 动态项目 PostgreSQL 连接依赖
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
@router.get("/scada-info")
|
||||
|
||||
@@ -17,7 +17,14 @@ class Database:
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = timescaledb_info.get_pgconn_string()
|
||||
target_db_name = db_name or self.db_name
|
||||
|
||||
# Get connection string, handling default case where target_db_name might be None
|
||||
if target_db_name:
|
||||
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
else:
|
||||
conn_string = timescaledb_info.get_pgconn_string()
|
||||
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
@@ -27,7 +34,7 @@ class Database:
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(
|
||||
f"TimescaleDB connection pool initialized for database: default"
|
||||
f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
|
||||
@@ -46,7 +53,9 @@ class Database:
|
||||
def get_pgconn_string(self, db_name=None):
|
||||
"""Get the TimescaleDB connection string."""
|
||||
target_db_name = db_name or self.db_name
|
||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
if target_db_name:
|
||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
return timescaledb_info.get_pgconn_string()
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator:
|
||||
|
||||
@@ -1,40 +1,32 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .schemas.realtime import RealtimeRepository
|
||||
from .schemas.scheme import SchemeRepository
|
||||
from .schemas.scada import ScadaRepository
|
||||
from .composite_queries import CompositeQueries
|
||||
from app.infra.db.postgresql.database import get_database_instance as get_postgres_database_instance
|
||||
from app.auth.project_dependencies import (
|
||||
get_project_pg_connection,
|
||||
get_project_timescale_connection,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
# 动态项目 TimescaleDB 连接依赖
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
# PostgreSQL 数据库连接依赖函数
|
||||
# 动态项目 PostgreSQL 连接依赖
|
||||
async def get_postgres_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_postgres_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
# --- Realtime Endpoints ---
|
||||
@@ -152,7 +144,7 @@ async def query_realtime_simulation_by_id_time(
|
||||
results = await RealtimeRepository.query_simulation_result_by_id_time(
|
||||
conn, id, type, query_time
|
||||
)
|
||||
return {"results": results}
|
||||
return {"result": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@@ -1,220 +1,112 @@
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuditRepository:
|
||||
"""审计日志数据访问层"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
"""审计日志数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_log(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
action: str = "",
|
||||
action: str,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
error_message: Optional[str] = None
|
||||
) -> Optional[AuditLogResponse]:
|
||||
"""
|
||||
创建审计日志
|
||||
|
||||
Args:
|
||||
参数说明见 AuditLogCreate
|
||||
|
||||
Returns:
|
||||
创建的审计日志对象
|
||||
"""
|
||||
query = """
|
||||
INSERT INTO audit_logs (
|
||||
user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message
|
||||
)
|
||||
VALUES (
|
||||
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
|
||||
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
|
||||
%(request_data)s, %(response_status)s, %(error_message)s
|
||||
)
|
||||
RETURNING id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {
|
||||
'user_id': user_id,
|
||||
'username': username,
|
||||
'action': action,
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'ip_address': ip_address,
|
||||
'user_agent': user_agent,
|
||||
'request_method': request_method,
|
||||
'request_path': request_path,
|
||||
'request_data': json.dumps(request_data) if request_data else None,
|
||||
'response_status': response_status,
|
||||
'error_message': error_message
|
||||
})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return AuditLogResponse(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating audit log: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
) -> AuditLogResponse:
|
||||
log = models.AuditLog(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(log)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
return AuditLogResponse.model_validate(log)
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
limit: int = 100,
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志
|
||||
|
||||
Args:
|
||||
user_id: 用户ID过滤
|
||||
username: 用户名过滤
|
||||
action: 操作类型过滤
|
||||
resource_type: 资源类型过滤
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
skip: 跳过记录数
|
||||
limit: 限制记录数
|
||||
|
||||
Returns:
|
||||
审计日志列表
|
||||
"""
|
||||
# 构建动态查询
|
||||
conditions = []
|
||||
params = {'skip': skip, 'limit': limit}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT %(limit)s OFFSET %(skip)s
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
rows = await cur.fetchall()
|
||||
return [AuditLogResponse(**row) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying audit logs: {e}")
|
||||
raise
|
||||
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = (
|
||||
select(models.AuditLog)
|
||||
.where(*conditions)
|
||||
.order_by(models.AuditLog.timestamp.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
AuditLogResponse.model_validate(log)
|
||||
for log in result.scalars().all()
|
||||
]
|
||||
|
||||
async def get_log_count(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
获取审计日志数量
|
||||
|
||||
Args:
|
||||
参数同 get_logs
|
||||
|
||||
Returns:
|
||||
日志总数
|
||||
"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
result = await cur.fetchone()
|
||||
return result['count'] if result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting audit logs: {e}")
|
||||
return 0
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
|
||||
result = await self.session.execute(stmt)
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
197
app/infra/repositories/metadata_repository.py
Normal file
197
app/infra/repositories/metadata_repository.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.encryption import (
|
||||
get_database_encryptor,
|
||||
get_encryptor,
|
||||
is_database_encryption_configured,
|
||||
is_encryption_configured,
|
||||
)
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
|
||||
def _normalize_postgres_dsn(dsn: str) -> str:
|
||||
if not dsn or "://" not in dsn:
|
||||
return dsn
|
||||
scheme, rest = dsn.split("://", 1)
|
||||
if scheme not in ("postgresql", "postgres", "postgresql+psycopg"):
|
||||
return dsn
|
||||
if "@" not in rest:
|
||||
return dsn
|
||||
userinfo, hostinfo = rest.rsplit("@", 1)
|
||||
if ":" not in userinfo:
|
||||
return dsn
|
||||
username, password = userinfo.split(":", 1)
|
||||
if "@" not in password:
|
||||
return dsn
|
||||
password = password.replace("@", "%40")
|
||||
return f"{scheme}://{username}:{password}@{hostinfo}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectDbRouting:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
db_type: str
|
||||
dsn: str
|
||||
pool_min_size: int
|
||||
pool_max_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectGeoServerInfo:
|
||||
project_id: UUID
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_admin_password: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectSummary:
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
|
||||
|
||||
class MetadataRepository:
|
||||
"""元数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_user_by_keycloak_id(self, keycloak_id: UUID) -> Optional[models.User]:
|
||||
result = await self.session.execute(
|
||||
select(models.User).where(models.User.keycloak_id == keycloak_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).where(models.Project.id == project_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_membership_role(
|
||||
self, project_id: UUID, user_id: UUID
|
||||
) -> Optional[str]:
|
||||
result = await self.session.execute(
|
||||
select(models.UserProjectMembership.project_role).where(
|
||||
models.UserProjectMembership.project_id == project_id,
|
||||
models.UserProjectMembership.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_db_routing(
|
||||
self, project_id: UUID, db_role: str
|
||||
) -> Optional[ProjectDbRouting]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectDatabase).where(
|
||||
models.ProjectDatabase.project_id == project_id,
|
||||
models.ProjectDatabase.db_role == db_role,
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if not is_database_encryption_configured():
|
||||
raise ValueError("DATABASE_ENCRYPTION_KEY is not configured")
|
||||
encryptor = get_database_encryptor()
|
||||
try:
|
||||
dsn = encryptor.decrypt(record.dsn_encrypted)
|
||||
except InvalidToken:
|
||||
raise ValueError(
|
||||
"Failed to decrypt project DB DSN: DATABASE_ENCRYPTION_KEY mismatch "
|
||||
"or invalid dsn_encrypted value"
|
||||
)
|
||||
dsn = _normalize_postgres_dsn(dsn)
|
||||
return ProjectDbRouting(
|
||||
project_id=record.project_id,
|
||||
db_role=record.db_role,
|
||||
db_type=record.db_type,
|
||||
dsn=dsn,
|
||||
pool_min_size=record.pool_min_size,
|
||||
pool_max_size=record.pool_max_size,
|
||||
)
|
||||
|
||||
async def get_geoserver_config(
|
||||
self, project_id: UUID
|
||||
) -> Optional[ProjectGeoServerInfo]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectGeoServerConfig).where(
|
||||
models.ProjectGeoServerConfig.project_id == project_id
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if record.gs_admin_password_encrypted:
|
||||
if is_encryption_configured():
|
||||
encryptor = get_encryptor()
|
||||
password = encryptor.decrypt(record.gs_admin_password_encrypted)
|
||||
else:
|
||||
password = record.gs_admin_password_encrypted
|
||||
else:
|
||||
password = None
|
||||
return ProjectGeoServerInfo(
|
||||
project_id=record.project_id,
|
||||
gs_base_url=record.gs_base_url,
|
||||
gs_admin_user=record.gs_admin_user,
|
||||
gs_admin_password=password,
|
||||
gs_datastore_name=record.gs_datastore_name,
|
||||
default_extent=record.default_extent,
|
||||
srid=record.srid,
|
||||
)
|
||||
|
||||
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
|
||||
stmt = (
|
||||
select(models.Project, models.UserProjectMembership.project_role)
|
||||
.join(
|
||||
models.UserProjectMembership,
|
||||
models.UserProjectMembership.project_id == models.Project.id,
|
||||
)
|
||||
.where(models.UserProjectMembership.user_id == user_id)
|
||||
.order_by(models.Project.name)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=role,
|
||||
)
|
||||
for project, role in result.all()
|
||||
]
|
||||
|
||||
async def list_all_projects(self) -> List[ProjectSummary]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).order_by(models.Project.name)
|
||||
)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role="owner",
|
||||
)
|
||||
for project in result.scalars().all()
|
||||
]
|
||||
37
app/main.py
37
app/main.py
@@ -9,6 +9,8 @@ import app.services.project_info as project_info
|
||||
from app.api.v1.router import api_router
|
||||
from app.infra.db.timescaledb.database import db as tsdb
|
||||
from app.infra.db.postgresql.database import db as pgdb
|
||||
from app.infra.db.dynamic_manager import project_connection_manager
|
||||
from app.infra.db.metadata.database import close_metadata_engine
|
||||
from app.services.tjnetwork import open_project
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -33,7 +35,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
await tsdb.open()
|
||||
await pgdb.open()
|
||||
|
||||
|
||||
# 将数据库实例存储到 app.state,供依赖项使用
|
||||
app.state.db = pgdb
|
||||
logger.info("Database connection pool initialized and stored in app.state")
|
||||
@@ -46,6 +48,8 @@ async def lifespan(app: FastAPI):
|
||||
# 清理资源
|
||||
await tsdb.close()
|
||||
await pgdb.close()
|
||||
await project_connection_manager.close_all()
|
||||
await close_metadata_engine()
|
||||
logger.info("Database connections closed")
|
||||
|
||||
|
||||
@@ -58,22 +62,25 @@ app = FastAPI(
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
# 配置 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 允许所有来源
|
||||
allow_credentials=True, # 允许传递凭证(Cookie、HTTP 头等)
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有 HTTP 头
|
||||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# 添加审计中间件(可选,记录关键操作)
|
||||
# 如果需要启用审计日志,取消下面的注释
|
||||
app.add_middleware(AuditMiddleware)
|
||||
|
||||
# Include Routers
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
# Legcy Routers without version prefix
|
||||
app.include_router(api_router)
|
||||
|
||||
# 配置中间件
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
# 添加审计中间件(可选,记录关键操作)
|
||||
app.add_middleware(AuditMiddleware)
|
||||
# 配置 CORS 中间件
|
||||
# 确保这是你最后一个添加的 app.add_middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:3000", # 必须明确指定
|
||||
"http://127.0.0.1:3000", # 建议同时加上这个
|
||||
],
|
||||
allow_credentials=True, # 既然这里是 True
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@@ -1,22 +1,4 @@
|
||||
import os
|
||||
import yaml
|
||||
|
||||
# 获取当前项目根目录的路径
|
||||
_current_file = os.path.abspath(__file__)
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(_current_file)))
|
||||
|
||||
# 尝试读取 .yml 或 .yaml 文件
|
||||
config_file = os.path.join(project_root, "configs", "project_info.yml")
|
||||
if not os.path.exists(config_file):
|
||||
config_file = os.path.join(project_root, "configs", "project_info.yaml")
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
raise FileNotFoundError(f"未找到项目配置文件 (project_info.yaml 或 .yml): {os.path.dirname(config_file)}")
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
_config = yaml.safe_load(f)
|
||||
|
||||
if not _config or 'name' not in _config:
|
||||
raise KeyError(f"项目配置文件中缺少 'name' 配置: {config_file}")
|
||||
|
||||
name = _config['name']
|
||||
# 从环境变量 NETWORK_NAME 读取
|
||||
name = os.getenv("NETWORK_NAME")
|
||||
|
||||
@@ -1190,12 +1190,12 @@ def run_simulation(
|
||||
if modify_valve_opening[valve_name] == 0:
|
||||
valve_status["status"] = "CLOSED"
|
||||
valve_status["setting"] = 0
|
||||
if modify_valve_opening[valve_name] < 1:
|
||||
elif modify_valve_opening[valve_name] < 1:
|
||||
valve_status["status"] = "OPEN"
|
||||
valve_status["setting"] = 0.1036 * pow(
|
||||
modify_valve_opening[valve_name], -3.105
|
||||
)
|
||||
if modify_valve_opening[valve_name] == 1:
|
||||
elif modify_valve_opening[valve_name] == 1:
|
||||
valve_status["status"] = "OPEN"
|
||||
valve_status["setting"] = 0
|
||||
cs = ChangeSet()
|
||||
@@ -1235,7 +1235,7 @@ def run_simulation(
|
||||
starttime = time.time()
|
||||
if simulation_type.upper() == "REALTIME":
|
||||
TimescaleInternalStorage.store_realtime_simulation(
|
||||
node_result, link_result, modify_pattern_start_time
|
||||
node_result, link_result, modify_pattern_start_time, db_name=name
|
||||
)
|
||||
elif simulation_type.upper() == "EXTENDED":
|
||||
TimescaleInternalStorage.store_scheme_simulation(
|
||||
@@ -1245,11 +1245,12 @@ def run_simulation(
|
||||
link_result,
|
||||
modify_pattern_start_time,
|
||||
num_periods_result,
|
||||
db_name=name,
|
||||
)
|
||||
endtime = time.time()
|
||||
logging.info("store time: %f", endtime - starttime)
|
||||
# 暂不需要再次存储 SCADA 模拟信息
|
||||
# TimescaleInternalQueries.fill_scheme_simulation_result_to_SCADA(scheme_Type=scheme_Type, scheme_Name=scheme_Name)
|
||||
# TimescaleInternalQueries.fill_scheme_simulation_result_to_SCADA(scheme_type=scheme_type, scheme_name=scheme_name)
|
||||
|
||||
# if simulation_type.upper() == "REALTIME":
|
||||
# influxdb_api.store_realtime_simulation_result_to_influxdb(
|
||||
@@ -1261,11 +1262,11 @@ def run_simulation(
|
||||
# link_result,
|
||||
# modify_pattern_start_time,
|
||||
# num_periods_result,
|
||||
# scheme_Type,
|
||||
# scheme_Name,
|
||||
# scheme_type,
|
||||
# scheme_name,
|
||||
# )
|
||||
# 暂不需要再次存储 SCADA 模拟信息
|
||||
# influxdb_api.fill_scheme_simulation_result_to_SCADA(scheme_Type=scheme_Type, scheme_Name=scheme_Name)
|
||||
# influxdb_api.fill_scheme_simulation_result_to_SCADA(scheme_type=scheme_type, scheme_name=scheme_name)
|
||||
|
||||
print("after store result")
|
||||
|
||||
@@ -1345,7 +1346,7 @@ if __name__ == "__main__":
|
||||
# run_simulation(name='bb', simulation_type="realtime", modify_pattern_start_time='2025-02-25T23:45:00+08:00')
|
||||
# 模拟示例2
|
||||
# run_simulation(name='bb', simulation_type="extended", modify_pattern_start_time='2025-03-10T12:00:00+08:00',
|
||||
# modify_total_duration=1800, scheme_Type="burst_Analysis", scheme_Name="scheme1")
|
||||
# modify_total_duration=1800, scheme_type="burst_Analysis", scheme_name="scheme1")
|
||||
|
||||
# 查询示例1:query_SCADA_ID_corresponding_info
|
||||
# result = query_SCADA_ID_corresponding_info(name='bb', SCADA_ID='P10755')
|
||||
|
||||
@@ -4,6 +4,8 @@ from app.algorithms.valve_isolation import valve_isolation_analysis
|
||||
|
||||
|
||||
def analyze_valve_isolation(
|
||||
network: str, accident_element: str | list[str]
|
||||
network: str,
|
||||
accident_element: str | list[str],
|
||||
disabled_valves: list[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
return valve_isolation_analysis(network, accident_element)
|
||||
return valve_isolation_analysis(network, accident_element, disabled_valves)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
name: szh
|
||||
@@ -1,13 +0,0 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY app ./app
|
||||
COPY resources ./resources
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -95,7 +95,6 @@ prometheus_client==0.24.1
|
||||
psycopg==3.2.5
|
||||
psycopg-binary==3.2.5
|
||||
psycopg-pool==3.3.0
|
||||
psycopg2==2.9.10
|
||||
PuLP==3.1.1
|
||||
py-key-value-aio==0.3.0
|
||||
py-key-value-shared==0.3.0
|
||||
@@ -157,8 +156,6 @@ starlette==0.50.0
|
||||
threadpoolctl==3.6.0
|
||||
tqdm==4.67.1
|
||||
typer==0.21.1
|
||||
typing-inspection==0.4.0
|
||||
typing_extensions==4.12.2
|
||||
tzdata==2025.2
|
||||
urllib3==2.2.3
|
||||
uvicorn==0.34.0
|
||||
|
||||
33
scripts/encrypt_string.py
Normal file
33
scripts/encrypt_string.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 将项目根目录添加到 python 路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from app.core.encryption import get_database_encryptor
|
||||
|
||||
|
||||
def main() -> int:
|
||||
plaintext = None
|
||||
if not sys.stdin.isatty():
|
||||
stdin_text = sys.stdin.read()
|
||||
if stdin_text != "":
|
||||
plaintext = stdin_text.rstrip("\r\n")
|
||||
if plaintext is None and len(sys.argv) >= 2:
|
||||
plaintext = sys.argv[1]
|
||||
if plaintext is None:
|
||||
try:
|
||||
plaintext = input("请输入要加密的文本: ")
|
||||
except EOFError:
|
||||
plaintext = ""
|
||||
if not plaintext.strip():
|
||||
print("Error: plaintext string cannot be empty.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
token = get_database_encryptor().encrypt(plaintext)
|
||||
print(token)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -3233,7 +3233,7 @@ async def fastapi_query_all_scheme_all_records(
|
||||
return loaded_dict
|
||||
|
||||
results = influxdb_api.query_scheme_all_record(
|
||||
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
@@ -3257,7 +3257,7 @@ async def fastapi_query_all_scheme_all_records_property(
|
||||
all_results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
else:
|
||||
all_results = influxdb_api.query_scheme_all_record(
|
||||
scheme_Type=schemetype, scheme_Name=schemename, query_date=querydate
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(all_results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
@@ -3585,7 +3585,7 @@ class BurstAnalysis(BaseModel):
|
||||
modify_fixed_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_variable_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_valve_opening: Optional[dict[str, float]] = None
|
||||
scheme_Name: Optional[str] = None
|
||||
scheme_name: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/burst_analysis/")
|
||||
@@ -3608,7 +3608,7 @@ async def fastapi_burst_analysis(data: BurstAnalysis) -> str:
|
||||
modify_fixed_pump_pattern=item["modify_fixed_pump_pattern"],
|
||||
modify_variable_pump_pattern=item["modify_variable_pump_pattern"],
|
||||
modify_valve_opening=item["modify_valve_opening"],
|
||||
scheme_Name=item["scheme_Name"],
|
||||
scheme_name=item["scheme_name"],
|
||||
)
|
||||
# os.rename(filename2, filename)
|
||||
|
||||
@@ -3616,7 +3616,7 @@ async def fastapi_burst_analysis(data: BurstAnalysis) -> str:
|
||||
# 将 时间转换成日期,然后缓存这个计算结果
|
||||
# 缓存key: burst_analysis_<name>_<modify_pattern_start_time>
|
||||
global redis_client
|
||||
schemename = data.scheme_Name
|
||||
schemename = data.scheme_name
|
||||
|
||||
print(data.modify_pattern_start_time)
|
||||
|
||||
@@ -3627,7 +3627,7 @@ async def fastapi_burst_analysis(data: BurstAnalysis) -> str:
|
||||
cache_key = f"queryallschemeallrecords_burst_Analysis_{schemename}_{querydate}"
|
||||
data = redis_client.get(cache_key)
|
||||
if not data:
|
||||
results = influxdb_api.query_scheme_all_record("burst_Analysis", scheme_Name=schemename, query_date=querydate)
|
||||
results = influxdb_api.query_scheme_all_record("burst_Analysis", scheme_name=schemename, query_date=querydate)
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
"""
|
||||
@@ -3712,7 +3712,7 @@ async def fastapi_contaminant_simulation(
|
||||
concentration: float,
|
||||
duration: int,
|
||||
pattern: str = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> str:
|
||||
filename = "c:/lock.simulation"
|
||||
filename2 = "c:/lock.simulation2"
|
||||
|
||||
@@ -68,7 +68,7 @@ def burst_analysis(
|
||||
:param modify_fixed_pump_pattern: dict中包含多个水泵模式,str为工频水泵的id,list为修改后的pattern
|
||||
:param modify_variable_pump_pattern: dict中包含多个水泵模式,str为变频水泵的id,list为修改后的pattern
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
scheme_detail: dict = {
|
||||
@@ -294,7 +294,7 @@ def flushing_analysis(
|
||||
modify_valve_opening: dict[str, float] = None,
|
||||
drainage_node_ID: str = None,
|
||||
flushing_flow: float = 0,
|
||||
scheme_Name: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
管道冲洗模拟
|
||||
@@ -304,7 +304,7 @@ def flushing_analysis(
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param drainage_node_ID: 冲洗排放口所在节点ID
|
||||
:param flushing_flow: 冲洗水量,传入参数单位为m3/h
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
@@ -396,8 +396,8 @@ def flushing_analysis(
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
modify_valve_opening=modify_valve_opening,
|
||||
scheme_Type="flushing_Analysis",
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type="flushing_Analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
# step 4. restore the base model
|
||||
if is_project_open(new_name):
|
||||
@@ -417,7 +417,7 @@ def contaminant_simulation(
|
||||
source: str = None,# 污染源节点ID
|
||||
concentration: float = None, # 污染源浓度,单位mg/L
|
||||
source_pattern: str = None, # 污染源时间变化模式名称
|
||||
scheme_Name: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
污染模拟
|
||||
@@ -428,7 +428,7 @@ def contaminant_simulation(
|
||||
:param concentration: 污染源位置处的浓度,单位mg/L,即默认的污染模拟setting为concentration(应改为 Set point booster)
|
||||
:param source_pattern: 污染源的时间变化模式,若不传入则默认以恒定浓度持续模拟,时间长度等于duration;
|
||||
若传入,则格式为{1.0,0.5,1.1}等系数列表pattern_step模拟等于模型的hydraulic time step
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
@@ -533,8 +533,8 @@ def contaminant_simulation(
|
||||
simulation_type="extended",
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
scheme_Type="contaminant_Analysis",
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type="contaminant_Analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
|
||||
# for i in range(1,operation_step):
|
||||
@@ -630,7 +630,7 @@ def pressure_regulation(
|
||||
modify_tank_initial_level: dict[str, float] = None,
|
||||
modify_fixed_pump_pattern: dict[str, list] = None,
|
||||
modify_variable_pump_pattern: dict[str, list] = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
区域调压模拟,用来模拟未来15分钟内,开关水泵对区域压力的影响
|
||||
@@ -640,7 +640,7 @@ def pressure_regulation(
|
||||
:param modify_tank_initial_level: dict中包含多个水塔,str为水塔的id,float为修改后的initial_level
|
||||
:param modify_fixed_pump_pattern: dict中包含多个水泵模式,str为工频水泵的id,list为修改后的pattern
|
||||
:param modify_variable_pump_pattern: dict中包含多个水泵模式,str为变频水泵的id,list为修改后的pattern
|
||||
:param scheme_Name: 模拟方案名称
|
||||
:param scheme_name: 模拟方案名称
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
@@ -692,8 +692,8 @@ def pressure_regulation(
|
||||
modify_tank_initial_level=modify_tank_initial_level,
|
||||
modify_fixed_pump_pattern=modify_fixed_pump_pattern,
|
||||
modify_variable_pump_pattern=modify_variable_pump_pattern,
|
||||
scheme_Type="pressure_regulation",
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type="pressure_regulation",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
@@ -1536,7 +1536,7 @@ if __name__ == "__main__":
|
||||
|
||||
# 示例1:burst_analysis
|
||||
# burst_analysis(name='bb', modify_pattern_start_time='2025-04-17T00:00:00+08:00',
|
||||
# burst_ID='GSD230112144241FA18292A84CB', burst_size=400, modify_total_duration=1800, scheme_Name='GSD230112144241FA18292A84CB_400')
|
||||
# burst_ID='GSD230112144241FA18292A84CB', burst_size=400, modify_total_duration=1800, scheme_name='GSD230112144241FA18292A84CB_400')
|
||||
|
||||
# 示例:create_user
|
||||
# create_user(name=project_info.name, username='tjwater dev', password='123456')
|
||||
|
||||
@@ -16,6 +16,6 @@ if __name__ == "__main__":
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
workers=2, # 这里可以设置多进程
|
||||
workers=4, # 这里可以设置多进程
|
||||
loop="asyncio",
|
||||
)
|
||||
|
||||
119
tests/unit/test_metadata_repository_dsn_decrypt.py
Normal file
119
tests/unit/test_metadata_repository_dsn_decrypt.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
def __init__(self, record):
|
||||
self._record = record
|
||||
|
||||
def scalar_one_or_none(self):
|
||||
return self._record
|
||||
|
||||
|
||||
class _DummyEncryptor:
|
||||
def __init__(self, decrypted=None, raise_invalid_token=False):
|
||||
self._decrypted = decrypted
|
||||
self._raise_invalid_token = raise_invalid_token
|
||||
self.encrypted_values = []
|
||||
|
||||
def decrypt(self, _value):
|
||||
if self._raise_invalid_token:
|
||||
raise InvalidToken()
|
||||
return self._decrypted
|
||||
|
||||
def _build_record(dsn_encrypted: str):
|
||||
return SimpleNamespace(
|
||||
project_id=uuid4(),
|
||||
db_role="biz_data",
|
||||
db_type="postgresql",
|
||||
dsn_encrypted=dsn_encrypted,
|
||||
pool_min_size=1,
|
||||
pool_max_size=5,
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_token_with_plaintext_dsn_value_raises_clear_error(monkeypatch):
|
||||
record = _build_record("postgresql://user:p@ss@localhost:5432/db")
|
||||
session = SimpleNamespace(
|
||||
execute=None,
|
||||
commit=None,
|
||||
)
|
||||
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||
session.commit = AsyncMock()
|
||||
encryptor = _DummyEncryptor(raise_invalid_token=True)
|
||||
repo = MetadataRepository(session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||
lambda: encryptor,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
|
||||
):
|
||||
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||
session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
def test_invalid_token_with_non_dsn_value_raises_clear_error(monkeypatch):
|
||||
record = _build_record("gAAAAABinvalidciphertext")
|
||||
session = SimpleNamespace(
|
||||
execute=None,
|
||||
commit=None,
|
||||
)
|
||||
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||
session.commit = AsyncMock()
|
||||
repo = MetadataRepository(session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||
lambda: _DummyEncryptor(raise_invalid_token=True),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="DATABASE_ENCRYPTION_KEY mismatch or invalid dsn_encrypted value",
|
||||
):
|
||||
asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||
session.commit.assert_not_awaited()
|
||||
|
||||
|
||||
def test_encrypted_dsn_decrypts_without_migration(monkeypatch):
|
||||
record = _build_record("encrypted-value")
|
||||
session = SimpleNamespace(
|
||||
execute=None,
|
||||
commit=None,
|
||||
)
|
||||
session.execute = AsyncMock(return_value=_DummyResult(record))
|
||||
session.commit = AsyncMock()
|
||||
repo = MetadataRepository(session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.is_database_encryption_configured",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.repositories.metadata_repository.get_database_encryptor",
|
||||
lambda: _DummyEncryptor(decrypted="postgresql://u:p@ss@host/db"),
|
||||
)
|
||||
|
||||
routing = asyncio.run(repo.get_project_db_routing(record.project_id, "biz_data"))
|
||||
|
||||
assert routing.dsn == "postgresql://u:p%40ss@host/db"
|
||||
session.commit.assert_not_awaited()
|
||||
Reference in New Issue
Block a user