109 Commits

Author SHA1 Message Date
jiang e336ffcd46 移除存在无效数据的 cli 命令 2026-06-05 16:42:03 +08:00
jiang 52b8f07abd 更新 cli 命令,新增 network 其他元素的属性查询 2026-06-05 15:48:53 +08:00
jiang 7efaeb41e8 新增pyclipper依赖 2026-06-05 13:43:53 +08:00
jiang 9a7aad2d36 fix(cli): constrain timeseries option values 2026-06-05 13:43:32 +08:00
jiang b7872f29a9 优化 CLI 命令,增加获取所有节点和管道属性的功能 2026-06-03 17:31:49 +08:00
jiang 233960d8db 明确时间模拟需要 scheme_name 参数 2026-06-03 17:31:44 +08:00
jiang b9410b0ff3 统一前后端时间时区请求 2026-06-03 11:17:37 +08:00
jiang 4982efba5e 更新tjwater-cli network参数;更新metadb health方法 2026-06-03 10:48:01 +08:00
jiang f87dd91b2b 修复--auth-stdin读取失败的bug 2026-06-02 18:41:39 +08:00
jiang c16e6e3d0c 移除 --auth-context,改为 --auth-stdin,结构化传递解析认证信息 2026-06-02 17:17:00 +08:00
jiang 40e699e173 拆分代码;约束cli命令 2026-06-02 14:54:08 +08:00
jiang 9b8a517092 更新文件夹命名 2026-06-02 11:13:07 +08:00
jiang f274cf5122 整理 tjwater-cli 代码和文档 2026-06-02 11:11:56 +08:00
jiang 60db2a7193 优化 cli 命令设计 2026-06-01 17:05:26 +08:00
jiang b72e42521c 优化时间范围查询,添加 UTC 时间标准化处理 2026-06-01 16:46:51 +08:00
jiang c2ccb7bc4e 移除实时数据和仿真结果接口,优化代码结构 2026-05-26 18:49:25 +08:00
jiang 88be97ddeb 修正单元测试失败代码 2026-05-25 17:51:45 +08:00
jiang 2317f4d527 新增 API 测试用例,修复失效接口问题 2026-05-21 15:32:12 +08:00
jiang 751950e5b5 调整函数说明 2026-05-20 11:45:01 +08:00
jiang a1dcbd4230 更新 dockerfile,提高打包效率 2026-04-30 13:06:09 +08:00
jiang 3b712ea467 优化传感器布置算法,修复数据库更新逻辑 2026-04-17 17:21:50 +08:00
jiang bf2aaa5ff7 后端统一时区为 UTC 2026-04-14 14:46:51 +08:00
jiang 51b481d174 优化临时文件管理,增强错误日志记录 2026-04-08 11:47:46 +08:00
jiang 644babf77e 将环境设置为生产模式;更新网络名称配置 2026-04-08 10:49:01 +08:00
jiang 6b09c6b20d 删除 Dockerfile 中的临时文件复制指令 2026-04-03 14:53:55 +08:00
jiang 93cbd7e7b3 独立 copilot 服务 2026-03-27 13:52:12 +08:00
jiang 0196206ed3 创建层级化目录的 skills 2026-03-27 13:05:22 +08:00
jiang 88eec2787b 整理 api tags 2026-03-27 12:31:52 +08:00
jiang 621cd9d2f9 删除 router 中多余的tags 2026-03-26 16:09:17 +08:00
jiang 600ddd329c 添加流式 Copilot 请求处理及审计中间件优化 2026-03-24 16:01:22 +08:00
jiang c184610035 添加 Copilot 聊天流式响应接口及测试 2026-03-24 11:22:00 +08:00
jiang 21dd393aee 添加 Copilot 聊天流式响应功能及相关配置 2026-03-23 18:03:00 +08:00
jiang b0acfb21ec DSN 复用已有连接池 2026-03-19 11:16:36 +08:00
jiang 20ec7d9c8d 添加加密文本输入提示示例 2026-03-19 09:43:22 +08:00
jiang 7c44654195 实现多进程 epanet 模拟,不保留临时文件 2026-03-18 16:56:44 +08:00
jiang c5d3075ae2 添加获取项目信息接口及相关数据模型 2026-03-17 18:27:58 +08:00
jiang 2ea5ce14ba 修正 Dockerfile 中 pip 安装命令的参数,使用 uv 加快部署 2026-03-17 15:56:06 +08:00
jiang adb5dc01fb 强制使用 utf-8 存取 2026-03-16 15:59:09 +08:00
jiang fb9f3217e2 添加清理 pycache 和编译扩展的功能 2026-03-16 15:59:09 +08:00
jiang 5e8600a0a7 添加对单个python文件的编译支持 2026-03-16 15:59:09 +08:00
jiang 1dcaf5ae9f 更新引用路径 2026-03-16 15:59:09 +08:00
jiang a792838e80 移除中文注释,避免 Github Action 工作流出错 2026-03-16 15:59:09 +08:00
jiang 3cd76b9b52 优化打包流程,增加编译路径和忽略规则 2026-03-16 15:59:09 +08:00
jiang e6d00e9bc6 更新构建工作流,删除不必要的安全脚本 2026-03-16 15:59:09 +08:00
jiang 68c12cc4eb 添加构建和打包工作流 2026-03-16 15:59:09 +08:00
jiang e0c247f3b2 更新封装路径 2026-03-16 15:59:09 +08:00
jiang c3bf48499b 删除旧文件 2026-03-16 15:59:09 +08:00
jiang 102cfffefe 更改编译代码的文件名 2026-03-16 15:59:09 +08:00
jiang 1a76c89054 更新metadb引用路径 2026-03-16 15:59:09 +08:00
jiang 1673396e1a 重构时序数据库连接逻辑,移除冗余代码 2026-03-16 15:59:09 +08:00
jiang c137adedad 元数据库目录结构变更 2026-03-16 15:59:09 +08:00
jiang 5041922c84 移除未使用的区域相关函数导入 2026-03-16 15:59:09 +08:00
jiang cfe69e581b 更新API请求体,移除不必要的请求参数 2026-03-16 15:59:09 +08:00
jiang b513d05611 优化API文档,添加参数描述和示例 2026-03-16 15:56:37 +08:00
jiang 9a8d851275 删除网络元素相关的API端点空文件 2026-03-16 15:53:18 +08:00
jiang 50a1e78073 移除旧的InternalQueries类,更新管道查询逻辑 2026-03-16 15:53:18 +08:00
jiang 83a6143146 重构SCADA信息获取,移除旧的数据库接口 2026-03-16 15:53:18 +08:00
Huarch 9aa0646bc6 Merge pull request #2 from OrgTJWater/refactor/app-structure
Refactor/app structure
2026-03-12 18:18:34 +08:00
jiang d34c61a051 更新环境配置,调整数据库用户及密码 2026-03-12 18:15:36 +08:00
jiang baf899eaeb 更新环境配置以控制文档启用状态 2026-03-11 17:57:47 +08:00
jiang 72d642fcf6 删除、移动旧文档 2026-03-11 17:49:37 +08:00
jiang 4ea0b8f05b 为爆管侦测模块新增模拟方案支持及相关参数 2026-03-11 16:15:02 +08:00
jiang aa68bc73ca 固定scikit-learn和scipy版本 2026-03-11 10:51:47 +08:00
jiang bef1c74782 新增爆管侦测功能及相关API接口 2026-03-11 10:31:24 +08:00
jiang 90216a762a 新增uv安装模式 2026-03-11 09:57:29 +08:00
jiang 559d5bb8e3 app/infra/db中 router 迁移并更新,清理 infra 层的旧 router 2026-03-09 18:20:46 +08:00
jiang 7345210bdd 修复引用错误 2026-03-09 18:11:24 +08:00
jiang 0d8a7f5cb7 目录重命名:timescaledb/schemas/ → timescaledb/repositories/ 2026-03-09 18:10:14 +08:00
jiang efeca41cbd 删除旧文件 2026-03-09 17:54:07 +08:00
jiang 8c7d77e6ee 将 from app.services.tjnetwork import * 改为显式导入 2026-03-09 17:51:12 +08:00
jiang c946e1b58b 补充 __init__.py 导出;将 from app.services.tjnetwork import * 改为显式导入;删除以下仅做 @staticmethod 转发的类,保留模块级函数 2026-03-09 17:45:20 +08:00
jiang 0b72ac959a 重构 app/algorithms/api_ex 目录结构 2026-03-09 17:26:39 +08:00
jiang 48f836d667 为预留的空文件夹添加结构功能说明 2026-03-09 16:31:37 +08:00
jiang 6eec6c04de 调整 epanet 从 services 迁到 infra 2026-03-09 16:11:29 +08:00
jiang 61d540356d 删除暂不使用的 mcp 文件夹 2026-03-09 16:06:00 +08:00
jiang eb1d9cce56 调整调用的最大进程数;删除wndb的封装文件 2026-03-09 16:05:31 +08:00
jiang 78978c6931 优化 app/native/wndb/__init__.py 按域分组导入和注释 2026-03-09 15:24:08 +08:00
jiang 747b4cd229 补全 services/tjnetwork.py 的 Facade 覆盖,把绕过的 8 处直接引用都收归到 tjnetwork.py 导出 2026-03-09 14:45:30 +08:00
jiang ed1eb74cfb 将 postgresql_info.py 移出 native/,合并到 core/config.py 或 infra/db/,便于后续项目环境变量读取发生变化 2026-03-09 14:41:50 +08:00
jiang 20ab08e206 将 native/api/ 改名为 native/wndb/,避免与 Web API 层命名冲突 2026-03-09 12:13:27 +08:00
jiang 6b85cfc666 更新文档 2026-03-09 11:30:22 +08:00
jiang a56e041cfc 更新文档 2026-03-09 11:30:05 +08:00
jiang f9111ab9c1 减少爆管定位代码中引入的不确定性 2026-03-09 11:29:57 +08:00
jiang d55e23bc44 把所有 list(set(...)) 改为 sorted(set(...)),确保去重后顺序稳定 2026-03-08 21:05:57 +08:00
jiang b3d58379ef 修复find_new_center_pipe中心点代码错误的bug 2026-03-08 20:45:22 +08:00
jiang 9a4a91c328 重构爆管定位算法,增加多进程支持与可视化功能 2026-03-08 20:01:21 +08:00
jiang a7e3b6aff9 增加 wn_inp_path 参数以支持多进程处理 2026-03-07 15:34:40 +08:00
jiang 05ca940c9f 优化爆管定位算法,增加多进程支持 2026-03-07 15:31:04 +08:00
jiang 0f8d33291d 重构管道中心选择逻辑,优化数据处理方式 2026-03-07 15:23:05 +08:00
jiang 143b918b86 优化压力泄漏标准差计算方式 2026-03-07 15:21:33 +08:00
jiang 7ff28893a1 优化管道权重处理,增加非有限权重检查 2026-03-07 15:11:49 +08:00
jiang b9d9cef5ef 修复管道加权计算逻辑bug,优化邻接关系处理 2026-03-07 15:04:08 +08:00
jiang 0c6c27a0c1 重构监测逻辑,优化 SCADA 数据处理 2026-03-07 15:02:36 +08:00
jiang f5a7e5b3c9 重构爆管定位请求,移除不必要的时间参数 2026-03-07 14:25:23 +08:00
jiang 78a57f5c56 重构爆管定位逻辑,更新实时数据源处理 2026-03-07 13:54:28 +08:00
jiang 7f481ca261 新增模拟数据源支持,重构爆管定位逻辑 2026-03-07 10:50:25 +08:00
jiang bc74e94fbb 重构爆管定位相关功能,优化输入验证与API接口 2026-03-06 16:19:14 +08:00
jiang b83b895e2b 新增爆管位置检测模块及相关API接口 2026-03-06 15:27:59 +08:00
jiang 63d3458fb4 优化漏损识别器,支持多进程评估 2026-03-05 18:18:28 +08:00
jiang b8aee14c00 重构漏损识别请求,添加用户验证和输入准备 2026-03-04 17:23:01 +08:00
jiang 340808e85e 添加审计中间件排除路径、用户按用户名查询功能;完善审计资源记录 2026-03-04 16:06:41 +08:00
jiang 2464c7f612 完善agent-insturction 2026-03-04 16:04:39 +08:00
jiang 61f6975296 完善区域漏损识别 2026-03-04 15:21:31 +08:00
jiang d0abad3c65 使用pymoo实现遗传算法 2026-03-03 16:29:59 +08:00
jiang e7a3aec02f 添加native.api源码;临时处理run_simulation中iot数据库name的判断 2026-03-03 09:47:13 +08:00
jiang 1d662f973a 允许所有来源 2026-02-27 18:27:52 +08:00
jiang 5566172e26 删除env.local;新增漏损区域识别功能 2026-02-27 17:37:39 +08:00
jiang df76e40b0a 更新readme文档 2026-02-27 17:34:09 +08:00
jiang 2e479868f8 修复audit_logs记录时主键缺失的问题 2026-02-27 17:33:58 +08:00
420 changed files with 33094 additions and 19866 deletions
+8 -37
View File
@@ -1,6 +1,7 @@
# TJWater Server 环境变量配置模板
# 复制此文件为 .env 并填写实际值
NETWORK_NAME="szh"
ENVIRONMENT="production"
NETWORK_NAME="tjwater"
# ============================================
# 安全配置 (必填)
# ============================================
@@ -20,17 +21,17 @@ DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
DB_NAME="tjwater"
DB_HOST="localhost"
DB_PORT="5432"
DB_USER="postgres"
DB_USER="tjwater"
DB_PASSWORD="password"
# ============================================
# 数据库配置 (TimescaleDB)
# ============================================
TIMESCALEDB_DB_NAME="szh"
TIMESCALEDB_DB_NAME="tjwater"
TIMESCALEDB_DB_HOST="localhost"
TIMESCALEDB_DB_PORT="5433"
TIMESCALEDB_DB_USER="tjwater"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
TIMESCALEDB_DB_PASSWORD="password"
# ============================================
# 元数据数据库配置 (Metadata DB)
@@ -41,39 +42,9 @@ METADATA_DB_PORT="5432"
METADATA_DB_USER="tjwater"
METADATA_DB_PASSWORD="password"
# ============================================
# 项目连接缓存与连接池配置
# ============================================
PROJECT_PG_CACHE_SIZE=50
PROJECT_TS_CACHE_SIZE=50
PROJECT_PG_POOL_SIZE=5
PROJECT_PG_MAX_OVERFLOW=10
PROJECT_TS_POOL_MIN_SIZE=1
PROJECT_TS_POOL_MAX_SIZE=10
# ============================================
# InfluxDB 配置 (时序数据)
# ============================================
# INFLUXDB_URL=http://localhost:8086
# INFLUXDB_TOKEN=your-influxdb-token
# INFLUXDB_ORG=your-org
# INFLUXDB_BUCKET=tjwater
# ============================================
# JWT 配置 (可选)
# ============================================
# ACCESS_TOKEN_EXPIRE_MINUTES=30
# REFRESH_TOKEN_EXPIRE_DAYS=7
# ALGORITHM=HS256
# ============================================
# Keycloak JWT (可选)
# ============================================
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
# KEYCLOAK_ALGORITHM=RS256
# ============================================
# 其他配置
# ============================================
# PROJECT_NAME=TJWater Server
# API_V1_STR=/api/v1
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
KEYCLOAK_ALGORITHM=RS256
KEYCLOAK_AUDIENCE="account"
-23
View File
@@ -1,23 +0,0 @@
NETWORK_NAME="tjwater"
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
KEYCLOAK_ALGORITHM="RS256"
KEYCLOAK_AUDIENCE="account"
DB_NAME="tjwater"
DB_HOST="192.168.1.114"
DB_PORT="5432"
DB_USER="tjwater"
DB_PASSWORD="Tjwater@123456"
TIMESCALEDB_DB_NAME="tjwater"
TIMESCALEDB_DB_HOST="192.168.1.114"
TIMESCALEDB_DB_PORT="5433"
TIMESCALEDB_DB_USER="tjwater"
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
METADATA_DB_NAME="system_hub"
METADATA_DB_HOST="192.168.1.114"
METADATA_DB_PORT="5432"
METADATA_DB_USER="tjwater"
METADATA_DB_PASSWORD="Tjwater@123456"
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
+63 -179
View File
@@ -1,198 +1,82 @@
# TJWater Server - Copilot Instructions
# Copilot Instructions for TJWater Server
This is a FastAPI-based water network management system (供水管网智能管理系统) that provides hydraulic simulation, SCADA data integration, network element management, and risk analysis capabilities.
This repository contains the backend code for the TJWater Server, a water distribution network management system built with FastAPI.
## Running the Server
## High-Level Architecture
The application follows a layered architecture:
- **Entry Point**: `app/main.py` initializes the FastAPI application, database connections (PostgreSQL & TimescaleDB), and middleware.
- **API Layer**: `app/api/v1` contains the route handlers.
- **Service Layer**: `app/services` contains business logic and orchestration.
- **Infrastructure Layer**: `app/infra` handles database connections (`db`), audit logging (`audit`), and external integrations.
- **Domain Layer**: `app/domain` likely contains core domain models.
- **Native/Algorithms**: `app/native` and `app/algorithms` handle specialized water network calculations (possibly using EPANET/WNTR).
## Build, Test, and Run Commands
### Environment Setup
- Dependencies are listed in `requirements.txt`.
- Configuration is managed via environment variables (see `.env.example` if available, or `app/core/config.py`).
- **Important**: Ensure `.env` is configured with correct database credentials for both PostgreSQL and TimescaleDB.
If first time setting up, you may want to create a Conda environment:
```bash
# Install dependencies
pip install -r requirements.txt
# Start the server (default: http://0.0.0.0:8000 with 2 workers)
python scripts/run_server.py
# Note: On Windows, the script automatically sets WindowsSelectorEventLoopPolicy
conda create -n server python=3.12
conda activate server
pip install uv
uv pip install -r requirements.txt
conda install -c conda-forge pymetis
```
## Running Tests
### Running the Server
The preferred way to run the server locally is using the helper script which sets up the Python path correctly:
```bash
conda activate server
python scripts/run_server.py
```
Alternatively, you can run directly with uvicorn (ensure PYTHONPATH includes the root):
```bash
conda activate server
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
```
### Running Tests
Use `pytest` to run tests. The `tests/conftest.py` handles path setup.
```bash
# Run all tests
pytest
# Run a specific test file with verbose output
pytest tests/unit/test_pipeline_health_analyzer.py -v
# Run a specific test file
pytest tests/unit/test_specific_file.py
# Run from conftest helper
python tests/conftest.py
# Run a specific test case
pytest tests/unit/test_specific_file.py::test_function_name
```
## Architecture Overview
### Building (Optional)
### Core Components
The project includes scripts to compile Python modules to `.pyd` files using Cython (see `scripts/build_pyd.py`). This is likely for distribution/performance but not required for standard development.
1. **Native Modules** (`app/native/`): Platform-specific compiled extensions (`.so` for Linux, `.pyd` for Windows) providing performance-critical functionality including:
- SCADA device integration
- Water distribution analysis (WDA)
- Pipe risk probability calculations
- Wrapped through `app.services.tjnetwork` interface
## Key Conventions
2. **Services Layer** (`app/services/`):
- `tjnetwork.py`: Main network API wrapper around native modules
- `simulation.py`: Hydraulic simulation orchestration (EPANET integration)
- `project_info.py`: Project configuration management
- `epanet/`: EPANET hydraulic engine integration
- **Async/Await**: The codebase heavily uses `async` and `await` for I/O operations, especially database interactions.
- **Database Management**:
- Connections are managed globally in `app.infra.db` and initialized in `lifespan` (app/main.py).
- Use `app.infra.db.dynamic_manager` for project-specific database connections (multi-tenancy/dynamic projects).
- **Pydantic**: extensively used for data validation and settings management.
- **Scripts**: The `scripts/` directory contains many utility scripts for maintenance, data processing, and server management. Check there before writing new operational scripts.
- **Water Network Modeling**: Interactions with water network models often involve `epanet` or `wntr` libraries. Be aware of domain-specific terminology (nodes, links, junctions, tanks).
3. **API Layer** (`app/api/v1/`):
- **Network Elements**: Separate endpoint modules for junctions, reservoirs, tanks, pipes, pumps, valves
- **Components**: Curves, patterns, controls, options, quality, visuals
- **Network Features**: Tags, demands, geometry, regions/DMAs
- **Core Services**: Auth, project, simulation, SCADA, data query, snapshots
## Code Style
4. **Database Infrastructure** (`app/infra/db/`):
- **PostgreSQL**: Primary relational database (users, audit logs, project metadata)
- **TimescaleDB**: Time-series extension for historical data
- **InfluxDB**: Optional time-series database for high-frequency SCADA data
- Connection pools initialized in `main.py` lifespan context
- Database instance stored in `app.state.db` for dependency injection
5. **Domain Layer** (`app/domain/`):
- `models/`: Enums and domain objects (e.g., `UserRole`)
- `schemas/`: Pydantic models for request/response validation
6. **Algorithms** (`app/algorithms/`):
- `api_ex/`: Analysis algorithms (k-means sensor placement, sensitivity analysis, pipeline health)
- `data_cleaning.py`: Data preprocessing utilities
- `simulations.py`: Simulation helpers
### Security & Authentication
- **Authentication**: JWT-based with access tokens (30 min) and refresh tokens (7 days)
- **Authorization**: Role-based access control (RBAC) with 4 roles:
- `VIEWER`: Read-only access
- `USER`: Read-write access
- `OPERATOR`: Modify data
- `ADMIN`: Full permissions
- **Audit Logging**: `AuditMiddleware` automatically logs POST/PUT/DELETE requests
- **Encryption**: Fernet symmetric encryption for sensitive data (`app.core.encryption`)
Default admin accounts:
- `admin` / `admin123`
- `tjwater` / `tjwater@123`
### Key Files
- `app/main.py`: FastAPI app initialization, lifespan (DB pools), CORS, middleware, router mounting
- `app/api/v1/router.py`: Central router aggregating all endpoint modules
- `app/core/config.py`: Settings management using `pydantic-settings`
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
## Important Conventions
### Database Connections
- Database instances are initialized in `main.py` lifespan and stored in `app.state.db`
- Access via dependency injection:
```python
from app.auth.dependencies import get_db
async def endpoint(db = Depends(get_db)):
# Use db connection
```
### Authentication in Endpoints
Use dependency injection for auth requirements:
```python
from app.auth.dependencies import get_current_active_user
from app.auth.permissions import require_role, get_current_admin
from app.domain.models.role import UserRole
# Require any authenticated user
@router.get("/data")
async def get_data(current_user = Depends(get_current_active_user)):
return data
# Require specific role (USER or higher)
@router.post("/data")
async def create_data(current_user = Depends(require_role(UserRole.USER))):
return result
# Admin-only access
@router.delete("/data/{id}")
async def delete_data(id: int, current_user = Depends(get_current_admin)):
return result
```
### API Routing Structure
- All v1 APIs are mounted under `/api/v1` prefix via `api_router`
- Legacy routes without version prefix are also mounted for backward compatibility
- Group related endpoints in separate router modules under `app/api/v1/endpoints/`
- Use descriptive tags in `router.py` for OpenAPI documentation grouping
### Native Module Integration
- Native modules are pre-compiled for specific platforms
- Always import through `app.native.api` or `app.services.tjnetwork`
- The `tjnetwork` service wraps native APIs with constants like:
- Element types: `JUNCTION`, `RESERVOIR`, `TANK`, `PIPE`, `PUMP`, `VALVE`
- Operations: `API_ADD`, `API_UPDATE`, `API_DELETE`
- `ChangeSet` for batch operations
### Project Initialization
- On startup, `main.py` automatically loads project from `project_info.name` if set
- Projects are opened via `open_project(name)` from `tjnetwork` service
### Audit Logging
Manual audit logging for critical operations:
```python
from app.core.audit import log_audit_event, AuditAction
await log_audit_event(
action=AuditAction.UPDATE,
user_id=current_user.id,
username=current_user.username,
resource_type="resource_name",
resource_id=str(resource_id),
ip_address=request.client.host,
request_data=data
)
```
### Environment Configuration
- Copy `.env.example` to `.env` before first run
- Required environment variables:
- `SECRET_KEY`: JWT signing (generate with `openssl rand -hex 32`)
- `ENCRYPTION_KEY`: Data encryption (generate with Fernet)
- Database credentials for PostgreSQL, TimescaleDB, and optionally InfluxDB
### Database Migrations
SQL migration scripts are in `migrations/`:
- `001_create_users_table.sql`: User authentication tables
- `002_create_audit_logs_table.sql`: Audit logging tables
Apply with:
```bash
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
```
## API Documentation
- Swagger UI: http://localhost:8000/docs
- ReDoc: http://localhost:8000/redoc
- OpenAPI schema: http://localhost:8000/openapi.json
## Additional Resources
- `SECURITY_README.md`: Comprehensive security feature documentation
- `DEPLOYMENT.md`: Integration guide for security features
- `readme.md`: Project overview and directory structure (in Chinese)
- Follow standard PEP 8 guidelines.
- No specific linter configuration was found, so default to standard Python formatting.
+128
View File
@@ -0,0 +1,128 @@
name: Build And Package
on:
push:
tags:
- "v*"
jobs:
build-package:
runs-on: ${{ matrix.os }}
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
steps:
- name: Checkout source
uses: actions/checkout@v5
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Install system build tools
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Install compile dependencies
run: |
python -m pip install --upgrade pip
pip install cython setuptools wheel
- name: Run Cython compile
run: |
python scripts/compile.py
- name: Prepare package and archive
run: |
python - <<'PY'
import os
import shutil
import tarfile
import zipfile
import sys
from pathlib import Path
root = Path.cwd()
package_dir = root / "package"
dist_dir = root / "dist"
for d in [package_dir, dist_dir]:
if d.exists():
shutil.rmtree(d)
d.mkdir(parents=True, exist_ok=True)
# Define directories with compiled artifacts
compile_dirs = ["app/services", "app/native/wndb", "app/algorithms"]
# Global ignore list
ignore_names = {
".git",
".github",
"__pycache__",
".pytest_cache",
".mypy_cache",
".venv",
"venv",
"temp",
"tests",
"package",
"dist",
}
def ignore_func(directory, names):
rel_dir = os.path.relpath(directory, root).replace("\\", "/")
is_in_compile_path = any(rel_dir.startswith(d) for d in compile_dirs)
ignored = []
for name in names:
if name in ignore_names or name.endswith(".pyc"):
ignored.append(name)
# Exclude source .py files only in compiled directories
elif is_in_compile_path and name.endswith(".py"):
ignored.append(name)
return ignored
for item in root.iterdir():
if item.name in ignore_names:
continue
target = package_dir / item.name
if item.is_dir():
shutil.copytree(item, target, ignore=ignore_func)
else:
shutil.copy2(item, target)
# Safety guard: ensure no .github directory remains
github_paths = [p for p in package_dir.rglob(".github") if p.is_dir()]
for p in github_paths:
shutil.rmtree(p, ignore_errors=True)
sha = os.environ["GITHUB_SHA"]
run_os = os.environ["RUNNER_OS"].lower()
if run_os == "windows":
archive_path = dist_dir / f"tjwater-server-{run_os}-{sha}.zip"
with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for f in package_dir.rglob("*"):
if f.is_file():
zf.write(f, f.relative_to(package_dir))
else:
archive_path = dist_dir / f"tjwater-server-{run_os}-{sha}.tar.gz"
with tarfile.open(archive_path, "w:gz") as tf:
tf.add(package_dir, arcname=".")
print(f"Archive created: {archive_path}")
PY
shell: bash
- name: Upload package artifact
uses: actions/upload-artifact@v5
with:
name: tjwater-server-package-${{ runner.os }}
path: dist/*
retention-days: 14
+2 -1
View File
@@ -5,5 +5,6 @@ build/
*.pyc
.env
*.dump
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
.vscode/
app/algorithms/health/model/my_survival_forest_model_quxi.joblib
inp/
-391
View File
@@ -1,391 +0,0 @@
# 部署和集成指南
本文档说明如何将新的安全功能集成到现有系统中。
## 📦 已完成的功能
### 1. 数据加密模块
-`app/core/encryption.py` - Fernet 对称加密实现
- ✅ 支持敏感数据加密/解密
- ✅ 密钥管理和生成工具
### 2. 用户认证系统
-`app/domain/models/role.py` - 用户角色枚举 (ADMIN/OPERATOR/USER/VIEWER)
-`app/domain/schemas/user.py` - 用户数据模型和验证
-`app/infra/repositories/user_repository.py` - 用户数据访问层
-`app/api/v1/endpoints/auth.py` - 注册/登录/刷新Token接口
-`app/auth/dependencies.py` - 认证依赖项
-`migrations/001_create_users_table.sql` - 用户表迁移脚本
### 3. 权限控制系统
-`app/auth/permissions.py` - RBAC 权限控制装饰器
-`app/api/v1/endpoints/user_management.py` - 用户管理接口示例
- ✅ 支持基于角色的访问控制
- ✅ 支持资源所有者检查
### 4. 审计日志系统
-`app/core/audit.py` - 审计日志核心功能
-`app/domain/schemas/audit.py` - 审计日志数据模型
-`app/infra/repositories/audit_repository.py` - 审计日志数据访问层
-`app/api/v1/endpoints/audit.py` - 审计日志查询接口
-`app/infra/audit/middleware.py` - 自动审计中间件
-`migrations/002_create_audit_logs_table.sql` - 审计日志表迁移脚本
### 5. 文档和测试
-`SECURITY_README.md` - 完整的使用文档
-`.env.example` - 环境变量配置模板
-`tests/test_encryption.py` - 加密功能测试
---
## 🔧 集成步骤
### 步骤 1: 环境配置
1. 复制环境变量模板:
```bash
cp .env.example .env
```
2. 生成密钥并填写 `.env`
```bash
# JWT 密钥
openssl rand -hex 32
# 加密密钥
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
```
3. 编辑 `.env` 填写所有必需的配置项。
### 步骤 2: 数据库迁移
执行数据库迁移脚本:
```bash
# 方法 1: 使用 psql 命令
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
# 方法 2: 在 psql 交互界面
psql -U postgres -d tjwater
\i migrations/001_create_users_table.sql
\i migrations/002_create_audit_logs_table.sql
```
验证表已创建:
```sql
-- 检查用户表
SELECT * FROM users;
-- 检查审计日志表
SELECT * FROM audit_logs;
```
### 步骤 3: 更新 main.py
`app/main.py` 中集成新功能:
```python
from fastapi import FastAPI
from app.core.config import settings
from app.infra.audit.middleware import AuditMiddleware
app = FastAPI(title=settings.PROJECT_NAME)
# 1. 添加审计中间件(可选)
app.add_middleware(AuditMiddleware)
# 2. 注册路由
from app.api.v1.endpoints import auth, user_management, audit
app.include_router(
auth.router,
prefix=f"{settings.API_V1_STR}/auth",
tags=["认证"]
)
app.include_router(
user_management.router,
prefix=f"{settings.API_V1_STR}/users",
tags=["用户管理"]
)
app.include_router(
audit.router,
prefix=f"{settings.API_V1_STR}/audit",
tags=["审计日志"]
)
# 3. 确保数据库在启动时初始化
@app.on_event("startup")
async def startup_event():
# 初始化数据库连接池
from app.infra.db.postgresql.database import Database
global db
db = Database()
db.init_pool()
await db.open()
@app.on_event("shutdown")
async def shutdown_event():
# 关闭数据库连接
await db.close()
```
### 步骤 4: 保护现有接口
#### 方法 1: 为路由添加全局依赖
```python
from app.auth.dependencies import get_current_active_user
# 为整个路由器添加认证
router = APIRouter(dependencies=[Depends(get_current_active_user)])
```
#### 方法 2: 为单个端点添加依赖
```python
from app.auth.permissions import require_role, get_current_admin
from app.domain.models.role import UserRole
@router.get("/data")
async def get_data(
current_user = Depends(require_role(UserRole.USER))
):
"""需要 USER 及以上角色"""
return {"data": "protected"}
@router.delete("/data/{id}")
async def delete_data(
id: int,
current_user = Depends(get_current_admin)
):
"""仅管理员可访问"""
return {"message": "deleted"}
```
### 步骤 5: 添加审计日志
#### 自动审计(推荐)
使用中间件自动记录(已在 main.py 中添加):
```python
app.add_middleware(AuditMiddleware)
```
#### 手动审计
在关键业务逻辑中手动记录:
```python
from app.core.audit import log_audit_event, AuditAction
@router.post("/important-action")
async def important_action(
data: dict,
request: Request,
current_user = Depends(get_current_active_user)
):
# 执行业务逻辑
result = do_something(data)
# 记录审计日志
await log_audit_event(
action=AuditAction.UPDATE,
user_id=current_user.id,
username=current_user.username,
resource_type="important_resource",
resource_id=str(result.id),
ip_address=request.client.host,
request_data=data
)
return result
```
### 步骤 6: 更新 auth/dependencies.py
确保 `get_db()` 函数正确获取数据库实例:
```python
async def get_db() -> Database:
"""获取数据库实例"""
# 方法 1: 从 main.py 导入
from app.main import db
return db
# 方法 2: 从 FastAPI app.state 获取
# from fastapi import Request
# def get_db_from_request(request: Request):
# return request.app.state.db
```
---
## 🧪 测试
### 1. 测试加密功能
```bash
python tests/test_encryption.py
```
### 2. 测试 API
启动服务器:
```bash
uvicorn app.main:app --reload
```
访问交互式文档:
- Swagger UI: http://localhost:8000/docs
- ReDoc: http://localhost:8000/redoc
### 3. 测试登录
```bash
curl -X POST "http://localhost:8000/api/v1/auth/login" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "username=admin&password=admin123"
```
### 4. 测试受保护接口
```bash
TOKEN="your-access-token"
curl -X GET "http://localhost:8000/api/v1/auth/me" \
-H "Authorization: Bearer $TOKEN"
```
---
## 🔄 迁移现有接口
### 原有硬编码认证
**旧代码** (`app/api/v1/endpoints/auth.py`):
```python
AUTH_TOKEN = "567e33c876a2"
async def verify_token(authorization: str = Header()):
token = authorization.split(" ")[1]
if token != AUTH_TOKEN:
raise HTTPException(status_code=403)
```
**新代码** (已更新):
```python
from app.auth.dependencies import get_current_active_user
@router.get("/protected")
async def protected_route(
current_user = Depends(get_current_active_user)
):
return {"user": current_user.username}
```
### 更新其他端点
搜索项目中使用旧认证的地方:
```bash
grep -r "AUTH_TOKEN" app/
grep -r "verify_token" app/
```
替换为新的依赖注入系统。
---
## 📋 检查清单
部署前检查:
- [ ] 环境变量已配置(`.env`
- [ ] 数据库迁移已执行
- [ ] 默认管理员账号可登录
- [ ] JWT Token 可正常生成和验证
- [ ] 权限控制正常工作
- [ ] 审计日志正常记录
- [ ] 加密功能测试通过
- [ ] API 文档可访问
---
## ⚠️ 注意事项
### 1. 向后兼容性
保留了简化版登录接口 `/auth/login/simple` 以兼容旧客户端:
```python
@router.post("/login/simple")
async def login_simple(username: str, password: str):
# 验证并返回 Token
...
```
### 2. 数据库连接
确保在 `app/auth/dependencies.py``get_db()` 函数能正确获取数据库实例。
### 3. 密钥安全
- ❌ 不要提交 `.env` 文件到版本控制
- ✅ 在生产环境使用环境变量或密钥管理服务
- ✅ 定期轮换 JWT 密钥
### 4. 性能考虑
- 审计中间件会增加每个请求的处理时间(约 5-10ms)
- 对高频接口可考虑异步记录审计日志
- 定期清理或归档旧的审计日志
---
## 🐛 故障排查
### 问题 1: 导入错误
```
ImportError: cannot import name 'db' from 'app.main'
```
**解决**: 确保在 `app/main.py` 中定义了全局 `db` 对象。
### 问题 2: 认证失败
```
401 Unauthorized: Could not validate credentials
```
**检查**:
1. Token 是否正确设置在 `Authorization: Bearer {token}` header
2. Token 是否过期
3. SECRET_KEY 是否配置正确
### 问题 3: 数据库连接失败
```
psycopg.OperationalError: connection failed
```
**检查**:
1. PostgreSQL 是否运行
2. `.env` 中数据库配置是否正确
3. 数据库是否存在
---
## 📞 技术支持
详细文档请参考:
- `SECURITY_README.md` - 安全功能使用指南
- `migrations/` - 数据库迁移脚本
- `app/domain/schemas/` - 数据模型定义
+5 -5
View File
@@ -1,19 +1,19 @@
FROM continuumio/miniconda3:latest
FROM condaforge/miniforge3:latest
WORKDIR /app
# 安装 Python 3.12 和 pymetis (通过 conda-forge 避免编译问题)
RUN conda install -y -c conda-forge python=3.12 pymetis && \
conda clean -afy
RUN mamba install -y python=3.12 pymetis && \
mamba clean -afy
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install uv
RUN uv pip install --system --no-cache-dir -r requirements.txt
# 将代码放入子目录 'app',将数据放入子目录 'db_inp'
# 这样临时文件默认会生成在 /app 下,而代码在 /app/app 下,实现了分离
COPY app ./app
COPY db_inp ./db_inp
COPY temp ./temp
COPY .env .
# 设置 PYTHONPATH 以便 uvicorn 找到 app 模块
-322
View File
@@ -1,322 +0,0 @@
# API 集成检查清单
## ✅ 已完成的集成工作
### 1. 路由集成 (app/api/v1/router.py)
已添加以下路由到 API Router
```python
# 新增导入
from app.api.v1.endpoints import (
...
user_management, # 用户管理
audit, # 审计日志
)
# 新增路由
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"])
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"])
```
**路由端点**
- `/api/v1/auth/` - 认证相关(register, login, me, refresh
- `/api/v1/users/` - 用户管理(CRUD操作,仅管理员)
- `/api/v1/audit/` - 审计日志查询(仅管理员)
### 2. 主应用配置 (app/main.py)
#### 2.1 导入更新
```python
from app.core.config import settings
from app.infra.audit.middleware import AuditMiddleware
```
#### 2.2 数据库初始化
```python
# 在 lifespan 中存储数据库实例到 app.state
app.state.db = pgdb
```
#### 2.3 FastAPI 配置
```python
app = FastAPI(
lifespan=lifespan,
title=settings.PROJECT_NAME,
description="TJWater Server - 供水管网智能管理系统",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
)
```
#### 2.4 审计中间件(可选)
```python
# 取消注释以启用审计日志
# app.add_middleware(AuditMiddleware)
```
### 3. 依赖项更新 (app/auth/dependencies.py)
更新 `get_db()` 函数从 Request 对象获取数据库:
```python
async def get_db(request: Request) -> Database:
"""从 app.state 获取数据库实例"""
if not hasattr(request.app.state, "db"):
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database not initialized"
)
return request.app.state.db
```
### 4. 审计日志更新
- `app/api/v1/endpoints/audit.py` - 使用正确的数据库依赖
- `app/core/audit.py` - 接受可选的 db 参数
---
## 📋 部署前检查清单
### 环境配置
- [ ] 复制 `.env.example``.env`
- [ ] 配置 `SECRET_KEY`JWT密钥)
- [ ] 配置 `ENCRYPTION_KEY`(数据加密密钥)
- [ ] 配置数据库连接信息
### 数据库迁移
- [ ] 执行用户表迁移:`psql -U postgres -d tjwater -f migrations/001_create_users_table.sql`
- [ ] 执行审计日志表迁移:`psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql`
- [ ] 验证表已创建:`\dt` 在 psql 中
### 依赖检查
- [ ] 确认已安装:`cryptography`
- [ ] 确认已安装:`python-jose[cryptography]`
- [ ] 确认已安装:`passlib[bcrypt]`
- [ ] 确认已安装:`email-validator`(用于 Pydantic email 验证)
### 代码验证
- [ ] 检查所有文件导入正常
- [ ] 运行加密功能测试:`python tests/test_encryption.py`
- [ ] 启动服务器:`uvicorn app.main:app --reload`
- [ ] 访问 API 文档:http://localhost:8000/docs
### API 测试
- [ ] 测试登录:POST `/api/v1/auth/login`
- [ ] 测试获取当前用户:GET `/api/v1/auth/me`
- [ ] 测试用户列表(需管理员):GET `/api/v1/users/`
- [ ] 测试审计日志(需管理员):GET `/api/v1/audit/logs`
---
## 🔧 快速测试命令
### 1. 生成密钥
```bash
# JWT 密钥
openssl rand -hex 32
# 加密密钥
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
```
### 2. 执行迁移
```bash
cd /home/zhifu/TJWaterServer/TJWaterServerBinary
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
```
### 3. 测试加密
```bash
python tests/test_encryption.py
```
### 4. 启动服务器
```bash
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
### 5. 测试登录 API
```bash
# 使用默认管理员账号
curl -X POST "http://localhost:8000/api/v1/auth/login" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "username=admin&password=admin123"
# 或使用迁移的账号
curl -X POST "http://localhost:8000/api/v1/auth/login" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "username=tjwater&password=tjwater@123"
```
### 6. 测试受保护接口
```bash
# 保存 Token
TOKEN="<从登录响应中获取的 access_token>"
# 获取当前用户信息
curl -X GET "http://localhost:8000/api/v1/auth/me" \
-H "Authorization: Bearer $TOKEN"
# 获取用户列表(需管理员权限)
curl -X GET "http://localhost:8000/api/v1/users/" \
-H "Authorization: Bearer $TOKEN"
# 查询审计日志(需管理员权限)
curl -X GET "http://localhost:8000/api/v1/audit/logs" \
-H "Authorization: Bearer $TOKEN"
```
---
## 📚 API 端点总览
### 认证接口 (`/api/v1/auth`)
| 方法 | 端点 | 描述 | 权限 |
|------|------|------|------|
| POST | `/register` | 用户注册 | 公开 |
| POST | `/login` | OAuth2 登录 | 公开 |
| POST | `/login/simple` | 简化登录(兼容旧版) | 公开 |
| GET | `/me` | 获取当前用户信息 | 认证用户 |
| POST | `/refresh` | 刷新 Token | 认证用户 |
### 用户管理 (`/api/v1/users`)
| 方法 | 端点 | 描述 | 权限 |
|------|------|------|------|
| GET | `/` | 获取用户列表 | 管理员 |
| GET | `/{id}` | 获取用户详情 | 所有者/管理员 |
| PUT | `/{id}` | 更新用户信息 | 所有者/管理员 |
| DELETE | `/{id}` | 删除用户 | 管理员 |
| POST | `/{id}/activate` | 激活用户 | 管理员 |
| POST | `/{id}/deactivate` | 停用用户 | 管理员 |
### 审计日志 (`/api/v1/audit`)
| 方法 | 端点 | 描述 | 权限 |
|------|------|------|------|
| GET | `/logs` | 查询审计日志 | 管理员 |
| GET | `/logs/count` | 获取日志总数 | 管理员 |
| GET | `/logs/my` | 查看我的操作记录 | 认证用户 |
---
## ⚠️ 注意事项
### 1. 审计中间件
审计中间件默认是**禁用**的。如需启用,在 `app/main.py` 中取消注释:
```python
app.add_middleware(AuditMiddleware)
```
**注意**:启用后会自动记录所有 POST/PUT/DELETE 请求,可能增加数据库负载。
### 2. 向后兼容
保留了原有的简化登录接口 `/auth/login/simple`,可以直接使用查询参数:
```bash
POST /api/v1/auth/login/simple?username=admin&password=admin123
```
### 3. 数据库连接
确保数据库实例在应用启动时正确初始化并存储到 `app.state.db`
### 4. 权限控制示例
为现有接口添加权限控制:
```python
from app.auth.permissions import require_role, get_current_admin
from app.domain.models.role import UserRole
# 需要管理员权限
@router.delete("/resource/{id}")
async def delete_resource(
id: int,
current_user = Depends(get_current_admin)
):
...
# 需要操作员以上权限
@router.post("/resource")
async def create_resource(
data: dict,
current_user = Depends(require_role(UserRole.OPERATOR))
):
...
```
---
## 🚀 完整启动流程
```bash
# 1. 进入项目目录
cd /home/zhifu/TJWaterServer/TJWaterServerBinary
# 2. 配置环境变量(如果还没有)
cp .env.example .env
# 编辑 .env 填写必要的配置
# 3. 执行数据库迁移(如果还没有)
psql -U postgres -d tjwater < migrations/001_create_users_table.sql
psql -U postgres -d tjwater < migrations/002_create_audit_logs_table.sql
# 4. 测试加密功能
python tests/test_encryption.py
# 5. 启动服务器
uvicorn app.main:app --reload
# 6. 访问 API 文档
# 浏览器打开: http://localhost:8000/docs
```
---
## 📞 故障排查
### 问题 1: 导入错误
```
ModuleNotFoundError: No module named 'jose'
```
**解决**: 安装依赖 `pip install python-jose[cryptography]`
### 问题 2: 数据库未初始化
```
503 Service Unavailable: Database not initialized
```
**解决**: 检查 `main.py` 中的 lifespan 函数是否正确设置 `app.state.db`
### 问题 3: Token 验证失败
```
401 Unauthorized: Could not validate credentials
```
**解决**:
1. 检查 SECRET_KEY 是否配置正确
2. 确认 Token 格式:`Authorization: Bearer {token}`
3. 检查 Token 是否过期
### 问题 4: 表不存在
```
relation "users" does not exist
```
**解决**: 执行数据库迁移脚本
---
## 📖 相关文档
- **使用指南**: `SECURITY_README.md`
- **部署指南**: `DEPLOYMENT.md`
- **实施总结**: `SECURITY_IMPLEMENTATION_SUMMARY.md`
- **自动设置**: `setup_security.sh`
---
**最后更新**: 2026-02-02
**状态**: ✅ API 已完全集成
-370
View File
@@ -1,370 +0,0 @@
# 安全功能实施总结
## ✅ 已完成的功能
本次实施完成了完整的安全体系,包括数据加密、身份认证、权限管理、审计日志四大模块。
---
## 📁 新增文件清单
### 核心功能模块
1. **数据加密**
- `app/core/encryption.py` - Fernet 加密实现
- `tests/test_encryption.py` - 加密功能测试
2. **用户系统**
- `app/domain/models/role.py` - 用户角色枚举
- `app/domain/schemas/user.py` - 用户数据模型
- `app/infra/repositories/user_repository.py` - 用户数据访问层
3. **认证授权**
- `app/api/v1/endpoints/auth.py` - 认证接口(已重构)
- `app/auth/dependencies.py` - 认证依赖项(已更新)
- `app/auth/permissions.py` - 权限控制装饰器
- `app/api/v1/endpoints/user_management.py` - 用户管理接口
4. **审计日志**
- `app/core/audit.py` - 审计日志核心(已完善)
- `app/domain/schemas/audit.py` - 审计日志数据模型
- `app/infra/repositories/audit_repository.py` - 审计日志数据访问层
- `app/api/v1/endpoints/audit.py` - 审计日志查询接口
- `app/infra/audit/middleware.py` - 自动审计中间件
### 数据库迁移
5. **迁移脚本**
- `migrations/001_create_users_table.sql` - 用户表
- `migrations/002_create_audit_logs_table.sql` - 审计日志表
### 配置和文档
6. **配置文件**
- `.env.example` - 环境变量模板
- `app/core/config.py` - 配置文件(已更新)
- `app/core/security.py` - 安全工具(已增强)
7. **文档**
- `SECURITY_README.md` - 完整使用指南(79KB+
- `DEPLOYMENT.md` - 部署和集成指南
- `SECURITY_IMPLEMENTATION_SUMMARY.md` - 本文件
8. **工具**
- `setup_security.sh` - 快速设置脚本
---
## 🎯 功能特性
### 1. 数据加密
- ✅ 使用 FernetAES-128)对称加密
- ✅ 支持密钥生成和管理
- ✅ 自动从环境变量读取密钥
- ✅ 完整的加密/解密 API
- ✅ 单元测试覆盖
### 2. 身份认证
- ✅ 基于 JWT 的 Token 认证
- ✅ Access Token + Refresh Token 机制
- ✅ 用户注册/登录接口
- ✅ 支持用户名或邮箱登录
- ✅ 密码使用 bcrypt 哈希存储
- ✅ Token 过期时间可配置
- ✅ 向后兼容旧接口
### 3. 权限管理(RBAC
- ✅ 4 个预定义角色:ADMIN, OPERATOR, USER, VIEWER
- ✅ 基于角色层级的权限检查
- ✅ 可复用的权限装饰器
- ✅ 资源所有者检查
- ✅ 灵活的依赖注入设计
### 4. 审计日志
- ✅ 自动记录所有关键操作
- ✅ 记录用户、时间、操作类型、资源等信息
- ✅ 敏感数据自动脱敏
- ✅ 支持按多条件查询
- ✅ 管理员专用查询接口
- ✅ 用户可查看自己的操作记录
---
## 📊 技术栈
| 组件 | 技术 | 说明 |
|------|------|------|
| 加密 | cryptography.Fernet | 对称加密 |
| 密码哈希 | bcrypt | 密码安全存储 |
| JWT | python-jose | Token 生成和验证 |
| 数据库 | PostgreSQL + psycopg | 异步数据访问 |
| Web框架 | FastAPI | 现代异步框架 |
| 数据验证 | Pydantic | 类型安全的数据模型 |
---
## 🔐 安全特性
1. **密码安全**
- bcrypt 哈希(work factor = 12
- 自动加盐
- 不可逆加密
2. **Token 安全**
- JWT 签名验证
- 短期 Access Token30分钟)
- 长期 Refresh Token7天)
- Token 类型校验
3. **数据保护**
- 敏感字段自动脱敏
- 审计日志不记录密码
- 加密密钥从环境变量读取
4. **访问控制**
- 基于角色的细粒度权限
- 资源级别的访问控制
- 自动验证用户激活状态
---
## 📈 数据库设计
### users 表
```
用户表 - 存储系统用户
- id (主键)
- username (唯一)
- email (唯一)
- hashed_password
- role (ADMIN/OPERATOR/USER/VIEWER)
- is_active
- is_superuser
- created_at
- updated_at (自动更新)
```
### audit_logs 表
```
审计日志表 - 记录所有关键操作
- id (主键)
- user_id (外键)
- username (冗余字段)
- action (操作类型)
- resource_type (资源类型)
- resource_id (资源ID)
- ip_address
- user_agent
- request_method
- request_path
- request_data (JSONB)
- response_status
- error_message
- timestamp
```
**索引优化**
- users: username, email, role, is_active
- audit_logs: user_id, username, timestamp, action, resource
---
## 🚀 快速开始
### 方法 1: 使用自动化脚本
```bash
./setup_security.sh
```
### 方法 2: 手动设置
```bash
# 1. 配置环境变量
cp .env.example .env
# 编辑 .env 填写密钥和数据库配置
# 2. 执行数据库迁移
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
# 3. 测试
python tests/test_encryption.py
# 4. 启动服务
uvicorn app.main:app --reload
```
---
## 📋 集成检查清单
### 必需步骤
- [ ] 复制 `.env.example``.env` 并配置
- [ ] 生成 JWT 密钥(SECRET_KEY
- [ ] 生成加密密钥(ENCRYPTION_KEY
- [ ] 配置数据库连接信息
- [ ] 执行用户表迁移脚本
- [ ] 执行审计日志表迁移脚本
- [ ] 验证默认管理员可登录
### 可选步骤
- [ ] 在 main.py 中添加审计中间件
- [ ] 为现有接口添加权限控制
- [ ] 注册新的路由(auth, user_management, audit
- [ ] 替换硬编码的认证逻辑
- [ ] 配置 Token 过期时间
---
## 🔄 向后兼容性
### 保留的旧接口
1. **简化登录**: `/api/v1/auth/login/simple`
- 仍可使用 `username``password` 参数
- 返回标准 Token 响应
2. **硬编码用户迁移**
- 原有 `tjwater/tjwater@123` 已迁移到数据库
- 保持相同的用户名和密码
### 渐进式迁移
可以逐步迁移现有接口:
1. 新接口直接使用新认证系统
2. 旧接口保持不变
3. 逐个替换旧接口的认证逻辑
---
## 📚 API 端点总览
### 认证接口 (`/api/v1/auth/`)
| 方法 | 路径 | 说明 | 权限 |
|------|------|------|------|
| POST | `/register` | 用户注册 | 公开 |
| POST | `/login` | OAuth2 登录 | 公开 |
| POST | `/login/simple` | 简化登录 | 公开 |
| GET | `/me` | 获取当前用户 | 认证用户 |
| POST | `/refresh` | 刷新Token | 认证用户 |
### 用户管理 (`/api/v1/users/`)
| 方法 | 路径 | 说明 | 权限 |
|------|------|------|------|
| GET | `/` | 用户列表 | 管理员 |
| GET | `/{id}` | 用户详情 | 所有者/管理员 |
| PUT | `/{id}` | 更新用户 | 所有者/管理员 |
| DELETE | `/{id}` | 删除用户 | 管理员 |
| POST | `/{id}/activate` | 激活用户 | 管理员 |
| POST | `/{id}/deactivate` | 停用用户 | 管理员 |
### 审计日志 (`/api/v1/audit/`)
| 方法 | 路径 | 说明 | 权限 |
|------|------|------|------|
| GET | `/logs` | 查询审计日志 | 管理员 |
| GET | `/logs/count` | 日志总数 | 管理员 |
| GET | `/logs/my` | 我的操作记录 | 认证用户 |
---
## 🎓 使用示例
### Python 示例
```python
import requests
# 登录
resp = requests.post("http://localhost:8000/api/v1/auth/login",
data={"username": "admin", "password": "admin123"})
token = resp.json()["access_token"]
# 访问受保护接口
headers = {"Authorization": f"Bearer {token}"}
resp = requests.get("http://localhost:8000/api/v1/auth/me", headers=headers)
print(resp.json())
```
### cURL 示例
```bash
# 登录
TOKEN=$(curl -s -X POST "http://localhost:8000/api/v1/auth/login" \
-d "username=admin&password=admin123" | jq -r .access_token)
# 查询审计日志
curl -H "Authorization: Bearer $TOKEN" \
"http://localhost:8000/api/v1/audit/logs?action=LOGIN"
```
---
## 🐛 常见问题
### Q: 如何修改默认管理员密码?
A: 登录后通过 PUT `/api/v1/users/{id}` 接口修改,或直接更新数据库。
### Q: 如何添加新用户?
A: 使用 POST `/api/v1/auth/register` 接口,或由管理员在用户管理界面创建。
### Q: 审计日志可以删除吗?
A: 不建议删除。可以归档到冷存储,保留最近 90 天的数据。
### Q: Token 过期了怎么办?
A: 使用 Refresh Token 调用 `/api/v1/auth/refresh` 接口获取新的 Access Token。
---
## 📞 技术支持
- **完整文档**: `SECURITY_README.md`
- **部署指南**: `DEPLOYMENT.md`
- **测试代码**: `tests/test_encryption.py`
- **迁移脚本**: `migrations/`
---
## 📝 待办事项(可选)
未来可以扩展的功能:
- [ ] 邮件验证
- [ ] 密码重置
- [ ] 双因素认证(2FA
- [ ] 单点登录(SSO
- [ ] Token 黑名单
- [ ] 会话管理
- [ ] IP 白名单
- [ ] 登录频率限制
- [ ] 密码复杂度策略
- [ ] 审计日志自动归档
---
## 🎉 总结
本次实施完成了企业级的安全体系,包含:
✅ 数据加密 - Fernet 对称加密
✅ 身份认证 - JWT Token + bcrypt 密码哈希
✅ 权限管理 - 基于角色的访问控制(RBAC)
✅ 审计日志 - 自动追踪所有关键操作
所有功能均遵循安全最佳实践,提供完整的文档和测试,可直接投入生产使用。
---
**实施日期**: 2026-02-02
**版本**: v1.0.0
**状态**: ✅ 已完成
-499
View File
@@ -1,499 +0,0 @@
# 安全功能使用指南
TJWater Server 安全体系实施完成,包含:数据加密、身份认证、权限管理、审计日志
## 📋 目录
1. [快速开始](#快速开始)
2. [数据加密](#数据加密)
3. [身份认证](#身份认证)
4. [权限管理](#权限管理)
5. [审计日志](#审计日志)
6. [数据库迁移](#数据库迁移)
7. [API 使用示例](#api-使用示例)
---
## 🚀 快速开始
### 1. 配置环境变量
复制 `.env.example``.env` 并配置:
```bash
cp .env.example .env
```
生成必要的密钥:
```bash
# 生成 JWT 密钥
openssl rand -hex 32
# 生成加密密钥
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
```
编辑 `.env` 文件:
```env
SECRET_KEY=your-generated-jwt-secret-key
ENCRYPTION_KEY=your-generated-encryption-key
DB_NAME=tjwater
DB_HOST=localhost
DB_PORT=5432
DB_USER=postgres
DB_PASSWORD=your-db-password
```
### 2. 执行数据库迁移
```bash
# 连接到 PostgreSQL
psql -U postgres -d tjwater
# 执行迁移脚本
\i migrations/001_create_users_table.sql
\i migrations/002_create_audit_logs_table.sql
```
或使用命令行:
```bash
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
```
### 3. 验证安装
默认创建了两个管理员账号:
- **用户名**: `admin` / **密码**: `admin123`
- **用户名**: `tjwater` / **密码**: `tjwater@123`
---
## 🔐 数据加密
### 使用加密器
```python
from app.core.encryption import get_encryptor
encryptor = get_encryptor()
# 加密敏感数据
encrypted_data = encryptor.encrypt("sensitive information")
# 解密
decrypted_data = encryptor.decrypt(encrypted_data)
```
### 生成新密钥
```python
from app.core.encryption import Encryptor
new_key = Encryptor.generate_key()
print(f"New encryption key: {new_key}")
```
---
## 👤 身份认证
### 用户角色
系统定义了 4 个角色(权限由低到高):
| 角色 | 权限说明 |
|------|---------|
| `VIEWER` | 仅查询权限 |
| `USER` | 读写权限 |
| `OPERATOR` | 操作员,可修改数据 |
| `ADMIN` | 管理员,完全权限 |
### API 接口
#### 用户注册
```http
POST /api/v1/auth/register
Content-Type: application/json
{
"username": "newuser",
"email": "user@example.com",
"password": "password123",
"role": "USER"
}
```
#### 用户登录(OAuth2 标准)
```http
POST /api/v1/auth/login
Content-Type: application/x-www-form-urlencoded
username=admin&password=admin123
```
响应:
```json
{
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "bearer",
"expires_in": 1800
}
```
#### 用户登录(简化版)
```http
POST /api/v1/auth/login/simple?username=admin&password=admin123
```
#### 获取当前用户信息
```http
GET /api/v1/auth/me
Authorization: Bearer {access_token}
```
#### 刷新 Token
```http
POST /api/v1/auth/refresh
Content-Type: application/json
{
"refresh_token": "your-refresh-token"
}
```
---
## 🔑 权限管理
### 在 API 中使用权限控制
#### 方式 1: 使用预定义依赖
```python
from fastapi import APIRouter, Depends
from app.auth.permissions import get_current_admin, get_current_operator
from app.domain.schemas.user import UserInDB
router = APIRouter()
@router.post("/admin-only")
async def admin_endpoint(
current_user: UserInDB = Depends(get_current_admin)
):
"""仅管理员可访问"""
return {"message": "Admin access granted"}
@router.post("/operator-only")
async def operator_endpoint(
current_user: UserInDB = Depends(get_current_operator)
):
"""操作员及以上可访问"""
return {"message": "Operator access granted"}
```
#### 方式 2: 使用 require_role
```python
from app.auth.permissions import require_role
from app.domain.models.role import UserRole
@router.get("/viewer-access")
async def viewer_endpoint(
current_user: UserInDB = Depends(require_role(UserRole.VIEWER))
):
"""所有认证用户可访问"""
return {"data": "visible to all"}
```
#### 方式 3: 手动检查权限
```python
from app.auth.dependencies import get_current_active_user
from app.auth.permissions import check_resource_owner
@router.put("/users/{user_id}")
async def update_user(
user_id: int,
current_user: UserInDB = Depends(get_current_active_user)
):
"""检查是否是资源拥有者或管理员"""
if not check_resource_owner(user_id, current_user):
raise HTTPException(status_code=403, detail="Permission denied")
# 执行更新操作
...
```
---
## 📝 审计日志
### 自动审计
使用中间件自动记录关键操作,在 `main.py` 中添加:
```python
from app.infra.audit.middleware import AuditMiddleware
app.add_middleware(AuditMiddleware)
```
自动记录:
- 所有 POST/PUT/DELETE 请求
- 登录/登出事件
- 关键资源访问
### 手动记录审计日志
```python
from app.core.audit import log_audit_event, AuditAction
await log_audit_event(
action=AuditAction.UPDATE,
user_id=current_user.id,
username=current_user.username,
resource_type="project",
resource_id="123",
ip_address=request.client.host,
request_data={"field": "value"},
response_status=200
)
```
### 查询审计日志
#### 获取所有审计日志(仅管理员)
```http
GET /api/v1/audit/logs?skip=0&limit=100
Authorization: Bearer {admin_token}
```
#### 按条件过滤
```http
GET /api/v1/audit/logs?user_id=1&action=LOGIN&start_time=2024-01-01T00:00:00
Authorization: Bearer {admin_token}
```
#### 获取我的操作记录
```http
GET /api/v1/audit/logs/my
Authorization: Bearer {access_token}
```
#### 获取日志总数
```http
GET /api/v1/audit/logs/count?action=LOGIN
Authorization: Bearer {admin_token}
```
---
## 💾 数据库迁移
### 用户表结构
```sql
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
hashed_password VARCHAR(255) NOT NULL,
role VARCHAR(20) DEFAULT 'USER' NOT NULL,
is_active BOOLEAN DEFAULT TRUE NOT NULL,
is_superuser BOOLEAN DEFAULT FALSE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
);
```
### 审计日志表结构
```sql
CREATE TABLE audit_logs (
id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id),
username VARCHAR(50),
action VARCHAR(50) NOT NULL,
resource_type VARCHAR(50),
resource_id VARCHAR(100),
ip_address VARCHAR(45),
user_agent TEXT,
request_method VARCHAR(10),
request_path TEXT,
request_data JSONB,
response_status INTEGER,
error_message TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
);
```
---
## 🔧 API 使用示例
### Python 客户端示例
```python
import requests
BASE_URL = "http://localhost:8000/api/v1"
# 1. 登录
response = requests.post(
f"{BASE_URL}/auth/login",
data={"username": "admin", "password": "admin123"}
)
token = response.json()["access_token"]
# 2. 设置 Authorization Header
headers = {"Authorization": f"Bearer {token}"}
# 3. 获取当前用户信息
response = requests.get(f"{BASE_URL}/auth/me", headers=headers)
print(response.json())
# 4. 创建新用户(需要管理员权限)
response = requests.post(
f"{BASE_URL}/auth/register",
headers=headers,
json={
"username": "newuser",
"email": "new@example.com",
"password": "password123",
"role": "USER"
}
)
print(response.json())
# 5. 查询审计日志(需要管理员权限)
response = requests.get(
f"{BASE_URL}/audit/logs?action=LOGIN",
headers=headers
)
print(response.json())
```
### cURL 示例
```bash
# 登录
curl -X POST "http://localhost:8000/api/v1/auth/login" \
-H "Content-Type: application/x-www-form-urlencoded" \
-d "username=admin&password=admin123"
# 使用 Token 访问受保护接口
TOKEN="your-access-token"
curl -X GET "http://localhost:8000/api/v1/auth/me" \
-H "Authorization: Bearer $TOKEN"
# 注册新用户
curl -X POST "http://localhost:8000/api/v1/auth/register" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $TOKEN" \
-d '{
"username": "testuser",
"email": "test@example.com",
"password": "password123",
"role": "USER"
}'
```
---
## 🛡️ 安全最佳实践
1. **密钥管理**
- 绝不在代码中硬编码密钥
- 定期轮换 JWT 密钥
- 使用强随机密钥
2. **密码策略**
- 最小长度 6 个字符(建议 12+)
- 强制密码复杂度(可在注册时添加验证)
- 定期提醒用户更换密码
3. **Token 管理**
- Access Token 短期有效(默认 30 分钟)
- Refresh Token 长期有效(默认 7 天)
- 实施 Token 黑名单(可选)
4. **审计日志**
- 审计日志不可删除
- 定期归档旧日志
- 监控异常登录行为
5. **权限控制**
- 遵循最小权限原则
- 定期审查用户权限
- 记录所有权限变更
---
## 📚 相关文件
- **配置**: `app/core/config.py`
- **加密**: `app/core/encryption.py`
- **安全**: `app/core/security.py`
- **审计**: `app/core/audit.py`
- **认证**: `app/api/v1/endpoints/auth.py`
- **权限**: `app/auth/permissions.py`
- **用户管理**: `app/api/v1/endpoints/user_management.py`
- **审计日志**: `app/api/v1/endpoints/audit.py`
- **迁移脚本**: `migrations/`
---
## ❓ 常见问题
### Q: 忘记密码怎么办?
A: 目前需要管理员通过数据库重置。未来可添加邮件重置功能。
```sql
-- 重置密码为 "newpassword123"
UPDATE users
SET hashed_password = '$2b$12$...' -- 使用 bcrypt 生成哈希
WHERE username = 'targetuser';
```
### Q: 如何添加新角色?
A: 编辑 `app/domain/models/role.py` 中的 `UserRole` 枚举,并更新数据库约束。
### Q: 审计日志占用太多空间?
A: 建议定期归档旧日志到冷存储:
```sql
-- 归档 90 天前的日志
CREATE TABLE audit_logs_archive AS
SELECT * FROM audit_logs WHERE timestamp < NOW() - INTERVAL '90 days';
DELETE FROM audit_logs WHERE timestamp < NOW() - INTERVAL '90 days';
```
---
## 📞 技术支持
如有问题,请查看:
- 日志文件: `logs/`
- 数据库表结构: `migrations/`
- 单元测试: `tests/`
+10 -4
View File
@@ -1,10 +1,13 @@
from app.algorithms.data_cleaning import flow_data_clean, pressure_data_clean
from app.algorithms.sensors import (
from app.algorithms.cleaning import flow_data_clean, pressure_data_clean
from app.algorithms.sensor import (
pressure_sensor_placement_sensitivity,
pressure_sensor_placement_kmeans,
)
from app.algorithms.valve_isolation import valve_isolation_analysis
from app.algorithms.simulations import (
from app.algorithms.isolation.valve import valve_isolation_analysis
from app.algorithms.leakage import LeakageIdentifier
from app.algorithms.health import PipelineHealthAnalyzer
from app.algorithms.burst_location import run_burst_location
from app.algorithms.simulation.scenarios import (
convert_to_local_unit,
burst_analysis,
valve_close_analysis,
@@ -27,4 +30,7 @@ __all__ = [
"age_analysis",
"pressure_regulation",
"valve_isolation_analysis",
"LeakageIdentifier",
"PipelineHealthAnalyzer",
"run_burst_location",
]
+88
View File
@@ -0,0 +1,88 @@
import os
import pandas as pd
def fill_time_gaps(
data: pd.DataFrame,
time_col: str = "time",
freq: str = "1min",
short_gap_threshold: int = 10,
) -> pd.DataFrame:
"""
补齐缺失时间戳并填补数据缺口。
Args:
data: 包含时间列的 DataFrame
time_col: 时间列名(默认 'time'
freq: 重采样频率(默认 '1min'
short_gap_threshold: 短缺口阈值(分钟),<=此值用线性插值,>此值用前向填充
Returns:
补齐时间后的 DataFrame(保留原时间列格式)
"""
if time_col not in data.columns:
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
# 解析时间列并设为索引
data = data.copy()
data[time_col] = pd.to_datetime(data[time_col], utc=True)
data_indexed = data.set_index(time_col)
# 生成完整时间范围
full_range = pd.date_range(
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
)
# 重索引以补齐缺失时间点,同时保留原始时间戳
combined_index = data_indexed.index.union(full_range).sort_values().unique()
data_reindexed = data_indexed.reindex(combined_index)
# 按列处理缺口
for col in data_reindexed.columns:
# 识别缺失值位置
is_missing = data_reindexed[col].isna()
# 计算连续缺失的长度
missing_groups = (is_missing != is_missing.shift()).cumsum()
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
# 短缺口:时间插值
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
if short_gap_mask.any():
data_reindexed.loc[short_gap_mask, col] = (
data_reindexed[col]
.interpolate(method="time", limit_area="inside")
.loc[short_gap_mask]
)
# 长缺口:前向填充
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
if long_gap_mask.any():
data_reindexed.loc[long_gap_mask, col] = (
data_reindexed[col].ffill().loc[long_gap_mask]
)
# 重置索引并恢复时间列(保留原格式)
data_result = data_reindexed.reset_index()
data_result.rename(columns={"index": time_col}, inplace=True)
# 保留时区信息
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
# 修正时区格式(Python的%z输出为+0000,需转为+00:00
data_result[time_col] = data_result[time_col].str.replace(
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
)
return data_result
def _cleanup_temp_files(prefix: str) -> None:
"""清理 EPANET 仿真产生的临时文件。"""
for ext in [".inp", ".rpt", ".bin", ".out"]:
temp_file = prefix + ext
if os.path.exists(temp_file):
try:
os.remove(temp_file)
except OSError:
pass
-3
View File
@@ -1,3 +0,0 @@
from .flow_data_clean import *
from .pressure_data_clean import *
from .pipeline_health_analyzer import *
@@ -1,355 +0,0 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.impute import SimpleImputer
import os
def fill_time_gaps(
data: pd.DataFrame,
time_col: str = "time",
freq: str = "1min",
short_gap_threshold: int = 10,
) -> pd.DataFrame:
"""
补齐缺失时间戳并填补数据缺口。
Args:
data: 包含时间列的 DataFrame
time_col: 时间列名(默认 'time'
freq: 重采样频率(默认 '1min'
short_gap_threshold: 短缺口阈值(分钟),<=此值用线性插值,>此值用前向填充
Returns:
补齐时间后的 DataFrame(保留原时间列格式)
"""
if time_col not in data.columns:
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
# 解析时间列并设为索引
data = data.copy()
data[time_col] = pd.to_datetime(data[time_col], utc=True)
data_indexed = data.set_index(time_col)
# 生成完整时间范围
full_range = pd.date_range(
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
)
# 重索引以补齐缺失时间点,同时保留原始时间戳
combined_index = data_indexed.index.union(full_range).sort_values().unique()
data_reindexed = data_indexed.reindex(combined_index)
# 按列处理缺口
for col in data_reindexed.columns:
# 识别缺失值位置
is_missing = data_reindexed[col].isna()
# 计算连续缺失的长度
missing_groups = (is_missing != is_missing.shift()).cumsum()
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
# 短缺口:时间插值
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
if short_gap_mask.any():
data_reindexed.loc[short_gap_mask, col] = (
data_reindexed[col]
.interpolate(method="time", limit_area="inside")
.loc[short_gap_mask]
)
# 长缺口:前向填充
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
if long_gap_mask.any():
data_reindexed.loc[long_gap_mask, col] = (
data_reindexed[col].ffill().loc[long_gap_mask]
)
# 重置索引并恢复时间列(保留原格式)
data_result = data_reindexed.reset_index()
data_result.rename(columns={"index": time_col}, inplace=True)
# 保留时区信息
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
# 修正时区格式(Python的%z输出为+0000,需转为+00:00
data_result[time_col] = data_result[time_col].str.replace(
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
)
return data_result
def clean_pressure_data_km(
input_csv_path: str, show_plot: bool = False, fill_gaps: bool = True
) -> str:
"""
读取输入 CSV,基于 KMeans 检测异常并用滚动平均修复。输出为 <input_basename>_cleaned.xlsx(同目录)。
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'
返回输出文件的绝对路径。
Args:
input_csv_path: CSV 文件路径
show_plot: 是否显示可视化
fill_gaps: 是否先补齐时间缺口(默认 True)
"""
# 读取 CSV
input_csv_path = os.path.abspath(input_csv_path)
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
# 补齐时间缺口(如果数据包含 time 列)
if fill_gaps and "time" in data.columns:
data = fill_time_gaps(
data, time_col="time", freq="1min", short_gap_threshold=10
)
# 分离时间列和数值列
time_col_data = None
if "time" in data.columns:
time_col_data = data["time"]
data = data.drop(columns=["time"])
# 标准化
data_norm = (data - data.mean()) / data.std()
# 聚类与异常检测
k = 3
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
clusters = kmeans.fit_predict(data_norm)
centers = kmeans.cluster_centers_
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
threshold = distances.mean() + 3 * distances.std()
anomaly_pos = np.where(distances > threshold)[0]
anomaly_indices = data.index[anomaly_pos]
anomaly_details = {}
for pos in anomaly_pos:
row_norm = data_norm.iloc[pos]
cluster_idx = clusters[pos]
center = centers[cluster_idx]
diff = abs(row_norm - center)
main_sensor = diff.idxmax()
anomaly_details[data.index[pos]] = main_sensor
# 修复:滚动平均(窗口可调)
data_rolled = data.rolling(window=13, center=True, min_periods=1).mean()
data_repaired = data.copy()
for pos in anomaly_pos:
label = data.index[pos]
sensor = anomaly_details[label]
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
# 可选可视化(使用位置作为 x 轴)
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
if show_plot and len(data.columns) > 0:
n = len(data)
time = np.arange(n)
plt.figure(figsize=(12, 8))
for col in data.columns:
plt.plot(time, data[col].values, marker="o", markersize=3, label=col)
for pos in anomaly_pos:
sensor = anomaly_details[data.index[pos]]
plt.plot(pos, data.iloc[pos][sensor], "ro", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("压力监测值")
plt.title("各传感器折线图(红色标记主要异常点)")
plt.legend()
plt.show()
plt.figure(figsize=(12, 8))
for col in data_repaired.columns:
plt.plot(
time, data_repaired[col].values, marker="o", markersize=3, label=col
)
for pos in anomaly_pos:
sensor = anomaly_details[data.index[pos]]
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("修复后压力监测值")
plt.title("修复后各传感器折线图(绿色标记修复值)")
plt.legend()
plt.show()
# 保存到 Excel:两个 sheet
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
output_filename = f"{input_base}_cleaned.xlsx"
output_path = os.path.join(input_dir, output_filename)
# 如果原始数据包含时间列,将其添加回结果
data_for_save = data.copy()
data_repaired_for_save = data_repaired.copy()
if time_col_data is not None:
data_for_save.insert(0, "time", time_col_data)
data_repaired_for_save.insert(0, "time", time_col_data)
if os.path.exists(output_path):
os.remove(output_path) # 覆盖同名文件
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
data_for_save.to_excel(writer, sheet_name="raw_pressure_data", index=False)
data_repaired_for_save.to_excel(
writer, sheet_name="cleaned_pressusre_data", index=False
)
# 返回输出文件的绝对路径
return os.path.abspath(output_path)
def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> dict:
"""
接收一个 DataFrame 数据结构,使用KMeans聚类检测异常并用滚动平均修复。
返回清洗后的字典数据结构。
Args:
data: 输入 DataFrame(可包含 time 列)
show_plot: 是否显示可视化
"""
# 使用传入的 DataFrame
data = data.copy()
# 补齐时间缺口(如果启用且数据包含 time 列)
data_filled = fill_time_gaps(
data, time_col="time", freq="1min", short_gap_threshold=10
)
# 保存 time 列用于最后合并
time_col_series = None
if "time" in data_filled.columns:
time_col_series = data_filled["time"]
# 移除 time 列用于后续清洗
data_filled = data_filled.drop(columns=["time"])
# 标准化(使用填充后的数据)
data_norm = (data_filled - data_filled.mean()) / data_filled.std()
# 添加:处理标准化后的 NaN(例如,标准差为0的列),防止异常数据,时间段内所有数据都相同导致计算结果为 NaN
imputer = SimpleImputer(
strategy="constant", fill_value=0, keep_empty_features=True
) # 用 0 填充 NaN,包括全 NaN,并保留空特征
data_norm = pd.DataFrame(
imputer.fit_transform(data_norm),
columns=data_norm.columns,
index=data_norm.index,
)
# 聚类与异常检测
k = 3
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
clusters = kmeans.fit_predict(data_norm)
centers = kmeans.cluster_centers_
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
threshold = distances.mean() + 3 * distances.std()
anomaly_pos = np.where(distances > threshold)[0]
anomaly_indices = data_filled.index[anomaly_pos]
anomaly_details = {}
for pos in anomaly_pos:
row_norm = data_norm.iloc[pos]
cluster_idx = clusters[pos]
center = centers[cluster_idx]
diff = abs(row_norm - center)
main_sensor = diff.idxmax()
anomaly_details[data_filled.index[pos]] = main_sensor
# 修复:滚动平均(窗口可调)
data_rolled = data_filled.rolling(window=13, center=True, min_periods=1).mean()
data_repaired = data_filled.copy()
for pos in anomaly_pos:
label = data_filled.index[pos]
sensor = anomaly_details[label]
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
# 可选可视化(使用位置作为 x 轴)
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
if show_plot and len(data.columns) > 0:
n = len(data)
time = np.arange(n)
n_filled = len(data_filled)
time_filled = np.arange(n_filled)
plt.figure(figsize=(12, 8))
for col in data.columns:
plt.plot(
time, data[col].values, marker="o", markersize=3, label=col, alpha=0.5
)
for col in data_filled.columns:
plt.plot(
time_filled,
data_filled[col].values,
marker="x",
markersize=3,
label=f"{col}_filled",
linestyle="--",
)
for pos in anomaly_pos:
sensor = anomaly_details[data_filled.index[pos]]
plt.plot(pos, data_filled.iloc[pos][sensor], "ro", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("压力监测值")
plt.title("各传感器折线图(红色标记主要异常点,虚线为0值填充后)")
plt.legend()
plt.show()
plt.figure(figsize=(12, 8))
for col in data_repaired.columns:
plt.plot(
time_filled, data_repaired[col].values, marker="o", markersize=3, label=col
)
for pos in anomaly_pos:
sensor = anomaly_details[data_filled.index[pos]]
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("修复后压力监测值")
plt.title("修复后各传感器折线图(绿色标记修复值)")
plt.legend()
plt.show()
# 将 time 列添加回结果
if time_col_series is not None:
data_repaired.insert(0, "time", time_col_series)
# 返回清洗后的字典
return data_repaired
# 测试
# if __name__ == "__main__":
# # 默认使用脚本目录下的 pressure_raw_data.csv
# script_dir = os.path.dirname(os.path.abspath(__file__))
# default_csv = os.path.join(script_dir, "pressure_raw_data.csv")
# out_path = clean_pressure_data_km(default_csv, show_plot=False)
# print("保存路径:", out_path)
# 测试 clean_pressure_data_dict_km 函数
if __name__ == "__main__":
import random
# 读取 szh_pressure_scada.csv 文件
script_dir = os.path.dirname(os.path.abspath(__file__))
csv_path = os.path.join(script_dir, "szh_pressure_scada.csv")
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
# 排除 Time 列,随机选择 5 列
columns_to_exclude = ["Time"]
available_columns = [col for col in data.columns if col not in columns_to_exclude]
selected_columns = random.sample(available_columns, 5)
# 将选中的列转换为字典
data_dict = {col: data[col].tolist() for col in selected_columns}
print("选中的列:", selected_columns)
print("原始数据长度:", len(data_dict[selected_columns[0]]))
# 调用函数进行清洗
cleaned_dict = clean_pressure_data_df_km(data_dict, show_plot=True)
print("清洗后的字典键:", list(cleaned_dict.keys()))
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
print("测试完成:函数运行正常")
-557
View File
@@ -1,557 +0,0 @@
# 改进灵敏度法
import networkx
import numpy as np
import pandas
import wntr
import pandas as pd
import copy
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.cluster import KMeans
from wntr.epanet.toolkit import EpanetException
from numpy.linalg import slogdet
import random
from app.services.tjnetwork import *
import app.services.project_info as project_info
# 2025/03/12
# Step1: 获取节点坐标
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
"""
获取管网模型的节点坐标
:param wn: 由wntr生成的模型
:return: 节点坐标
"""
# site: pandas.Series
# index:节点名称(wn.node_name_list
# values:每个节点的坐标,格式为 tuple(如 (x, y) 或 (x, y, z)
site = wn.query_node_attribute('coordinates')
# Coor: pandas.Series
# index:与site相同(节点名称)。
# values:坐标转换为numpy.ndarray(如array([10.5, 20.3])
Coor = site.apply(lambda x: np.array(x)) # 将节点坐标转换为numpy数组
# x, y: list[float]
x = [] # 存储所有节点的 x 坐标
y = [] # 存储所有节点的 y 坐标
for i in range(0, len(Coor)):
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
# xy: dict[str, list], x、y 坐标的字典
xy = {'x': x, 'y': y}
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
return Coor_node
# 2025/03/12
# Step2: KMeans 聚类
# 将节点用kmeans根据坐标分为k组,存入字典g
def kgroup(coor: pandas.DataFrame, knum: int) -> dict[int, list[str]]:
"""
使用KMeans聚类,将节点坐标分组
:param coor: 存储所有节点的坐标数据
:param knum: 需要分成的聚类数
:return: 聚类结果字典
"""
g = {}
# estimator: sklearn.cluster.KMeans,KMeans 聚类模型
estimator = KMeans(n_clusters=knum)
estimator.fit(coor)
# label_pred: numpy.ndarrayint,每个点的类别标签
label_pred = estimator.labels_
for i in range(0, knum):
g[i] = coor[label_pred == i].index.tolist()
return g
# 2025/03/12
# Step3: wn_func类,水力计算
# wn_func 主要用于计算:
# 水力距离(hydraulic length):即节点之间的水力阻力。
# 灵敏度分析(sensitivity analysis):用于优化测压点的布置。
# 一些与水力相关的函数,包括 CtoS:求水力距离,stafun:求状态函数F
# # diff:求F对P的导数,返回灵敏度矩阵A
# # sensitivity:返回灵敏度和总灵敏度
class wn_func(object):
# Step3.1: 初始化
def __init__(self, wn: wntr.network.WaterNetworkModel):
"""
获取管网模型信息
:param wn: 由wntr生成的模型
"""
# self.results: wntr.sim.results.SimulationResults,仿真结果,包含压力、流量、水头等数据
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
self.wn = wn
# self.qpandas.DataFrame,管道流量,索引为时间步长,列为管道名称
self.q = self.results.link['flowrate']
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
ReservoirIndex = wn.reservoir_name_list
Tankindex = wn.tank_name_list
# 删除水库节点,删除与直接水库相连的虚拟管道
# self.pipes: list[str],所有管道的名称
self.pipes = wn.pipe_name_list
# self.nodes: list[str],所有节点的名称
self.nodes = wn.node_name_list
# self.coordinatespandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
self.coordinates = wn.query_node_attribute('coordinates')
# allpumps / allvalves: list[str],所有泵/阀门名称列表
allpumps = wn.pump_name_list
allvalves = wn.valve_name_list
# pumpstnode / pumpednode / valvestnode / valveednode: list[str],存储泵和阀门 起终点节点的名称
pumpstnode = []
pumpednode = []
valvestnode = []
valveednode = []
# Reservoirpipe / Reservoirednode: list[str],记录与水库相关的管道和节点
Reservoirpipe = []
Reservoirednode = []
for pump in allpumps:
pumpstnode.append(wn.links[pump].start_node.name)
pumpednode.append(wn.links[pump].end_node.name)
for valve in allvalves:
valvestnode.append(wn.links[valve].start_node.name)
valveednode.append(wn.links[valve].end_node.name)
for pipe in self.pipes:
if wn.links[pipe].start_node.name in ReservoirIndex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].end_node.name)
if wn.links[pipe].start_node.name in Tankindex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].end_node.name)
if wn.links[pipe].end_node.name in Tankindex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].start_node.name)
# 泵的起终点、tank、reservoir
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
self.delnodes = list(
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
# 泵、起终点为tank、reservoir的管道
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
# self.L: list[float],所有管道的长度(以米为单位)
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
self.n = len(self.nodes)
self.m = len(self.pipes)
# self.unit_headloss: list[float],单位水头损失(headloss 数据的第一行,单位:米/km)
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
##
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
# Step3.2: 计算水力距离
def CtoS(self):
"""
计算水力距离矩阵
:return:
"""
# 水力距离:当行索引对应的节点为控制点时,列索引对应的节点距离控制点的(路径*水头损失)的最小值
# nodeslist[str](节点名称)
nodes = copy.deepcopy(self.nodes)
# pipeslist[str](管道名称)
pipes = self.pipes
wn = self.wn
# n / m:int(节点数 / 管道数)
n = self.n
m = self.m
s1 = [0] * m
q = self.q
L = self.L
# H1pandas.DataFrame,水头数据,索引为时间步长,列为节点名
H1 = self.results.node['head'].T
# hhlist[float],计算管道两端水头之差
hh = []
# 水头损失
for p in pipes:
h1 = self.wn.links[p].start_node.name
h1 = H1.loc[str(h1)]
h2 = self.wn.links[p].end_node.name
h2 = H1.loc[str(h2)]
hh.append(abs(h1 - h2))
hh = np.array(hh)
# headlosspandas.DataFrame,管道水头损失矩阵
headloss = pd.DataFrame(hh, index=pipes).T
# s1:管道阻力系数,s2:将管道阻力系数与管道的起始节点和终止节点对应
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵,index代表起始节点,columns代表终止节点
G = nx.DiGraph()
for i in range(0, m):
pipe = pipes[i]
a = wn.links[pipe].start_node.name
b = wn.links[pipe].end_node.name
if q.loc[0, pipe] > 0:
hf.loc[a, b] = headloss.loc[0, pipe]
weightL.loc[a, b] = headloss.loc[0, pipe] * L[i]
G.add_weighted_edges_from([(a, b, weightL.loc[a, b])])
else:
hf.loc[b, a] = headloss.loc[0, pipe]
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
for a in nodes:
if a in G.nodes:
d = nx.shortest_path_length(G, source=a, weight='weight')
for b in list(d.keys()):
hydraulicL.loc[a, b] = d[b]
hydraulicL = hydraulicL.drop(self.delnodes)
hydraulicL = hydraulicL.drop(self.delnodes, axis=1)
# 求加权水力距离
return hydraulicL, G
# Step3.3: 计算灵敏度矩阵
# 获取关系矩阵
def get_Conn(self):
"""
计算管网连接关系矩阵
:return:
"""
m = self.wn.num_links
n = self.wn.num_nodes
p = self.wn.num_pumps
v = self.wn.num_valves
self.nonjunc_index = []
self.non_link_index = []
for r in self.wn.reservoirs():
self.nonjunc_index.append(r[0])
for t in self.wn.tanks():
self.nonjunc_index.append(t[0])
# Connnumpy.matrix,节点-管道连接矩阵,起点 -1,终点 1
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
# NConnnumpy.matrix,节点-节点连接矩阵,有管道相连的地方设为 1
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系,之间有管道为1,反之为0
# pipeslist[str],去除泵和阀门的管道列表
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
for pipe_name, pipe in pipes:
start = self.wn.node_name_list.index(pipe.start_node_name)
end = self.wn.node_name_list.index(pipe.end_node_name)
p_index = self.wn.link_name_list.index(pipe_name)
Conn[start, p_index] = -1
Conn[end, p_index] = 1
NConn[start, end] = 1
NConn[end, start] = 1
self.A = Conn
link_name_list = [link for link in self.wn.link_name_list if
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
self.A2 = self.A2.drop(self.delnodes)
for pipe in self.delpipes:
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
self.A2 = self.A2.drop(columns=pipe)
self.junc_list = self.A2.index
self.A2 = np.mat(self.A2) # 节点管道关系
self.A3 = NConn
def Jaco(self, hL: pandas.DataFrame):
"""
计算灵敏度矩阵(节点压力对粗糙度变化的响应)
:param hL: 水力距离矩阵
:return:
"""
# global result
# Anumpy.matrix, 节点-管道关系矩阵
A = self.A2
wn = self.wn
try:
result = wntr.sim.EpanetSimulator(wn).run_sim()
except EpanetException:
pass
finally:
h = result.link['headloss'][self.pipes].values[0]
q = result.link['flowrate'][self.pipes].values[0]
l = self.wn.query_link_attribute('length')[self.pipes]
C = self.wn.query_link_attribute('roughness')[self.pipes]
# headlossnumpy.ndarray,水头损失数组
headloss = np.array(h)
# 调整流量方向
for i in range(0, len(q)):
if q[i] < 0:
A[:, i] = -A[:, i]
# qnumpy.ndarray,流量数组
q = np.abs(q)
# 两个灵敏度矩阵
# B / Snumpy.matrix,灵敏度计算的中间矩阵
B = np.mat(np.diag(q / ((1.852 * headloss) + 1e-10)))
S = np.mat(np.diag(q / C))
# Xnumpy.matrix, 灵敏度矩阵
X = A * B * A.T
try:
det = np.linalg.det(X)
except RuntimeError as e:
sign, logdet = slogdet(X) # 防止溢出
det = sign * np.exp(logdet)
if det != 0:
J_H_Cw = X.I * A * S
# J_H_Q = -X.I
J_q_Cw = S - B * A.T * X.I * A * S # 去掉了delnodes和delpipes
# J_q_Q = B * A.T * X.I
else: # 当X不可逆
J_H_Cw = np.linalg.pinv(X) @ A @ S
# J_H_Q = -np.linalg.pinv(X)
J_q_Cw = S - B * A.T * np.linalg.pinv(X) * A * S
# J_q_Q = B * A.T * np.linalg.pinv(X)
Sen_pressure = []
S_pressure = np.abs(J_H_Cw).sum(axis=1).tolist() # 修改为绝对值
for ss in S_pressure:
Sen_pressure.append(ss[0])
# 求总灵敏度
SS_pressure = copy.deepcopy(hL)
for i in range(0, len(Sen_pressure)):
SS_pressure.iloc[i, :] = SS_pressure.iloc[i, :] * Sen_pressure[i]
SS = copy.deepcopy(hL)
for i in range(0, len(Sen_pressure)):
SS.iloc[i, :] = SS.iloc[i, :] * Sen_pressure[i]
# SS[i,j]:节点nodes[i]的灵敏度*该节点到nodes[j]的水力距离
return SS
# 2025/03/12
# Step4: 传感器布置优化
# Sensorplacement
# weight:分配权重
# sensor:传感器布置的位置
class Sensorplacement(wn_func):
"""
Sensorplacement 类继承了 wn_func 类,并且用于计算和优化传感器布置的位置。
"""
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int):
"""
:param wn: 由wntr生成的模型
:param sensornum: 传感器的数量
"""
wn_func.__init__(self, wn)
self.sensornum = sensornum
# 1.某个节点到所有节点的加权距离之和
# 2.某个节点到该组内所有节点的加权距离之和
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
"""
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
:param SS: 灵敏度矩阵,每个节点的行和列代表不同节点,矩阵元素表示节点间的灵敏度。SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
:param G: 加权图,表示管网的拓扑结构,每个节点通过管道连接。图的边的权重通常是根据水力距离或者流量等计算的
:param group: 节点分组,字典的键是分组编号,值是该组的节点名称列表
:return:
"""
# 传感器布置个数以及位置
# W = self.weight()
n = self.n - len(self.delnodes)
nodes = copy.deepcopy(self.nodes)
for node in self.delnodes:
nodes.remove(node)
# sumSSlist[float],每个节点到其他节点的灵敏度之和。SS.iloc[i, :] 返回第 i 个节点与所有其他节点的灵敏度值,sum(SS.iloc[i, :]) 计算这些灵敏度值的总和。
sumSS = []
for i in range(0, n):
sumSS.append(sum(SS.iloc[i, :]))
# 一个整数范围,表示每个节点的索引,用作sumSS_ DataFrame的索引
indices = range(0, n)
# sumSS_pandas.DataFrame,将 sumSS 转换成 DataFrame 格式,并且将节点的总灵敏度保存到 CSV 文件 sumSS_data.csv 中
sumSS_ = pd.DataFrame(np.array(sumSS), index=indices)
sumSS_.to_csv('sumSS_data.csv') # 存储节点总灵敏度
# sumSSpandas.DataFrame,sumSS 被转换为 DataFrame 类型,并且按总灵敏度(即灵敏度之和)降序排列。此时,sumSS 是按节点的灵敏度之和排序的 DataFrame
sumSS = pd.DataFrame(np.array(sumSS), index=nodes)
sumSS = sumSS.sort_values(by=[0], ascending=[False])
# sensorindexlist[str],用于存储根据灵敏度排序选出的传感器位置的节点名称,存储根据总灵敏度排序的节点列表,用于传感器布置
sensorindex = []
# sensorindex_2list[str],用于存储每组内根据灵敏度排序选出的传感器位置的节点名称,存储每个组内根据灵敏度排序选择的传感器节点
sensorindex_2 = []
# group_Sdict[int, pandas.DataFrame],存储每个组内的灵敏度矩阵
group_S = {}
# group_sumSSdict[int, list[float]],存储每个组内节点的总灵敏度,值为每个组内节点灵敏度之和的列表
group_sumSS = {}
for i in range(0, len(group)):
for node in self.delnodes:
# 这里的group[i]是每个组的节点列表,代码首先去除已经被标记为删除的节点self.delnodes
if node in group[i]:
group[i].remove(node)
group_S[i] = SS.loc[group[i], group[i]]
# 对每个组内的节点,计算组内节点的总灵敏度(group_sumSS[i])。它将每个组内节点的灵敏度值相加,并且按灵敏度降序排序
group_sumSS[i] = []
for j in range(0, len(group[i])):
group_sumSS[i].append(sum(group_S[i].iloc[j, :]))
group_sumSS[i] = pd.DataFrame(np.array(group_sumSS[i]), index=group[i])
group_sumSS[i] = group_sumSS[i].sort_values(by=[0], ascending=[False])
pass
# 1.选sumSS最大的节点,然后把这个节点所在的那个组删掉,就可以不再从这个组选点。再重新排序选sumSS最大的;
# 2.在每组内选group_sumSS最大的节点
# 在这个循环中,首先选择灵敏度最高的节点Smaxnode并添加到sensorindex。然后根据灵敏度排序,删除已选的节点并继续选择下一个灵敏度最大的节点。这个过程用于选择传感器的位置
sensornum = self.sensornum
for i in range(0, sensornum):
# Smaxnodestr,最大灵敏度节点,sumSS.index[0] 表示灵敏度最高的节点
Smaxnode = sumSS.index[0]
sensorindex.append(Smaxnode)
sensorindex_2.append(group_sumSS[i].index[0])
for key, value in group.items():
if Smaxnode in value:
sumSS = sumSS.drop(index=group[key])
continue
sumSS = sumSS.sort_values(by=[0], ascending=[False])
return sensorindex, sensorindex_2
# 2025/03/13
def get_sensor_coord(name: str, sensor_num: int) -> dict[str, float]:
"""
获取布置测压点的坐标,初始测压点布置根据灵敏度来布置,计算初始情况下的校准过程的error
:param name: 数据库名称
:param sensor_num: 测压点数目
:return: 测压点坐标字典
"""
# inp_file_realstr,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
inp_file_real = f'./db_inp/{name}.db.inp'
# sensornumint,需要布置的传感器数量
# sensornum = sensor_num
# wn_realwntr.network.WaterNetworkModel,加载 EPANET 水力模型
wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
# sim_realwntr.sim.EpanetSimulator,创建一个水力仿真器对象
sim_real = wntr.sim.EpanetSimulator(wn_real)
# results_realwntr.sim.results.SimulationResults,运行仿真并返回结果
results_real = sim_real.run_sim()
# real_Clist[float],包含所有管道粗糙度的列表
real_C = wn_real.query_link_attribute('roughness').tolist()
# wn_fun1wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
wn_fun1 = wn_func(wn_real)
# nodeslist[str],管网的节点名称列表
nodes = wn_fun1.nodes
# delnodeslist[str],被删除的节点(如水库、泵、阀门连接的节点等)
delnodes = wn_fun1.delnodes
# Coor_nodepandas.DataFrame
Coor_node = getCoor(wn_real)
Coor_node = Coor_node.drop(wn_fun1.delnodes)
nodes = [node for node in wn_fun1.nodes if node not in delnodes]
# coordinatespandas.Series,存储所有节点的坐标,类型为 Series,索引为节点名称,值为 (x, y) 坐标对
coordinates = wn_fun1.coordinates
# 随机产生监测点
# junctionnumintnodes 的长度,表示节点的数量
junctionnum = len(nodes)
# random_numberslist[int],使用 random.sample 随机选择 sensornum(20)个节点的编号。它返回一个不重复的随机编号列表
# random_numbers = random.sample(range(junctionnum), sensor_num)
# for i in range(sensor_num):
# # print(random_numbers[i])
wn_fun1.get_Conn()
# hLpandas.DataFrame,水力距离矩阵,表示每个节点到其他节点的水力阻力
# Gnetworkx.DiGraph,加权有向图,表示管网的拓扑结构,节点之间的边带有权重
hL, G = wn_fun1.CtoS()
# SSpandas.DataFrame,灵敏度矩阵,表示每个节点对管网变化(如粗糙度、流量等)的响应
SS = wn_fun1.Jaco(hL)
# groupdict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组,每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
group = kgroup(Coor_node, sensor_num)
# wn_funSensorplacement(继承自wn_func
# 创建Sensorplacement类的实例,传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
wn_fun = Sensorplacement(wn_real, sensor_num)
wn_fun.__dict__.update(wn_fun1.__dict__)
# sensorindexlist[str],初始传感器布置位置的节点名称
# sensorindex_2list[str],根据分组选择的传感器位置
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
sensor_coord = {}
# 重新打开数据库
if is_project_open(name=name):
close_project(name=name)
open_project(name=name)
for node_id in sensorindex:
sensor_coord[node_id] = get_node_coord(name=name, node_id=node_id)
close_project(name=name)
# print(sensor_coord)
return sensor_coord
if __name__ == '__main__':
sensor_coord = get_sensor_coord(name=project_info.name, sensor_num=20)
print(sensor_coord)
# '''
# 初始测压点布置根据灵敏度来布置,计算初始情况下的校准过程的error
# '''
#
# # inp_file_realstr,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
# inp_file_real = './db_inp/bb.db.inp'
# # sensornumint,需要布置的传感器数量
# sensornum = 20
# # wn_realwntr.network.WaterNetworkModel,加载 EPANET 水力模型
# wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
# # sim_realwntr.sim.EpanetSimulator,创建一个水力仿真器对象
# sim_real = wntr.sim.EpanetSimulator(wn_real)
# # results_realwntr.sim.results.SimulationResults,运行仿真并返回结果
# results_real = sim_real.run_sim()
#
# # real_Clist[float],包含所有管道粗糙度的列表
# real_C = wn_real.query_link_attribute('roughness').tolist()
# # wn_fun1wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
# wn_fun1 = wn_func(wn_real)
# # nodeslist[str],管网的节点名称列表
# nodes = wn_fun1.nodes
# # delnodeslist[str],被删除的节点(如水库、泵、阀门连接的节点等)
# delnodes = wn_fun1.delnodes
# # Coor_nodepandas.DataFrame
# Coor_node = getCoor(wn_real)
# Coor_node = Coor_node.drop(wn_fun1.delnodes)
# nodes = [node for node in wn_fun1.nodes if node not in delnodes]
# # coordinatespandas.Series,存储所有节点的坐标,类型为 Series,索引为节点名称,值为 (x, y) 坐标对
# coordinates = wn_fun1.coordinates
#
# # 随机产生监测点
# # junctionnumintnodes 的长度,表示节点的数量
# junctionnum = len(nodes)
# # random_numberslist[int],使用 random.sample 随机选择 sensornum(20)个节点的编号。它返回一个不重复的随机编号列表
# random_numbers = random.sample(range(junctionnum), sensornum)
# for i in range(sensornum):
# print(random_numbers[i])
#
# wn_fun1.get_Conn()
# # hLpandas.DataFrame,水力距离矩阵,表示每个节点到其他节点的水力阻力
# # Gnetworkx.DiGraph,加权有向图,表示管网的拓扑结构,节点之间的边带有权重
# hL, G = wn_fun1.CtoS()
# # SSpandas.DataFrame,灵敏度矩阵,表示每个节点对管网变化(如粗糙度、流量等)的响应
# SS = wn_fun1.Jaco(hL)
# # groupdict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组,每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
# group = kgroup(Coor_node, sensornum)
# # wn_funSensorplacement(继承自wn_func
# # 创建Sensorplacement类的实例,传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
# wn_fun = Sensorplacement(wn_real, sensornum)
# wn_fun.__dict__.update(wn_fun1.__dict__)
# # sensorindexlist[str],初始传感器布置位置的节点名称
# # sensorindex_2list[str],根据分组选择的传感器位置
# sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
# print(str(sensornum), "个测压点,测压点位置:", sensorindex)
# # 分区画图
# colorlist = ['lightpink', 'coral', 'rosybrown', 'olive', 'powderblue', 'lightskyblue', 'steelblue', 'peachpuff','brown','silver','indigo','lime','gold','violet','maroon','navy','teal','magenta','cyan',
# 'burlywood', 'tan', 'slategrey', 'thistle', 'lightseagreen', 'lightgreen', 'red','blue','yellow','orange','purple','grey','green','pink','lightblue','beige','chartreuse','turquoise','lavender','fuchsia','coral']
# G = wn_real.to_graph()
# G = G.to_undirected() # 变为无向图
# pos = nx.get_node_attributes(G, 'pos')
# pass
# for i in range(0, sensornum):
# ax = plt.gca()
# ax.set_title(inp_file_real + str(sensornum))
# nodes = nx.draw_networkx_nodes(G, pos, nodelist=group[i], node_color=colorlist[i], node_size=20)
# nodes = nx.draw_networkx_nodes(G, pos,
# nodelist=sensorindex_2, node_color='black', node_size=70, node_shape='*'
# )
# edges = nx.draw_networkx_edges(G, pos)
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)
# ax.spines['bottom'].set_visible(False)
# ax.spines['left'].set_visible(False)
# plt.savefig(inp_file_real + str(sensornum) + ".png")
# plt.show()
#
# wntr.graphics.plot_network(wn_real, node_attribute=sensorindex_2, node_size=50, node_labels=False,
# title=inp_file_real + '_Projetion' + str(sensornum))
# plt.savefig(inp_file_real + '_S' + str(sensornum) + ".png")
# plt.show()
@@ -0,0 +1,3 @@
from app.algorithms.burst_detection.burst_detector import BurstDetector
__all__ = ["BurstDetector"]
@@ -0,0 +1,259 @@
from __future__ import annotations
from typing import Any
import numpy as np
import pandas as pd
from scipy.fft import fft, ifft
from sklearn.ensemble import IsolationForest
PressureDataInput = (
pd.DataFrame
| dict[str, list[Any]]
| list[dict[str, Any]]
| list[list[Any]]
| np.ndarray
)
IGNORED_OBSERVATION_COLUMNS = {"time", "timestamp", "datetime", "date"}
class BurstDetector:
"""FFT + IsolationForest based burst detection for daily aligned pressure data."""
def __init__(
self,
*,
mu: int = 100,
points_per_day: int = 1440,
iforest_params: dict[str, Any] | None = None,
) -> None:
if points_per_day <= 0:
raise ValueError("points_per_day 必须大于 0。")
if mu <= 0:
raise ValueError("mu 必须大于 0。")
self.mu = int(mu)
self.points_per_day = int(points_per_day)
self.iforest_params = {
"n_estimators": 50,
"random_state": 42,
"contamination": "auto",
}
if iforest_params:
self.iforest_params.update(iforest_params)
self.data: np.ndarray | None = None
self.sensor_names: list[str] = []
self.high_freq_features: np.ndarray | None = None
def load_data(
self,
data_source: PressureDataInput,
*,
sensor_nodes: list[str] | None = None,
) -> pd.DataFrame:
"""
标准化输入观测数据为 DataFrame。
支持的 `data_source` 格式:
- `pd.DataFrame`
每一列代表一个传感器,每一行代表一个时间点。
- `dict[str, list[Any]]`
键为传感器 ID,值为该传感器按时间顺序排列的压力序列。
例如:`{"J1": [101.2, 101.0], "J2": [99.8, 99.7]}`。
- `list[dict[str, Any]]`
每个字典代表一个时间点,键为传感器 ID,值为该时刻压力。
例如:`[{"J1": 101.2, "J2": 99.8}, {"J1": 101.0, "J2": 99.7}]`。
- `list[list[Any]]`
二维列表,格式为 `(时间点数, 传感器数)`。
例如:`[[101.2, 99.8], [101.0, 99.7]]`。
- `np.ndarray`
二维数组,形状必须为 `(时间点数, 传感器数)`。
参数:
- `sensor_nodes`:
可选的传感器列筛选列表。传入后,数据中必须包含这些列名。
返回:
- 标准化后的 `pd.DataFrame`,列为传感器,行为时间点。
"""
if isinstance(data_source, np.ndarray):
observation_df = pd.DataFrame(data_source)
elif isinstance(data_source, pd.DataFrame):
observation_df = data_source.copy()
else:
observation_df = pd.DataFrame(data_source)
return self._normalize_observation_frame(
observation_df=observation_df, sensor_nodes=sensor_nodes
)
def process(
self,
observed_pressure_data: PressureDataInput,
*,
sensor_nodes: list[str] | None = None,
) -> np.ndarray:
"""
对输入压力序列按天切片,并提取每天末时刻的高频特征。
`observed_pressure_data` 的格式与 `load_data()` 一致,统一要求:
- 数据必须表示为“行=时间点、列=传感器”。
- 总行数必须是 `points_per_day` 的整数倍。
- 至少需要 2 天数据,即总行数 `>= 2 * points_per_day`。
例如:
- 当 `points_per_day=1440` 时,15 天数据的形状通常为 `(21600, 传感器数)`。
- 若传入 `sensor_nodes=["J1", "J2"]`,则输入中必须存在 `J1/J2` 两列。
返回:
- `np.ndarray`,形状为 `(天数, 传感器数)`,
每个值表示对应传感器在当天末时刻提取出的高频分量。
"""
observation_df = self.load_data(
observed_pressure_data,
sensor_nodes=sensor_nodes,
)
matrix = observation_df.to_numpy(dtype=float)
total_points, sensor_count = matrix.shape
if sensor_count == 0:
raise ValueError("压力观测数据中未找到可用传感器列。")
if total_points < self.points_per_day * 2:
raise ValueError("至少需要 2 天的观测数据才能执行爆管侦测。")
if total_points % self.points_per_day != 0:
raise ValueError("观测数据长度必须能被每日采样点数整除,以便按天切分。")
day_count = total_points // self.points_per_day
high_freq_features = np.zeros((day_count, sensor_count), dtype=float)
for sensor_idx in range(sensor_count):
sensor_series = matrix[:, sensor_idx]
for day_idx in range(day_count):
start = day_idx * self.points_per_day
end = (day_idx + 1) * self.points_per_day
day_data = sensor_series[start:end]
mirrored_data = np.concatenate([day_data, day_data[::-1]])
transformed = fft(mirrored_data)
transformed[self.mu : len(mirrored_data) - self.mu + 1] = 0
low_freq = ifft(transformed).real
high_freq = day_data - low_freq[: self.points_per_day]
high_freq_features[day_idx, sensor_idx] = float(high_freq[-1])
self.data = matrix
self.sensor_names = [str(column) for column in observation_df.columns]
self.high_freq_features = high_freq_features
return high_freq_features
def detect(self) -> pd.DataFrame:
if self.high_freq_features is None:
raise ValueError("特征未提取。请先调用 process()。")
day_count = self.high_freq_features.shape[0]
if day_count < 2:
raise ValueError("孤立森林至少需要 2 天特征数据。")
clf = IsolationForest(
n_estimators=self.iforest_params.get("n_estimators", 50),
max_samples=day_count,
random_state=self.iforest_params.get("random_state", 42),
contamination=self.iforest_params.get("contamination", "auto"),
**{
key: value
for key, value in self.iforest_params.items()
if key not in {"n_estimators", "random_state", "contamination"}
},
)
clf.fit(self.high_freq_features)
scores = clf.decision_function(self.high_freq_features)
predictions = clf.predict(self.high_freq_features)
result_df = pd.DataFrame(
{
"Day": range(1, day_count + 1),
"Score": scores.astype(float),
"Prediction": predictions.astype(int),
}
)
result_df["IsBurst"] = result_df["Prediction"].eq(-1)
result_df.attrs["sensor_nodes"] = self.sensor_names.copy()
result_df.attrs["high_freq_features"] = self.high_freq_features.copy()
result_df.attrs["day_count"] = day_count
result_df.attrs["points_per_day"] = self.points_per_day
result_df.attrs["sample_count"] = (
int(self.data.shape[0]) if self.data is not None else 0
)
return result_df
def run_detection(
self,
observed_pressure_data: PressureDataInput,
*,
sensor_nodes: list[str] | None = None,
) -> pd.DataFrame:
"""
执行完整爆管侦测流程。
输入格式与 `process()` 相同:
- `DataFrame` / `dict[str, list[Any]]` / `list[dict[str, Any]]` / `list[list[Any]]` / `np.ndarray`
- 行表示时间点,列表示传感器
- 总行数必须能被 `points_per_day` 整除
返回结果包含列:
- `Day`: 第几天(从 1 开始)
- `Score`: IsolationForest 异常分数,越小越异常
- `Prediction`: `-1` 表示异常,`1` 表示正常
- `IsBurst`: 是否判定为异常日
"""
self.process(observed_pressure_data, sensor_nodes=sensor_nodes)
return self.detect()
@staticmethod
def _normalize_observation_frame(
*,
observation_df: pd.DataFrame,
sensor_nodes: list[str] | None,
) -> pd.DataFrame:
if observation_df.empty:
raise ValueError("压力观测数据为空。")
normalized_df = observation_df.copy()
normalized_df.columns = [str(column) for column in normalized_df.columns]
normalized_df = normalized_df.drop(
columns=[
column
for column in normalized_df.columns
if column.lower() in IGNORED_OBSERVATION_COLUMNS
or column.lower().startswith("unnamed:")
],
errors="ignore",
)
if sensor_nodes:
selected_columns = [str(node) for node in sensor_nodes]
missing_columns = [
column
for column in selected_columns
if column not in normalized_df.columns
]
if missing_columns:
preview = ", ".join(missing_columns[:10])
raise ValueError(f"观测数据缺少传感器列: {preview}")
normalized_df = normalized_df.loc[:, selected_columns]
else:
candidate_df = normalized_df.apply(pd.to_numeric, errors="coerce")
normalized_df = candidate_df.loc[:, candidate_df.notna().any(axis=0)]
if normalized_df.empty:
raise ValueError("未识别到可用的数值型压力观测列。")
normalized_df = normalized_df.apply(pd.to_numeric, errors="coerce")
invalid_columns = [
column
for column in normalized_df.columns
if normalized_df[column].isna().any()
]
if invalid_columns:
preview = ", ".join(invalid_columns[:10])
raise ValueError(f"压力观测数据包含非数值或缺失值: {preview}")
return normalized_df.reset_index(drop=True)
@@ -0,0 +1,3 @@
from .burst_location import run_burst_location
__all__ = ["run_burst_location"]
@@ -0,0 +1,342 @@
import argparse
import json
import logging
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any, Iterable
import pandas as pd
from app.algorithms.burst_location import leak_simulator
from .burst_locator import (
DN_search_multi_simple_add_flow_count_new,
)
from .network_model import (
_build_node_pipe_maps,
cal_node_coordinate,
construct_graph,
load_inp,
read_inf_inp,
read_inf_inp_other,
)
DEFAULT_N_WORKERS = max(1, min(cpu_count() - 1, 4))
# DEFAULT_N_WORKERS = max(1, cpu_count() - 1)
logger = logging.getLogger(__name__)
def _read_id_list_json(path):
if path is None:
return None
data = json.loads(Path(path).read_text(encoding="utf-8"))
if isinstance(data, list):
return [str(item) for item in data]
if isinstance(data, dict):
if "ids" in data and isinstance(data["ids"], list):
return [str(item) for item in data["ids"]]
raise ValueError(f"ID JSON must be list or dict with key 'ids': {path}")
raise ValueError(f"Unsupported ID JSON format: {path}")
def _read_series_csv(path):
if path is None:
return None
df = pd.read_csv(path)
if df.shape[1] < 2:
raise ValueError(f"CSV must contain at least two columns (id,value): {path}")
if {"id", "value"}.issubset(df.columns):
id_col, value_col = "id", "value"
else:
id_col, value_col = df.columns[0], df.columns[1]
series = pd.Series(
df[value_col].values, index=df[id_col].astype(str).values, dtype=float
)
return series
def _align_scada_series(
series: pd.Series, ids: Iterable[str], series_name: str
) -> pd.Series:
ids = [str(item) for item in ids]
aligned = series.copy()
aligned.index = aligned.index.map(str)
missing_ids = [item for item in ids if item not in aligned.index]
if missing_ids:
preview = ", ".join(missing_ids[:10])
raise ValueError(f"{series_name} missing IDs: {preview}")
aligned = pd.to_numeric(aligned.loc[ids], errors="coerce")
invalid_ids = aligned[aligned.isna()].index.tolist()
if invalid_ids:
preview = ", ".join(invalid_ids[:10])
raise ValueError(
f"{series_name} contains non-numeric values for IDs: {preview}"
)
return aligned
def _validate_flow_inputs(
flow_scada_ids: list[str] | None,
burst_flow: pd.Series | None,
normal_flow: pd.Series | None,
) -> tuple[bool, list[str]]:
has_any_flow = any(
value is not None for value in [flow_scada_ids, burst_flow, normal_flow]
)
has_all_flow = all(
value is not None for value in [flow_scada_ids, burst_flow, normal_flow]
)
if has_any_flow and not has_all_flow:
raise ValueError(
"flow_scada_ids, burst_flow, and normal_flow must be provided together."
)
if not has_all_flow:
return False, []
flow_ids = [str(item) for item in (flow_scada_ids or [])]
if len(flow_ids) == 0:
raise ValueError("flow_scada_ids cannot be empty when flow data is provided.")
return True, flow_ids
def _build_top_candidates(similarity_series: pd.Series) -> list[dict[str, Any]]:
top_series = similarity_series.iloc[:10]
return [
{"pipe_id": str(pipe_id), "similarity": float(score)}
for pipe_id, score in top_series.items()
]
def run_burst_location(
wn_inp_path: str,
pressure_scada_ids: list[str],
burst_pressure: pd.Series,
normal_pressure: pd.Series,
burst_leakage: float,
flow_scada_ids: list[str] | None = None,
burst_flow: pd.Series | None = None,
normal_flow: pd.Series | None = None,
min_dpressure: float = 2.0,
basic_pressure: float = 10.0,
n_workers: int = DEFAULT_N_WORKERS,
partition_on_full_graph: bool = True,
visualize_partition: bool = True,
visualize_pause_seconds: float = 0.3,
final_candidates_csv_path: (
str | None
) = "temp/burst_location/final_round_candidates.csv",
) -> dict[str, Any]:
if pressure_scada_ids is None or len(pressure_scada_ids) == 0:
raise ValueError("pressure_scada_ids cannot be empty.")
if burst_pressure is None or normal_pressure is None:
raise ValueError("burst_pressure and normal_pressure are required.")
has_all_flow, flow_ids = _validate_flow_inputs(
flow_scada_ids=flow_scada_ids,
burst_flow=burst_flow,
normal_flow=normal_flow,
)
inp_path = Path(wn_inp_path)
wn = load_inp(
inp_name=inp_path.name,
inp_location=str(inp_path.parent) + "/",
inp_time=0,
driven_mode="PDD",
require_p=float(basic_pressure),
minimum_p=0.0,
)
(
all_node,
_,
node_coordinates,
all_pipe,
_,
_,
pipe_length,
pipe_diameter,
) = read_inf_inp(wn)
candidate_pipe, _ = leak_simulator.cal_possible_pipe(
burst_leakage, all_pipe, pipe_diameter
)
_, pipe_start_node_all, pipe_end_node_all = read_inf_inp_other(wn)
node_x, node_y = cal_node_coordinate(all_node, node_coordinates)
G0 = construct_graph(wn)
node_pipe_dic, couple_node_length = _build_node_pipe_maps(
all_node,
all_pipe,
pipe_start_node_all,
pipe_end_node_all,
pipe_length,
)
all_node_series = pd.Series(range(len(all_node)), index=all_node)
pressure_ids = [str(item) for item in pressure_scada_ids]
normal_pressure_aligned = _align_scada_series(
normal_pressure, pressure_ids, "normal_pressure"
)
burst_pressure_aligned = _align_scada_series(
burst_pressure, pressure_ids, "burst_pressure"
)
pressure_normal = normal_pressure_aligned.to_frame().T
pressure_monitor = burst_pressure_aligned.to_frame().T
pressure_predict = pressure_normal.copy()
timestep_list = list(pressure_normal.index)
if has_all_flow:
normal_flow_aligned = _align_scada_series(normal_flow, flow_ids, "normal_flow")
burst_flow_aligned = _align_scada_series(burst_flow, flow_ids, "burst_flow")
flow_normal = normal_flow_aligned.to_frame().T
flow_monitor = burst_flow_aligned.to_frame().T
flow_predict = flow_normal.copy()
similarity_mode = "CDF"
max_flow = flow_normal.iloc[0, :].abs()
else:
flow_normal = pd.DataFrame(index=timestep_list)
flow_monitor = pd.DataFrame(index=timestep_list)
flow_predict = pd.DataFrame(index=timestep_list)
similarity_mode = "CAD_new_gy"
max_flow = pd.Series(dtype=float)
stage_timing: dict[str, Any] = {}
try:
(
located_pipe,
elapsed_seconds,
simulation_times,
_,
similarity_series,
exit_condition,
final_candidates_csv,
) = DN_search_multi_simple_add_flow_count_new(
wn=wn,
wn_inp_path=str(inp_path),
G0=G0,
all_node=all_node,
node_x=node_x,
node_y=node_y,
pipe_start_node_all=pipe_start_node_all,
pipe_end_node_all=pipe_end_node_all,
pipe_diameter=pipe_diameter,
couple_node_length=couple_node_length,
node_pipe_dic=node_pipe_dic,
all_node_series=all_node_series,
top_group_ratio=0.3,
top_pipe_num_max=80,
top_pipe_num_min=10,
candidate_pipe_input_initial=candidate_pipe,
similarity_mode=similarity_mode,
pressure_monitor=pressure_monitor,
pressure_predict=pressure_predict,
pressure_normal=pressure_normal,
pressure_leak_all=None,
flow_monitor=flow_monitor,
flow_predict=flow_predict,
flow_normal=flow_normal,
flow_leak_all=None,
timestep_list=timestep_list,
max_flow=max_flow,
group_basic_num=30,
Top_sensor_num=min(5, len(pressure_ids)),
if_gy=0,
pressure_threshold=float(min_dpressure),
leak_mag=float(burst_leakage),
n_workers=max(1, int(n_workers)),
stage_timing=stage_timing,
partition_on_full_graph=partition_on_full_graph,
visualize_partition=visualize_partition,
visualize_pause_seconds=visualize_pause_seconds,
final_candidates_csv_path=final_candidates_csv_path,
)
except Exception as exc:
logger.exception("Burst location algorithm execution failed.")
raise RuntimeError(f"Failed to run burst location algorithm: {exc}") from exc
return {
"located_pipe": located_pipe,
"burst_leakage": float(burst_leakage),
"elapsed_seconds": elapsed_seconds,
"simulation_times": int(simulation_times),
"top_candidates": _build_top_candidates(similarity_series),
"similarity_mode": similarity_mode,
"exit_condition": exit_condition,
"final_candidates_csv": final_candidates_csv,
"stage_timing_seconds": stage_timing,
}
def _parse_args():
parser = argparse.ArgumentParser(description="爆管定位主函数入口")
parser.add_argument("--wn-inp", required=True, help="EPANET inp 文件路径")
parser.add_argument(
"--pressure-ids-json", required=True, help="压力SCADA ID列表 JSON 文件"
)
parser.add_argument(
"--flow-ids-json", default=None, help="(可选)流量SCADA ID列表 JSON 文件"
)
parser.add_argument(
"--burst-pressure-csv", required=True, help="爆管时压力 CSVid,value"
)
parser.add_argument(
"--normal-pressure-csv", required=True, help="正常时压力 CSVid,value"
)
parser.add_argument(
"--burst-flow-csv", default=None, help="(可选)爆管时流量 CSV(id,value"
)
parser.add_argument(
"--normal-flow-csv", default=None, help="(可选)正常时流量 CSV(id,value"
)
parser.add_argument(
"--burst-leakage", type=float, required=True, help="爆管漏损流量"
)
parser.add_argument(
"--min-dpressure",
type=float,
default=2.0,
help="(可选)最小压降阈值,默认 2.0",
)
parser.add_argument(
"--basic-pressure",
type=float,
default=10.0,
help="(可选)基础服务压力,默认 10.0",
)
parser.add_argument(
"--n-workers",
type=int,
default=DEFAULT_N_WORKERS,
help="(可选)特征中心模拟进程数,默认 max(1, min(cpu_count()-1, 4))",
)
parser.add_argument(
"--final-candidates-csv-path",
default="temp/burst_location/final_round_candidates.csv",
help="(可选)最后一轮候选管道明细 CSV 输出路径",
)
return parser.parse_args()
def main():
args = _parse_args()
result = run_burst_location(
wn_inp_path=args.wn_inp,
pressure_scada_ids=_read_id_list_json(args.pressure_ids_json),
burst_pressure=_read_series_csv(args.burst_pressure_csv),
normal_pressure=_read_series_csv(args.normal_pressure_csv),
burst_leakage=args.burst_leakage,
flow_scada_ids=_read_id_list_json(args.flow_ids_json),
burst_flow=_read_series_csv(args.burst_flow_csv),
normal_flow=_read_series_csv(args.normal_flow_csv),
min_dpressure=args.min_dpressure,
basic_pressure=args.basic_pressure,
n_workers=args.n_workers,
final_candidates_csv_path=args.final_candidates_csv_path,
)
print(json.dumps(result, ensure_ascii=False))
if __name__ == "__main__":
main()
@@ -0,0 +1,772 @@
"""爆管定位主模块。"""
import copy
import math
import os
import sys
from datetime import datetime
from time import perf_counter
import networkx as nx
import numpy as np
import pandas as pd
from .leak_simulator import cal_signature_pipe_multi_pf
from .network_partitioner import (
cal_group_num,
metis_grouping_pipe_weight,
visualize_metis_partition,
)
from .similarity_calculator import (
adjust_ratio,
cal_similarity_all_multi_new_sq_improve_double_lzr,
decode_mode,
extra_judge,
update_similarity,
)
def _ensure_signatures_for_centers(
wn,
wn_inp_path,
center_list, # 本轮要用到的中心(list[str])
pressure_leak_all,
flow_leak_all, # 全量缓存(可为空 DF
timestep_list, # 你现有的时序列表
pressure_monitor,
flow_monitor, # 用来推断传感器列名
leak_mag,
n_workers=1,
):
"""
只为缺失的中心补算 SLF(调用你现有的 cal_signature_pipe_multi_pf),
并把补算结果并回缓存。返回:
pressure_leak_subset, flow_leak_subset, pressure_leak_all_new, flow_leak_all_new
其中 subset 只包含 center_list 的行(顺序与 center_list 保持一致)。
"""
center_list = _dedupe_preserve_order(center_list)
# 1) 推断传感器列名(与现有数据保持一致)
sensor_name_all = list(pressure_monitor.columns)
sensor_f_name_all = (
list(flow_monitor.columns)
if (flow_monitor is not None and hasattr(flow_monitor, "columns"))
else []
)
# 2) 取出缓存里已经有的中心(考虑 MultiIndex 的第 0 层为 pipe
def _existing_pipes(df):
if df is None or len(df) == 0:
return set()
idx = df.index
if isinstance(idx, pd.MultiIndex):
return set(idx.get_level_values(0))
else:
return set(idx)
exist_p = _existing_pipes(pressure_leak_all)
need = [p for p in center_list if p not in exist_p]
# 3) 若有缺失中心,仅为这些中心补算一次
if len(need) > 0:
p_new, _ = cal_signature_pipe_multi_pf(
wn,
leak_mag,
need,
timestep_list,
sensor_name_all,
n_workers=n_workers,
wn_inp_path=wn_inp_path,
)
# 初始化空缓存时,做一次“同构化”
if pressure_leak_all is None or len(pressure_leak_all) == 0:
pressure_leak_all = p_new
else:
pressure_leak_all = pd.concat([pressure_leak_all, p_new], axis=0)
# if (flow_leak_all is None or len(flow_leak_all) == 0) and f_new is not None:
# flow_leak_all = f_new
# elif f_new is not None:
# flow_leak_all = pd.concat([flow_leak_all, f_new], axis=0)
# 去重(如果既有缓存里不小心有重复中心)
if isinstance(pressure_leak_all.index, pd.MultiIndex):
pressure_leak_all = pressure_leak_all[
~pressure_leak_all.index.duplicated(keep="last")
]
if flow_leak_all is not None and len(flow_leak_all) > 0:
flow_leak_all = flow_leak_all[
~flow_leak_all.index.duplicated(keep="last")
]
else:
pressure_leak_all = pressure_leak_all[
~pressure_leak_all.index.duplicated(keep="last")
]
if flow_leak_all is not None and len(flow_leak_all) > 0:
flow_leak_all = flow_leak_all[
~flow_leak_all.index.duplicated(keep="last")
]
# 4) 从更新后的缓存里,取出这轮需要的中心子集(顺序与 center_list 一致)
if isinstance(pressure_leak_all.index, pd.MultiIndex):
pressure_subset = pressure_leak_all.loc[center_list]
flow_subset = (
flow_leak_all.loc[center_list]
if (flow_leak_all is not None and len(flow_leak_all) > 0)
else None
)
else:
pressure_subset = pressure_leak_all.loc[center_list, :]
flow_subset = (
flow_leak_all.loc[center_list, :]
if (flow_leak_all is not None and len(flow_leak_all) > 0)
else None
)
return pressure_subset, flow_subset, pressure_leak_all, flow_leak_all
def area_output_num_ki_improve(
candidate_center,
candidate_group,
similarity,
new_all_node,
top_group_ratio,
top_pipe_num_max,
top_pipe_num_min,
cut_ratio,
):
final_area = []
final_center = []
all_node_iter = []
if similarity.index.is_unique == False:
total_center_num = len(set(similarity.index))
else:
total_center_num = len(similarity.index)
next_group_num = min(
total_center_num, math.ceil(total_center_num / cut_ratio * top_group_ratio)
)
for i in range(next_group_num):
top_center = similarity.index[i]
top_center_index = find_list_repeat(candidate_center, top_center)
for j in range(len(top_center_index)):
final_area = final_area + candidate_group[top_center_index[j]]
all_node_iter = all_node_iter + list(new_all_node[top_center_index[j]])
final_center.append(top_center)
final_area = sorted(set(final_area))
if len(final_area) > top_pipe_num_max:
if_end = 0
elif len(final_area) > top_pipe_num_min:
if_end = 1
elif total_center_num == next_group_num:
if_end = 1
else:
if_end = 1
for i in np.arange(next_group_num, total_center_num, 1):
before_list = copy.deepcopy(final_area)
top_center = similarity.index[i]
top_center_index = candidate_center.index(top_center)
temp_group = final_area + candidate_group[top_center_index]
temp_area = sorted(set(temp_group))
if len(temp_area) < top_pipe_num_min:
final_center.append(top_center)
all_node_iter = all_node_iter + list(new_all_node[top_center_index])
final_area = temp_area
elif len(temp_area) < top_pipe_num_max:
final_center.append(top_center)
all_node_iter = all_node_iter + list(new_all_node[top_center_index])
final_area = temp_area
break
else:
a = len(temp_area) - top_pipe_num_max
b = top_pipe_num_min - len(before_list)
if a >= b:
final_area = before_list
else:
final_center.append(top_center)
all_node_iter = all_node_iter + list(new_all_node[top_center_index])
final_area = temp_area
break
final_center = sorted(set(final_center))
all_node_iter = sorted(set(all_node_iter))
return final_area, final_center, all_node_iter, if_end
def find_list_repeat(candidate_center, target):
repeated_list = []
for index, nums in enumerate(candidate_center):
if nums == target:
repeated_list.append(index)
return repeated_list
def _dedupe_preserve_order(items):
seen = set()
output = []
for item in items:
if item in seen:
continue
seen.add(item)
output.append(item)
return output
def _accumulate_stage(stage_timing, stage_name, started_at):
stage_timing[stage_name] = stage_timing.get(stage_name, 0.0) + (
perf_counter() - started_at
)
def _write_last_round_candidates_csv(
csv_path,
exit_condition,
iteration_count,
similarity_mode,
candidate_details,
fallback_similarity,
):
if not csv_path:
return None
timestamp_suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
base_path, ext = os.path.splitext(csv_path)
ext = ext or ".csv"
output_path = f"{base_path}_{timestamp_suffix}{ext}"
if candidate_details is not None and len(candidate_details) > 0:
export_df = candidate_details.copy()
if export_df.index.name == "pipe_id":
export_df = export_df.reset_index()
else:
export_df = pd.DataFrame(
{
"pipe_id": [str(pipe_id) for pipe_id in fallback_similarity.index],
"final_similarity": [float(value) for value in fallback_similarity.values],
}
)
export_df["exit_condition"] = exit_condition
export_df["iterations"] = int(iteration_count)
export_df["similarity_mode"] = similarity_mode
parent_dir = os.path.dirname(output_path)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
export_df.to_csv(output_path, index=False, encoding="utf-8-sig")
return output_path
def cal_DtoTop1(
G0, pipe_leak, located_pipe, pipe_start_node_all, pipe_end_node_all, pipe_length
):
if pipe_leak == located_pipe:
result_DtoTop1 = 0
result_DtoTop1_num = 0
else:
pipe_leak_start_node = pipe_start_node_all[pipe_leak]
pipe_leak_end_node = pipe_end_node_all[pipe_leak]
located_pipe_start_node = pipe_start_node_all[located_pipe]
located_pipe_end_node = pipe_end_node_all[located_pipe]
DtoTop1_series = pd.Series(dtype=object)
DtoTop1_num_series = pd.Series(dtype=object)
DtoTop1_series["ss"] = nx.shortest_path_length(
G0, pipe_leak_start_node, located_pipe_start_node, weight="weight"
)
DtoTop1_series["se"] = nx.shortest_path_length(
G0, pipe_leak_start_node, located_pipe_end_node, weight="weight"
)
DtoTop1_series["es"] = nx.shortest_path_length(
G0, pipe_leak_end_node, located_pipe_start_node, weight="weight"
)
DtoTop1_series["ee"] = nx.shortest_path_length(
G0, pipe_leak_end_node, located_pipe_end_node, weight="weight"
)
DtoTop1_num_series["ss"] = nx.shortest_path_length(
G0, pipe_leak_start_node, located_pipe_start_node
)
DtoTop1_num_series["se"] = nx.shortest_path_length(
G0, pipe_leak_start_node, located_pipe_end_node
)
DtoTop1_num_series["es"] = nx.shortest_path_length(
G0, pipe_leak_end_node, located_pipe_start_node
)
DtoTop1_num_series["ee"] = nx.shortest_path_length(
G0, pipe_leak_end_node, located_pipe_end_node
)
if DtoTop1_num_series.min() == 0:
result_DtoTop1_num = 1
result_DtoTop1 = DtoTop1_series.max() / 2
else:
result_DtoTop1_num = DtoTop1_num_series.min() + 1
DtoTop1_type = DtoTop1_series.argmin()
result_DtoTop1 = (
DtoTop1_series[DtoTop1_type]
+ (pipe_length[pipe_leak] + pipe_length[located_pipe]) / 2
)
return result_DtoTop1, result_DtoTop1_num
def cal_RR(located_pipe, similarity_sp):
if located_pipe in similarity_sp.index:
rank = similarity_sp.index.get_loc(located_pipe)
RR = rank / len(similarity_sp.index)
else:
RR = 1.1
return RR
def cal_cover(similarity, leak_pipe):
if leak_pipe in list(similarity.index):
cover = 1
else:
cover = 0
return cover
def cal_SD(located_pipe, real_pipe, pipe_x, pipe_y):
dx = pipe_x[located_pipe] - pipe_x[real_pipe]
dy = pipe_y[located_pipe] - pipe_y[real_pipe]
SD = math.sqrt(dx * dx + dy * dy)
return SD
def DN_search_multi_simple_add_flow_count_new(
wn,
wn_inp_path,
G0,
all_node,
node_x,
node_y,
pipe_start_node_all,
pipe_end_node_all,
pipe_diameter,
couple_node_length,
node_pipe_dic,
all_node_series,
top_group_ratio,
top_pipe_num_max,
top_pipe_num_min,
candidate_pipe_input_initial,
similarity_mode,
pressure_monitor,
pressure_predict,
pressure_normal,
pressure_leak_all,
flow_monitor,
flow_predict,
flow_normal,
flow_leak_all,
timestep_list,
max_flow,
group_basic_num,
Top_sensor_num,
if_gy,
pressure_threshold,
leak_mag,
n_workers=1,
stage_timing=None,
partition_on_full_graph=True,
visualize_partition=False,
visualize_pause_seconds=0.3,
final_candidates_csv_path=None,
):
if stage_timing is None:
stage_timing = {}
exit_condition = "unknown"
final_candidates_csv = None
iter_count = 0
all_node_iter = copy.deepcopy(all_node)
candidate_pipe_input = copy.deepcopy(candidate_pipe_input_initial) # 可能漏损管段
t1 = datetime.now()
if_flow, if_only_cos, if_only_flow = decode_mode(similarity_mode) # 定位方法
# threshold
if if_only_flow == 1:
dpressure = (flow_predict - flow_monitor).mean()
dpressure = dpressure.abs()
effective_sensor = list(dpressure.index)
else:
dpressure = (pressure_predict - pressure_monitor).mean()
dpressure = dpressure.abs()
dpressure = dpressure[dpressure > pressure_threshold]
effective_sensor = list(dpressure.index)
simulation_times = 0 # 模拟次数
if len(dpressure) > 0:
break_flag = 0
last_round_candidate_details = None
cos_h = 0
dis_h = 0
dis_f_h = 0
if_compalsive = 0
record_center_dataset = []
record_center_set = set()
# iter
while 1:
final_area = []
final_center = []
group_num = cal_group_num(candidate_pipe_input, group_basic_num)
partition_nodes = all_node if partition_on_full_graph else all_node_iter
# group 分组,得出候选漏损中心
stage_start = perf_counter()
(
candidate_center_list,
candidate_group_list,
new_all_node,
candidate_center_candidates,
) = (
metis_grouping_pipe_weight(
G0,
wn,
partition_nodes,
candidate_pipe_input,
group_num,
node_x,
node_y,
pipe_start_node_all,
pipe_end_node_all,
node_pipe_dic,
all_node_series,
couple_node_length,
pipe_diameter,
)
)
_accumulate_stage(stage_timing, "group_partitioning", stage_start)
if visualize_partition:
visualize_metis_partition(
G0,
candidate_center_list,
candidate_group_list,
node_x,
node_y,
pipe_start_node_all,
pipe_end_node_all,
title=(
f"METIS Partition Iteration {iter_count + 1} | "
f"candidate pipes={len(candidate_pipe_input)} "
f"groups={len(candidate_group_list)}"
),
block=False,
pause_seconds=visualize_pause_seconds,
)
simulation_times = simulation_times + len(candidate_center_list)
# pick_pressure_leak
# pressure_leak = pressure_leak_all.loc[candidate_center_list].loc[:, :]
# flow_leak = flow_leak_all.loc[candidate_center_list].loc[:, :]
# —— 新增泄漏量(保持你现在的一致,或从外部传入)——
# —— 只为缺失中心补算,然后取本轮需要的中心子集 ——
stage_start = perf_counter()
pressure_leak, flow_leak, pressure_leak_all, flow_leak_all = (
_ensure_signatures_for_centers(
wn=wn,
wn_inp_path=wn_inp_path,
center_list=candidate_center_list,
pressure_leak_all=pressure_leak_all,
flow_leak_all=flow_leak_all,
timestep_list=timestep_list,
pressure_monitor=pressure_monitor,
flow_monitor=flow_monitor,
leak_mag=leak_mag,
n_workers=n_workers,
)
)
_accumulate_stage(stage_timing, "signature_for_candidates", stage_start)
# pressure_leak_f= pressure_leak.swaplevel()
# --------------------------------------------------------
add_center = []
leak_center_dict = dict()
for i in range(len(candidate_center_list)):
primary_center = candidate_center_list[i]
houxuan_center = [
center
for center in candidate_center_candidates[i]
if center != primary_center
]
candidate_group_set = set(candidate_group_list[i])
for each_center in record_center_dataset:
if (
each_center in candidate_group_set
and each_center != primary_center
):
houxuan_center.append(each_center)
add_center = add_center + houxuan_center
leak_center_dict[primary_center] = _dedupe_preserve_order(
houxuan_center + [primary_center]
)
add_center = _dedupe_preserve_order(add_center)
for each_group_centers in candidate_center_candidates:
for each_center in each_group_centers:
if each_center not in record_center_set:
record_center_dataset.append(each_center)
record_center_set.add(each_center)
for each_center in add_center:
if each_center not in record_center_set:
record_center_dataset.append(each_center)
record_center_set.add(each_center)
# --------------------------------------------------------
# --------------------------------------------------------
# if len(add_center) > 0:
# s3 = pressure_leak_all.loc[add_center]
# pressure_leak = pd.concat([pressure_leak, s3])
# s4 = flow_leak_all.loc[add_center]
# flow_leak = pd.concat([flow_leak, s4])
# --------------------------------------------------------
# 只为 add_center 里还没算过的中心补算,并与本轮中心合并
if len(add_center) > 0:
stage_start = perf_counter()
pressure_add, flow_add, pressure_leak_all, flow_leak_all = (
_ensure_signatures_for_centers(
wn=wn,
wn_inp_path=wn_inp_path,
center_list=add_center,
pressure_leak_all=pressure_leak_all,
flow_leak_all=flow_leak_all,
timestep_list=timestep_list,
pressure_monitor=pressure_monitor,
flow_monitor=flow_monitor,
leak_mag=leak_mag, # 与上面一致
n_workers=n_workers,
)
)
_accumulate_stage(
stage_timing, "signature_for_extra_centers", stage_start
)
pressure_leak = pd.concat([pressure_leak, pressure_add], axis=0)
if (flow_leak is not None) and (flow_add is not None):
flow_leak = pd.concat([flow_leak, flow_add], axis=0)
# --------------------------------------------------------
#
if len(candidate_pipe_input) < 1.2 * top_pipe_num_max / top_group_ratio:
if_compalsive = 1
cos_h, dis_h, dis_f_h = adjust_ratio(similarity_mode, cos_h, dis_h, dis_f_h)
candidate_center_list_sup = _dedupe_preserve_order(
candidate_center_list + add_center
)
stage_start = perf_counter()
similarity, cos_h, dis_h, dis_f_h, break_flag, similarity_details = (
cal_similarity_all_multi_new_sq_improve_double_lzr(
candidate_center_list_sup,
similarity_mode,
pressure_leak,
pressure_monitor,
pressure_predict,
pressure_normal,
if_flow,
if_only_cos,
if_only_flow,
flow_leak,
flow_monitor,
flow_predict,
flow_normal,
timestep_list,
Top_sensor_num,
if_gy,
effective_sensor,
cos_h,
dis_h,
dis_f_h,
if_compalsive,
max_flow,
)
)
last_round_candidate_details = similarity_details
_accumulate_stage(stage_timing, "similarity_ranking", stage_start)
if break_flag == 1:
exit_condition = "similarity_break_flag"
break
new_similarity = update_similarity(
candidate_center_list, similarity, leak_center_dict
)
if len(candidate_pipe_input) > top_pipe_num_max / top_group_ratio:
cut_ratio, new_similarity = extra_judge(new_similarity)
else:
cut_ratio = 1
stage_start = perf_counter()
final_area_t, final_center_t, all_node_new_1, if_end = (
area_output_num_ki_improve(
candidate_center_list,
candidate_group_list,
new_similarity,
new_all_node,
top_group_ratio,
top_pipe_num_max,
top_pipe_num_min,
cut_ratio,
)
)
_accumulate_stage(stage_timing, "candidate_area_selection", stage_start)
final_area = final_area + final_area_t
final_center = final_center + final_center_t
final_area = sorted(set(final_area))
final_center = sorted(set(final_center))
if if_end == 1:
exit_condition = "candidate_area_if_end"
break
elif len(candidate_pipe_input) == len(final_area):
exit_condition = "candidate_size_no_change"
break
else:
candidate_pipe_input = final_area
if not partition_on_full_graph:
all_node_iter = all_node_new_1
iter_count += 1
sys.stdout.write(
"\r"
+ "已经完成"
+ str(iter_count)
+ "次迭代计算"
+ "候选节点"
+ str(len(final_area))
+ ""
)
# if break_flag == 0:
# final_area_pipe = copy.deepcopy(final_area)
# simulation_times = simulation_times + len(final_area)
# pressure_leak_sp = pressure_leak_all.loc[final_area_pipe].loc[:, :]
# flow_leak_sp = flow_leak_all.loc[final_area_pipe].loc[:, :]
# similarity_sp, cos_h, dis_h, dis_f_h, break_flag = cal_similarity_all_multi_new_sq_improve_double_lzr(
# final_area_pipe, similarity_mode, pressure_leak_sp,
# pressure_monitor, pressure_predict, pressure_normal, if_flow,
# if_only_cos, if_only_flow,
# flow_leak_sp, flow_monitor, flow_predict, flow_normal,
# timestep_list, Top_sensor_num, if_gy, effective_sensor, cos_h, dis_h, dis_f_h, if_compalsive, max_flow)
if break_flag == 0:
final_area_pipe = list(final_area) # 确保是 list
# 只为还没算过的管段补齐 SLF(按需计算)
stage_start = perf_counter()
pressure_leak_sp, flow_leak_sp, pressure_leak_all, flow_leak_all = (
_ensure_signatures_for_centers(
wn=wn,
wn_inp_path=wn_inp_path,
center_list=final_area_pipe, # 这次要用的“最终区域里的所有管段”
pressure_leak_all=pressure_leak_all, # 累积缓存(会被更新)
flow_leak_all=flow_leak_all,
timestep_list=timestep_list,
pressure_monitor=pressure_monitor,
flow_monitor=flow_monitor,
leak_mag=leak_mag,
n_workers=n_workers,
)
)
_accumulate_stage(stage_timing, "signature_for_final_area", stage_start)
# 如果你要精确统计模拟次数,这里可以加上“本次新补的数量”,
# 做法:让 _ensure_signatures_for_centers 额外返回 need_cnt,再 simulation_times += need_cnt
stage_start = perf_counter()
(
similarity_sp,
cos_h,
dis_h,
dis_f_h,
break_flag,
similarity_details,
) = (
cal_similarity_all_multi_new_sq_improve_double_lzr(
final_area_pipe,
similarity_mode,
pressure_leak_sp,
pressure_monitor,
pressure_predict,
pressure_normal,
if_flow,
if_only_cos,
if_only_flow,
flow_leak_sp,
flow_monitor,
flow_predict,
flow_normal,
timestep_list,
Top_sensor_num,
if_gy,
effective_sensor,
cos_h,
dis_h,
dis_f_h,
if_compalsive,
max_flow,
)
)
last_round_candidate_details = similarity_details
_accumulate_stage(stage_timing, "similarity_final", stage_start)
else:
dpressure = (pressure_predict - pressure_monitor).mean()
dpressure = dpressure.abs()
simulation_times = simulation_times + len(dpressure.index)
similarity_sp = pd.Series(dtype=float)
for each_node in dpressure.index:
pipe = node_pipe_dic[each_node][0]
similarity_sp.loc[pipe] = dpressure.loc[each_node]
similarity_sp = similarity_sp.sort_values(ascending=False, kind="mergesort")
t2 = datetime.now()
final_area_pipe = []
sys.stdout.write(
"\r"
+ "已经完成"
+ str(iter_count + 1)
+ "次迭代计算"
+ "候选节点"
+ str(len(final_area_pipe))
+ ""
)
t2 = datetime.now()
dt = (t2 - t1).seconds
final_candidates_csv = _write_last_round_candidates_csv(
csv_path=final_candidates_csv_path,
exit_condition=exit_condition,
iteration_count=iter_count + 1,
similarity_mode=similarity_mode,
candidate_details=last_round_candidate_details,
fallback_similarity=similarity_sp,
)
else:
exit_condition = "no_effective_sensor_after_threshold"
dpressure = (pressure_predict - pressure_monitor).mean()
dpressure = dpressure.abs()
similarity_sp = pd.Series(dtype=float)
for each_node in dpressure.index:
pipe = node_pipe_dic[each_node][0]
similarity_sp.loc[pipe] = dpressure.loc[each_node]
similarity_sp = similarity_sp.sort_values(ascending=False, kind="mergesort")
t2 = datetime.now()
dt = (t2 - t1).seconds
final_candidates_csv = _write_last_round_candidates_csv(
csv_path=final_candidates_csv_path,
exit_condition=exit_condition,
iteration_count=0,
similarity_mode=similarity_mode,
candidate_details=None,
fallback_similarity=similarity_sp,
)
stage_timing["iterations"] = iter_count + 1 if len(dpressure) > 0 else 0
stage_timing["total_elapsed_seconds"] = float(dt)
stage_timing["exit_condition"] = exit_condition
stage_timing["final_candidates_csv"] = final_candidates_csv
return (
similarity_sp.index[0],
dt,
simulation_times,
wn,
similarity_sp,
exit_condition,
final_candidates_csv,
)
@@ -0,0 +1,563 @@
"""漏损模拟模块。"""
import math
import multiprocessing as mp
import os
import sys
import pandas as pd
import wntr
from app.algorithms._utils import _cleanup_temp_files
_PIPE2LEAKNODE = None
_SIGNATURE_WORKER_DATA = {}
def _make_temp_prefix(tag):
temp_dir = os.path.abspath(os.path.join("temp", "burst_location"))
os.makedirs(temp_dir, exist_ok=True)
safe_tag = str(tag).replace(os.sep, "_").replace(" ", "_")
return os.path.join(temp_dir, f"{safe_tag}_{os.getpid()}")
def _snapshot_hydraulic_options(wn):
options = wn.options
return {
"demand_model": options.hydraulic.demand_model,
"duration": float(options.time.duration),
"hydraulic_timestep": float(options.time.hydraulic_timestep),
"pattern_timestep": float(options.time.pattern_timestep),
"report_timestep": float(options.time.report_timestep),
"required_pressure": float(options.hydraulic.required_pressure),
"minimum_pressure": float(options.hydraulic.minimum_pressure),
}
def _apply_hydraulic_options(wn, option_values):
options = wn.options
options.hydraulic.demand_model = option_values["demand_model"]
options.time.duration = option_values["duration"]
options.time.hydraulic_timestep = option_values["hydraulic_timestep"]
options.time.pattern_timestep = option_values["pattern_timestep"]
options.time.report_timestep = option_values["report_timestep"]
options.hydraulic.required_pressure = option_values["required_pressure"]
options.hydraulic.minimum_pressure = option_values["minimum_pressure"]
def simple_add_leak(wn, leak_mag, leak_pipe):
whole_inf = dict()
leak_pipe_self = wn.get_link(leak_pipe)
pipe_diameter = leak_pipe_self.diameter
pipe_length = leak_pipe_self.length
pipe_roughness = leak_pipe_self.roughness
pipe_minor_loss = leak_pipe_self.minor_loss
# pipe_status = leak_pipe_self.status
# pipe_check_valve = leak_pipe_self.check_valve
pipe_start_node = leak_pipe_self.start_node_name
pipe_end_node = leak_pipe_self.end_node_name
# close the pipe
# leak_pipe_self.status = 'Closed'
wn.remove_link(leak_pipe)
# add the pipe
add_pipe1 = leak_pipe + "A"
add_pipe2 = leak_pipe + "B"
add_node = leak_pipe + "_"
start_n = wn.get_node(pipe_start_node)
end_n = wn.get_node(pipe_end_node)
if start_n.node_type == "Reservoir":
end_n_elevation = end_n.elevation
start_n_elevation = end_n_elevation
elif end_n.node_type == "Reservoir":
start_n_elevation = start_n.elevation
end_n_elevation = start_n_elevation
else:
end_n_elevation = end_n.elevation
start_n_elevation = start_n.elevation
elevation_self = (start_n_elevation + end_n_elevation) / 2
coordinates_self = (
(start_n.coordinates[0] + end_n.coordinates[0]) / 2,
(start_n.coordinates[1] + end_n.coordinates[1]),
)
wn.add_junction(
add_node, base_demand=0, elevation=elevation_self, coordinates=coordinates_self
)
leak_node = wn.get_node(add_node)
wn.add_pipe(
add_pipe1,
start_node_name=pipe_start_node,
end_node_name=add_node,
length=pipe_length / 2,
diameter=pipe_diameter,
roughness=pipe_roughness,
minor_loss=pipe_minor_loss,
)
wn.add_pipe(
add_pipe2,
start_node_name=pipe_end_node,
end_node_name=add_node,
length=pipe_length / 2,
diameter=pipe_diameter,
roughness=pipe_roughness,
minor_loss=pipe_minor_loss,
)
# simulation
leak_node.add_demand(base=leak_mag, pattern_name="add_leak")
whole_inf["leak_node_name"] = add_node
whole_inf["add_pipe1"] = add_pipe1
whole_inf["add_pipe2"] = add_pipe2
whole_inf["leak_pipe"] = leak_pipe
whole_inf["pipe_start_node"] = pipe_start_node
whole_inf["pipe_end_node"] = pipe_end_node
whole_inf["pipe_length"] = pipe_length
whole_inf["pipe_diameter"] = pipe_diameter
whole_inf["pipe_roughness"] = pipe_roughness
whole_inf["pipe_minor_loss"] = pipe_diameter
return wn, whole_inf, add_pipe1
def simple_recover_wn(wn, whole_inf):
leak_node = wn.get_node(whole_inf["leak_node_name"])
del leak_node.demand_timeseries_list[-1]
# update
wn.remove_link(whole_inf["add_pipe1"])
wn.remove_link(whole_inf["add_pipe2"])
wn.remove_node(whole_inf["leak_node_name"])
# open the pipe
# leak_pipe_self.status = 'Open'
wn.add_pipe(
whole_inf["leak_pipe"],
start_node_name=whole_inf["pipe_start_node"],
end_node_name=whole_inf["pipe_end_node"],
length=whole_inf["pipe_length"],
diameter=whole_inf["pipe_diameter"],
roughness=whole_inf["pipe_roughness"],
minor_loss=whole_inf["pipe_minor_loss"],
)
return wn
def disable_all_controls_temporarily(wn):
"""返回(控制名, 控制对象)的列表,之后可用 restore_controls 还原。"""
removed = []
# WNTR 的控制都在 wn.control_name_list / wn.get_control / wn.remove_control
for cname in list(wn.control_name_list):
ctrl = wn.get_control(cname)
removed.append((cname, ctrl))
wn.remove_control(cname)
return removed
def restore_controls(wn, removed):
"""把先前禁用的控制全部加回去。"""
for cname, ctrl in removed:
wn.add_control(cname, ctrl)
def set_pipe2leaknode_mapping(mapping):
global _PIPE2LEAKNODE
_PIPE2LEAKNODE = mapping
def _get_or_create_leak_demand_ts(leak_node):
"""
返回:泄漏专用 demand 的下标 idx。
若不存在,以 category='leak' 新建一条 base=0.0 的 demand。
"""
# 先尝试找到已有的 'leak' 分类
for i, ts in enumerate(leak_node.demand_timeseries_list):
# WNTR 的 Demand object 存在 category 属性
if getattr(ts, "category", None) == "leak":
return i
# 没有则新建(base=0.0,后续临时改 base_value
leak_node.add_demand(base=0.0, pattern_name=None, category="leak")
return len(leak_node.demand_timeseries_list) - 1
def ensure_mid_node(wn, leak_pipe):
add_pipe1 = f"{leak_pipe}A"
add_pipe2 = f"{leak_pipe}B"
add_node = f"{leak_pipe}__mid"
if add_node in wn.node_name_list:
return add_node
if leak_pipe in wn.link_name_list:
leak_pipe_self = wn.get_link(leak_pipe)
pipe_diameter = leak_pipe_self.diameter
pipe_length = leak_pipe_self.length
pipe_roughness = leak_pipe_self.roughness
pipe_minor_loss = leak_pipe_self.minor_loss
pipe_start_node = leak_pipe_self.start_node_name
pipe_end_node = leak_pipe_self.end_node_name
start_n = wn.get_node(pipe_start_node)
end_n = wn.get_node(pipe_end_node)
if start_n.node_type == "Reservoir":
end_elev = end_n.elevation
start_elev = end_elev
elif end_n.node_type == "Reservoir":
start_elev = start_n.elevation
end_elev = start_elev
else:
end_elev = end_n.elevation
start_elev = start_n.elevation
elev_mid = (start_elev + end_elev) / 2.0
x_mid = (start_n.coordinates[0] + end_n.coordinates[0]) / 2.0
y_mid = (start_n.coordinates[1] + end_n.coordinates[1]) / 2.0
wn.remove_link(leak_pipe)
wn.add_junction(
add_node, base_demand=0.0, elevation=elev_mid, coordinates=(x_mid, y_mid)
)
wn.add_pipe(
add_pipe1,
start_node_name=pipe_start_node,
end_node_name=add_node,
length=pipe_length / 2.0,
diameter=pipe_diameter,
roughness=pipe_roughness,
minor_loss=pipe_minor_loss,
)
wn.add_pipe(
add_pipe2,
start_node_name=add_node,
end_node_name=pipe_end_node,
length=pipe_length / 2.0,
diameter=pipe_diameter,
roughness=pipe_roughness,
minor_loss=pipe_minor_loss,
)
return add_node
# 若 A/B 已存在但中点不在,建议确认网络一致性
raise KeyError(f"Cannot ensure mid node for pipe '{leak_pipe}'.")
def leak_simulation_pipe_dd_multi_pf(
wn, leak_mag, leak_pipe, sensor_name, file_prefix=None
):
"""
优化版:
- 不再 remove/add link/node
- 直接在预插入的中点泄漏节点上设置 base_demand = leak_mag;仿真后设回 0
"""
wn.options.hydraulic.demand_model = "DD"
# 确保中点节点存在
leak_node_name = ensure_mid_node(wn, leak_pipe)
leak_node = wn.get_node(leak_node_name)
# 拿到泄漏专用的 demand time-series 下标
leak_idx = _get_or_create_leak_demand_ts(leak_node)
ts_obj = leak_node.demand_timeseries_list[leak_idx]
# 记录原值(通常是 0.0
orig_base = ts_obj.base_value
try:
# 打开泄漏:只改 base_value,不碰 base_demand(只读)
ts_obj.base_value = float(leak_mag)
# 仿真
sim = wntr.sim.EpanetSimulator(wn)
if file_prefix is None:
results = sim.run_sim()
else:
results = sim.run_sim(file_prefix=file_prefix)
# 输出(保持列顺序)
pressure_output = results.node["pressure"].loc[:, sensor_name]
# flow_output = results.link['flowrate'].loc[:, sensor_f_name]
return wn, pressure_output
finally:
# 关闭泄漏:还原 base_value
ts_obj.base_value = orig_base
if file_prefix is not None:
_cleanup_temp_files(file_prefix)
def prepare_leak_infrastructure(wn, candidate_pipes):
"""
把 candidate_pipes 每条管段切成两段,并在中点插入一个泄漏节点(base_demand=0)。
返回一个映射:pipe_id -> leak_node_name
注意:只做一次;后续仿真通过在该节点设置 base_demand 实现“打开泄漏”,结束后恢复为 0。
"""
pipe2leaknode = {}
for leak_pipe in candidate_pipes:
if leak_pipe in pipe2leaknode:
continue
leak_pipe_self = wn.get_link(leak_pipe)
pipe_diameter = leak_pipe_self.diameter
pipe_length = leak_pipe_self.length
pipe_roughness = leak_pipe_self.roughness
pipe_minor_loss = leak_pipe_self.minor_loss
pipe_start_node = leak_pipe_self.start_node_name
pipe_end_node = leak_pipe_self.end_node_name
# 计算中点高程/坐标(与原逻辑一致)
start_n = wn.get_node(pipe_start_node)
end_n = wn.get_node(pipe_end_node)
if start_n.node_type == "Reservoir":
end_elev = end_n.elevation
start_elev = end_elev
elif end_n.node_type == "Reservoir":
start_elev = start_n.elevation
end_elev = start_elev
else:
end_elev = end_n.elevation
start_elev = start_n.elevation
elev_mid = (start_elev + end_elev) / 2.0
x_mid = (start_n.coordinates[0] + end_n.coordinates[0]) / 2.0
y_mid = (start_n.coordinates[1] + end_n.coordinates[1]) / 2.0
# 先删原管,再加中点与两段半长管(只做一次)
wn.remove_link(leak_pipe)
add_pipe1 = f"{leak_pipe}A"
add_pipe2 = f"{leak_pipe}B"
add_node = f"{leak_pipe}__mid" # 唯一命名,后面直接用它当泄漏节点
wn.add_junction(
add_node, base_demand=0.0, elevation=elev_mid, coordinates=(x_mid, y_mid)
)
wn.add_pipe(
add_pipe1,
start_node_name=pipe_start_node,
end_node_name=add_node,
length=pipe_length / 2.0,
diameter=pipe_diameter,
roughness=pipe_roughness,
minor_loss=pipe_minor_loss,
)
wn.add_pipe(
add_pipe2,
start_node_name=add_node,
end_node_name=pipe_end_node,
length=pipe_length / 2.0,
diameter=pipe_diameter,
roughness=pipe_roughness,
minor_loss=pipe_minor_loss,
)
pipe2leaknode[leak_pipe] = add_node
return pipe2leaknode
def normal_simulation_pf(
wn, drive_mode, sensor_name, sensor_f_name, inp_time, require_p, minimum_p
):
# inp_time = 0
if drive_mode == "PDD": # 需水量根据节点压力动态调整
wn.options.hydraulic.demand_model = "PDD"
wn.options.hydraulic.required_pressure = require_p
wn.options.hydraulic.minimum_pressure = minimum_p
elif drive_mode == "DD": # 需水量固定,与压力无关
wn.options.hydraulic.demand_model = "DD"
sim = wntr.sim.EpanetSimulator(wn)
results = sim.run_sim()
pressure_all = results.node["pressure"][sensor_name]
pressure = pressure_all.iloc[inp_time]
demand_all = results.node["demand"]
demand = demand_all.iloc[inp_time]
sum_demand = cal_sum_demand(demand)
flow_all = results.link["flowrate"][sensor_f_name]
flow = flow_all.iloc[inp_time]
top_sensor = pressure.idxmin()
basic_p = results.node["pressure"]
basic_p = basic_p.iloc[inp_time]
return pressure, flow, basic_p, top_sensor, sum_demand
def normal_simulation_multi_pf(
wn, drive_mode, sensor_name, sensor_f_name, require_p, minimum_p
):
# inp_time = 0
if drive_mode == "PDD":
wn.options.hydraulic.demand_model = "PDD"
wn.options.hydraulic.required_pressure = require_p
wn.options.hydraulic.minimum_pressure = minimum_p
elif drive_mode == "DD":
wn.options.hydraulic.demand_model = "DD"
sim = wntr.sim.EpanetSimulator(wn)
results = sim.run_sim()
pressure_all = results.node["pressure"][sensor_name]
pressure = pressure_all
demand_all = results.node["demand"]
demand = demand_all
flow = results.link["flowrate"][sensor_f_name]
sum_demand = pd.Series(dtype=object)
for i in range(len(demand.index)):
sum_demand[str(demand.index[i])] = cal_sum_demand(demand.iloc[i])
if type(pressure) == pd.core.series.Series:
top_sensor = pressure.idxmin()
else:
mean_pressure = pressure.mean()
top_sensor = mean_pressure.idxmin()
basic_p = results.node["pressure"]
return pressure, flow, basic_p, top_sensor, sum_demand
def simple_simulation_pf(wn, sensor_name, sensor_f_name, leak_pipe, add_pipe1):
sim = wntr.sim.EpanetSimulator(wn)
results = sim.run_sim()
pressure_all = results.node["pressure"][sensor_name]
if len(leak_pipe) > 0 and leak_pipe in sensor_f_name:
f_sensor_name = [add_pipe1 if i == leak_pipe else i for i in sensor_f_name]
flow_all = results.link["flowrate"][f_sensor_name]
flow_all.columns = sensor_f_name
else:
flow_all = results.link["flowrate"][sensor_f_name]
return pressure_all, flow_all
def cal_sum_demand(demand):
sum_demand = 0
for i in range(len(demand)):
if demand.iloc[i] > 0:
sum_demand += demand.iloc[i]
return sum_demand
def cal_signature_pipe_multi_pf(
wn,
leak_mag,
candidate_center,
timestep_list,
sensor_name,
n_workers=1,
wn_inp_path=None,
):
candidate_center_num = len(candidate_center)
pressure_leak = pd.DataFrame(
index=pd.MultiIndex.from_product([candidate_center, timestep_list]),
columns=sensor_name,
)
# flow_leak = pd.DataFrame(index=pd.MultiIndex.from_product([candidate_center, timestep_list]),
# columns=sensor_f_name)
pressure_leak = pressure_leak.sort_index()
# flow_leak = flow_leak.sort_index()
can_parallel = (
n_workers > 1
and candidate_center_num > 1
and wn_inp_path is not None
and len(str(wn_inp_path)) > 0
)
if can_parallel:
option_values = _snapshot_hydraulic_options(wn)
worker_count = min(n_workers, candidate_center_num)
start_methods = mp.get_all_start_methods()
context_name = "spawn" if "spawn" in start_methods else start_methods[0]
with mp.get_context(context_name).Pool(
processes=worker_count,
initializer=_signature_worker_init,
initargs=(
str(wn_inp_path),
float(leak_mag),
list(sensor_name),
option_values,
list(candidate_center),
),
) as pool:
for i, (center_name, pressure_array) in enumerate(
pool.imap(_signature_worker_run_center, candidate_center)
):
pressure_leak.loc[(center_name, slice(None)), :] = pressure_array
sys.stdout.write("\r" + "已经完成计算" + str(i + 1) + "个特征中心")
else:
# Pre-insert all mid-nodes so every simulation sees the same topology
for center in candidate_center:
ensure_mid_node(wn, center)
for i in range(candidate_center_num):
temp_prefix = _make_temp_prefix(f"sig_{i}")
wn, pressure_output = leak_simulation_pipe_dd_multi_pf(
wn,
leak_mag,
candidate_center[i],
sensor_name,
file_prefix=temp_prefix,
)
# leak_or_not_list.append(leak_or_not)
pressure_leak.loc[(candidate_center[i], slice(None)), :] = (
pressure_output.to_numpy()
)
# flow_leak.loc[candidate_center[i]].loc[:, :] = flow_output
sys.stdout.write("\r" + "已经完成计算" + str(i + 1) + "个特征中心")
return pressure_leak, candidate_center
def _signature_worker_init(
inp_path, leak_mag, sensor_name, option_values, candidate_centers=None
):
global _SIGNATURE_WORKER_DATA
wn = wntr.network.WaterNetworkModel(inp_path)
_apply_hydraulic_options(wn, option_values)
# Pre-insert ALL mid-nodes so every simulation runs on the same topology,
# regardless of which worker handles which task.
if candidate_centers is not None:
for center in candidate_centers:
ensure_mid_node(wn, center)
_SIGNATURE_WORKER_DATA = {
"wn": wn,
"leak_mag": leak_mag,
"sensor_name": sensor_name,
}
def _signature_worker_run_center(center_name):
data = _SIGNATURE_WORKER_DATA
temp_prefix = _make_temp_prefix(f"sig_worker_{center_name}")
_, pressure_output = leak_simulation_pipe_dd_multi_pf(
data["wn"],
data["leak_mag"],
center_name,
data["sensor_name"],
file_prefix=temp_prefix,
)
return center_name, pressure_output.to_numpy()
def pick_pipe(all_pipes, pipe_diameter, limited_diameter):
candidate_pipe = []
for each_pipe in all_pipes:
if pipe_diameter[each_pipe] >= limited_diameter:
candidate_pipe.append(each_pipe)
return candidate_pipe
def cal_possible_pipe(leak_flow, all_pipe, pipe_diameter):
basic_pressure = 10 # 基础压力
discharge_coeff = 0.6 # 经验系数
break_area_ratio = 1 # 爆管面积比 0.5 1.25
break_area = leak_flow / (
discharge_coeff * math.sqrt(2 * basic_pressure * 9.81)
) # 爆管面积 m3/h
"""break_area_diameter = math.sqrt(4 * break_area / math.pi)
min_diameter = (math.ceil(1000 * break_area_diameter / break_area_ratio)) / 1000"""
break_area_diameter = math.sqrt(
4 * break_area / math.pi / break_area_ratio
) # 爆管直径
min_diameter = (math.ceil(1000 * break_area_diameter)) / 1000 # 向上取整
new_all_pipe = pick_pipe(all_pipe, pipe_diameter, min_diameter)
return new_all_pipe, min_diameter
def extract_links(data, link_types, direction):
return [
link
for res_data in data.values()
for link_type in link_types
for link in res_data[link_type][direction]
]
@@ -0,0 +1,137 @@
"""管网模型读取与图构建模块。"""
import copy
import numpy as np
import networkx as nx
import pandas as pd
import wntr
def load_inp(inp_name, inp_location, inp_time, driven_mode, require_p, minimum_p):
inp_file = inp_location + inp_name
wn = wntr.network.WaterNetworkModel(inp_file)
if driven_mode == "PDD":
wn.options.hydraulic.demand_model = "PDD"
wn.options.hydraulic.required_pressure = require_p
wn.options.hydraulic.minimum_pressure = minimum_p
else:
wn.options.hydraulic.demand_model = "DD"
return wn
def read_inf_inp(wn):
all_node = wn.node_name_list
node_elevation = wn.query_node_attribute("elevation")
node_coordinates = wn.query_node_attribute("coordinates")
all_pipe = wn.pipe_name_list
# 改_wz__________________________________
n_pipe = []
for p in all_pipe:
pipe = wn.get_link(p)
if pipe.initial_status == 0: # 状态为'Closed'
n_pipe.append(p)
candidate_pipe_init = sorted(set(all_pipe) - set(n_pipe))
pipe_start_node = wn.query_link_attribute(
"start_node_name", link_type=wntr.network.model.Pipe
)
pipe_end_node = wn.query_link_attribute(
"end_node_name", link_type=wntr.network.model.Pipe
)
pipe_length = wn.query_link_attribute("length")
pipe_diameter = wn.query_link_attribute("diameter")
return (
all_node,
node_elevation,
node_coordinates,
candidate_pipe_init,
pipe_start_node,
pipe_end_node,
pipe_length,
pipe_diameter,
)
def read_inf_inp_other(wn):
all_link = wn.link_name_list
pipe_start_node_all = wn.query_link_attribute("start_node_name")
pipe_end_node_all = wn.query_link_attribute("end_node_name")
return all_link, pipe_start_node_all, pipe_end_node_all
def construct_graph(wn):
length = wn.query_link_attribute("length")
G = wn.get_graph(wn, link_weight=length)
# 转为无向图
G0 = G.to_undirected()
# A0 = np.array(nx.adjacency_graph(G0).todense())
return G0 # , A0
def cal_pipe_coordinate(all_pipe, pipe_start_node, pipe_end_node, node_coordinates):
pipe_num = len(all_pipe)
pipe_coordinates = np.zeros([pipe_num, 2])
pipe_x = copy.deepcopy(pipe_start_node)
pipe_y = copy.deepcopy(pipe_start_node)
for i in range(pipe_num):
temp_pipe = all_pipe[i]
pipe_x[temp_pipe] = (
node_coordinates[pipe_start_node[temp_pipe]][0]
+ node_coordinates[pipe_end_node[temp_pipe]][0]
) / 2
pipe_y[temp_pipe] = (
node_coordinates[pipe_start_node[temp_pipe]][1]
+ node_coordinates[pipe_end_node[temp_pipe]][1]
) / 2
return pipe_x, pipe_y
def cal_node_coordinate(all_node, node_coordinates):
node_x = copy.deepcopy(node_coordinates)
node_y = copy.deepcopy(node_coordinates)
for i in range(len(node_x)):
temp_node = all_node[i]
node_x[temp_node] = node_coordinates[temp_node][0]
node_y[temp_node] = node_coordinates[temp_node][1]
return node_x, node_y
def produce_pattern_value(wn, all_node):
wn_o = copy.deepcopy(wn)
# 改_wz_____________________________
# sample_node = wn_o.get_node(all_node[0])
# num_categories = len(sample_node.demand_timeseries_list)
num_categories = 1
columns = [f"D{i}" for i in range(num_categories)]
basic_demand_pd = pd.DataFrame(index=all_node, columns=columns)
for each in all_node:
node = wn_o.get_node(each)
for i in range(num_categories):
basic_demand_pd.loc[each, columns[i]] = node.demand_timeseries_list[
i
].base_value
return basic_demand_pd
def _build_node_pipe_maps(
all_nodes, candidate_pipes, pipe_start_node, pipe_end_node, pipe_length
):
node_pipe_dic = {node: [] for node in all_nodes}
couple_node_length = {}
for pipe in candidate_pipes:
start_node = pipe_start_node[pipe]
end_node = pipe_end_node[pipe]
if start_node in node_pipe_dic:
node_pipe_dic[start_node].append(pipe)
if end_node in node_pipe_dic:
node_pipe_dic[end_node].append(pipe)
length = float(pipe_length[pipe])
couple_node_length[f"{start_node},{end_node}"] = length
couple_node_length[f"{end_node},{start_node}"] = length
return node_pipe_dic, couple_node_length
@@ -0,0 +1,456 @@
"""管网分区模块。"""
import math
import matplotlib.pyplot as plt
import networkx as nx
import networkx as networkx
import numpy as np
import pandas as pd
import pymetis
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse.csgraph import connected_components
def _to_metis_edge_weight(edge_weight):
weight = float(edge_weight)
if not math.isfinite(weight):
raise ValueError(f"Invalid non-finite METIS edge weight: {edge_weight}")
# pymetis expects integer edge weights.
return max(1, int(round(weight)))
def _dedupe_preserve_order(items):
seen = set()
output = []
for item in items:
if item in seen:
continue
seen.add(item)
output.append(item)
return output
def pick_center_pipe(node_x, node_y, candidate_pipe, pipe_start_node, pipe_end_node):
candidate_pipe_list = list(candidate_pipe)
start_nodes = pipe_start_node[candidate_pipe_list]
end_nodes = pipe_end_node[candidate_pipe_list]
x_vals = (node_x[start_nodes].to_numpy() + node_x[end_nodes].to_numpy()) / 2.0
y_vals = (node_y[start_nodes].to_numpy() + node_y[end_nodes].to_numpy()) / 2.0
mean_x = float(np.mean(x_vals))
mean_y = float(np.mean(y_vals))
distance = np.abs(x_vals - mean_x) + np.abs(y_vals - mean_y)
center_idx = int(np.argmin(distance))
return candidate_pipe_list[center_idx]
def pick_max_diameter_pipe(candidate_pipe, pipe_diameter):
candidate_pipe_list = list(candidate_pipe)
diameters = pd.to_numeric(
pipe_diameter.reindex(candidate_pipe_list), errors="coerce"
).dropna()
if len(diameters) != len(candidate_pipe_list):
missing = sorted(set(candidate_pipe_list) - set(diameters.index))
preview = ", ".join(map(str, missing[:10]))
raise ValueError(f"Missing or invalid diameter for pipes: {preview}")
max_diameter = float(diameters.max())
max_diameter_pipes = sorted(
[pipe for pipe, diameter in diameters.items() if float(diameter) == max_diameter],
key=str,
)
return max_diameter_pipes[0]
def pick_dual_center_pipes(
node_x, node_y, candidate_pipe, pipe_start_node, pipe_end_node, pipe_diameter
):
geometric_center = pick_center_pipe(
node_x, node_y, candidate_pipe, pipe_start_node, pipe_end_node
)
diameter_center = pick_max_diameter_pipe(candidate_pipe, pipe_diameter)
return _dedupe_preserve_order([geometric_center, diameter_center])
def find_new_center_pipe(
node_x,
node_y,
candidate_pipe,
pipe_start_node,
pipe_end_node,
pipe_diameter,
record_center,
):
new_candidate_pipe = sorted(set(candidate_pipe) - set(record_center))
if new_candidate_pipe == []:
new_candidate_pipe = candidate_pipe
center_t = pick_center_pipe(
node_x,
node_y,
new_candidate_pipe,
pipe_start_node,
pipe_end_node,
)
return center_t
def cal_area_node_linked_pipe(nodeset, node_pipe_dic):
pipeset = []
for temp_node in nodeset:
pipeset.extend(node_pipe_dic[temp_node])
return pipeset
def metis_grouping_pipe_weight(
G0,
wn,
all_node_iter,
candidate_pipe_input,
group_num,
node_x,
node_y,
pipe_start_node_all,
pipe_end_node_all,
node_pipe_dic,
all_node_series,
couple_node_length,
pipe_diameter,
):
all_node_iter_series_new = all_node_series[all_node_iter]
all_node_iter_series_new = all_node_iter_series_new.sort_values(ascending=True)
all_node_iter_new = list(all_node_iter_series_new.index)
G1 = G0.subgraph(all_node_iter_new)
delimiter = " "
adjacency_list = []
node_dict = {}
c_new = 0
for each_node in all_node_iter_new:
node_dict[each_node] = c_new
c_new = c_new + 1
correspond_dic = {}
count_node = 0
w = []
for node_name in all_node_iter_new:
neighbors = G1[node_name]
w_temp = []
n_t = [node_dict[node_name]]
for neighbor_name in sorted(neighbors.keys()):
edge_data = neighbors[neighbor_name]
edge_key = f"{node_name},{neighbor_name}"
reverse_edge_key = f"{neighbor_name},{node_name}"
if edge_key in couple_node_length:
edge_weight = couple_node_length[edge_key]
elif reverse_edge_key in couple_node_length:
edge_weight = couple_node_length[reverse_edge_key]
elif edge_data.get("weight") is not None:
edge_weight = float(edge_data["weight"])
else:
# Ignore graph edges that are outside candidate pipes and have no usable
# partition weight (e.g. some non-pipe links in mixed network graphs).
continue
w_temp.append(_to_metis_edge_weight(edge_weight))
n_t.append(node_dict[neighbor_name])
w.append(w_temp)
correspond_dic[n_t[0]] = count_node
count_node = count_node + 1
# del n_t[0]
adjacency_list.append(n_t)
adjacency_list_new = [[] * 1 for i in range(len(adjacency_list))]
w_new = [[] * 1 for i in range(len(adjacency_list))]
for i in range(len(adjacency_list)):
adjacency_list_new[int(adjacency_list[i][0])] = adjacency_list[i]
w_new[int(adjacency_list[i][0])] = w[i]
for i in range(len(adjacency_list)):
del adjacency_list_new[i][0]
xadj = [0]
w_f = []
final_adjacency_list = []
for i in range(len(adjacency_list_new)):
final_adjacency_list = final_adjacency_list + adjacency_list_new[i]
xadj.append(len(final_adjacency_list))
w_f = w_f + w_new[i]
# (edgecuts, parts) = pymetis.part_graph(nparts=group_num, adjacency=adjacency_list_new)
metis_options = pymetis.Options()
metis_options.seed = 42
(edgecuts, parts) = pymetis.part_graph(
nparts=group_num,
adjncy=final_adjacency_list,
xadj=xadj,
eweights=w_f,
options=metis_options,
)
# (edgecuts, parts) = pymetis.part_graph(nparts=group_num, adjacency=adjacency_list_new)
candidate_group_list = [[] * 1 for i in range(group_num)]
for i in range(len(all_node_iter_new)):
candidate_group_list[parts[i]].append(all_node_iter_new[i])
"""parts_new = np.zeros(len(candidate_node_input), dtype=int)
for i in range(len(candidate_group_list)):
temp_group = candidate_group_list[i]
for each_node in temp_group:
parts_new[node_dict[each_node]] = i
parts_new = list(parts_new)"""
new_center = []
new_group = []
new_center_candidates = []
new_all_node = []
candidate_pipe_set = set(candidate_pipe_input)
all_grouped_pipe = []
for i in range(group_num):
# 构建子图
G_sub = G0.subgraph(candidate_group_list[i])
# 计算联通子图
sub_graphs = networkx.connected_components(G_sub)
if networkx.number_connected_components(G_sub) == 1:
# 求交集
nodeset = G_sub.nodes()
pipeset_set = set(cal_area_node_linked_pipe(nodeset, node_pipe_dic))
candidate_pipe = sorted(pipeset_set.intersection(candidate_pipe_set))
# 判断集合是否保留
if len(candidate_pipe) > 0:
# 保留 计算中心
center_t = pick_center_pipe(
node_x,
node_y,
candidate_pipe,
pipe_start_node_all,
pipe_end_node_all,
)
center_candidates_t = pick_dual_center_pipes(
node_x,
node_y,
candidate_pipe,
pipe_start_node_all,
pipe_end_node_all,
pipe_diameter,
)
# 更新
new_center.append(center_t)
new_center_candidates.append(center_candidates_t)
new_group.append(candidate_pipe)
new_all_node.append(nodeset)
all_grouped_pipe = all_grouped_pipe + candidate_pipe
else:
for c in sorted(sub_graphs, key=lambda c: min(c)):
G_temp = G0.subgraph(c)
nodeset = G_temp.nodes()
pipeset = cal_area_node_linked_pipe(nodeset, node_pipe_dic)
pipeset_set = set(pipeset)
# 求交集
candidate_pipe = sorted(pipeset_set.intersection(candidate_pipe_set))
# print(len(candidate_node))
# 判断集合是否保留
if len(candidate_pipe) > 0:
# 保留 计算中心
center_t = pick_center_pipe(
node_x,
node_y,
candidate_pipe,
pipe_start_node_all,
pipe_end_node_all,
)
center_candidates_t = pick_dual_center_pipes(
node_x,
node_y,
candidate_pipe,
pipe_start_node_all,
pipe_end_node_all,
pipe_diameter,
)
# 更新
new_center.append(center_t)
new_center_candidates.append(center_candidates_t)
new_group.append(candidate_pipe)
new_all_node.append(nodeset)
all_grouped_pipe = all_grouped_pipe + candidate_pipe
record_center = []
c_g = 0
for each_group in new_group:
if len(each_group) < 3:
record_center.append(new_center[c_g])
c_g += 1
c_g = 0
for each_group in new_group:
if len(each_group) >= 3:
if new_center[c_g] in record_center:
new_center[c_g] = find_new_center_pipe(
node_x,
node_y,
each_group,
pipe_start_node_all,
pipe_end_node_all,
pipe_diameter,
record_center,
)
new_center_candidates[c_g] = _dedupe_preserve_order(
[new_center[c_g]] + list(new_center_candidates[c_g])
)
record_center.append(new_center[c_g])
c_g += 1
# visualize_metis_partition(
# G0, new_center, new_group,
# node_x, node_y,
# pipe_start_node_all, pipe_end_node_all
# )
return new_center, new_group, new_all_node, new_center_candidates
def visualize_metis_partition(
G,
center_pipes,
pipe_groups,
node_x,
node_y,
pipe_start_node_all,
pipe_end_node_all,
title: str | None = None,
block: bool = True,
pause_seconds: float | None = None,
):
"""
可视化METIS分区结果(单图模式)
参数:
G: 原始管网图(nx.Graph)
center_pipes: 中心管道列表(list)
pipe_groups: 分组管道列表(list of lists)
node_x: 节点X坐标字典(dict)
node_y: 节点Y坐标字典(dict)
pipe_start_node_all: 管道起点字典(dict)
pipe_end_node_all: 管道终点字典(dict)
"""
fig = plt.figure("metis_partition_convergence", figsize=(22.51, 12.48))
fig.clf()
ax = fig.add_subplot(111)
if not block:
plt.ion()
# 生成颜色映射(自动扩展颜色数量)
colors = plt.cm.tab20(np.linspace(0, 1, len(pipe_groups)))
# --- 绘制背景管网(灰色半透明) ---
for edge in G.edges():
start_node, end_node = edge
ax.plot(
[node_x[start_node], node_x[end_node]],
[node_y[start_node], node_y[end_node]],
color="lightgray",
linewidth=0.5,
alpha=0.3,
zorder=1, # 确保背景在底层
)
# --- 绘制各分区管道(彩色)---
legend_handles = [] # 用于图例的句柄
for i, (group, center) in enumerate(zip(pipe_groups, center_pipes)):
color = colors[i % len(colors)] # 循环使用颜色
# 绘制分组管道
for pipe in group:
start = pipe_start_node_all[pipe]
end = pipe_end_node_all[pipe]
line = ax.plot(
[node_x[start], node_x[end]],
[node_y[start], node_y[end]],
color=color,
linewidth=2.5,
alpha=0.8,
zorder=2,
)
# 只为每个分组的第一个管道添加图例句柄
if pipe == group[0]:
legend_handles.append(line[0])
# 高亮中心管道(红色虚线)
if center in pipe_start_node_all and center in pipe_end_node_all:
start = pipe_start_node_all[center]
end = pipe_end_node_all[center]
ax.plot(
[node_x[start], node_x[end]],
[node_y[start], node_y[end]],
color="red",
linewidth=4,
linestyle="--",
dash_capstyle="round",
zorder=3, # 确保中心管道在最顶层
)
# --- 添加图例和标注 ---
# 分组图例
if legend_handles:
group_labels = [f"Group {i + 1}" for i in range(len(pipe_groups))]
ax.legend(
legend_handles,
group_labels,
loc="upper right",
title="Partitions",
fontsize=8,
title_fontsize=10,
)
# 中心管道标注(可选)
for i, center in enumerate(center_pipes):
if center in pipe_start_node_all:
x = (
node_x[pipe_start_node_all[center]] + node_x[pipe_end_node_all[center]]
) / 2
y = (
node_y[pipe_start_node_all[center]] + node_y[pipe_end_node_all[center]]
) / 2
ax.text(
x,
y,
f"C{i + 1}",
color="red",
fontsize=10,
ha="center",
va="center",
bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"),
)
# --- 图形美化 ---
ax.set_title(title or "Water Network Partitioning Overview", fontsize=14, pad=20)
ax.set_xlabel("X Coordinate", fontsize=10)
ax.set_ylabel("Y Coordinate", fontsize=10)
ax.grid(True, alpha=0.2, linestyle=":")
fig.tight_layout()
# 显示图形并强制刷新,避免迭代显示滞后一轮。
plt.show(block=block)
if not block:
fig.canvas.draw_idle()
fig.canvas.flush_events()
pause_value = 0.001 if pause_seconds is None else max(0.0, float(pause_seconds))
plt.pause(max(0.001, pause_value))
elif pause_seconds is not None:
plt.pause(max(0.0, float(pause_seconds)))
return fig
def generate_adjlist_with_all_edges(G, delimiter):
for s, nbrs in G.adjacency():
line = str(s) + delimiter
for t, data in nbrs.items():
line += str(t) + delimiter
yield line[: -len(delimiter)]
def cal_group_num(candidate_node_input, cal_group_num):
candidate_node_num = len(candidate_node_input)
if candidate_node_num > 100:
group_num_input = cal_group_num # 30
else:
group_num_input = 10
return group_num_input
@@ -0,0 +1,198 @@
"""噪声生成模块。"""
import copy
import random
import numpy as np
import pandas as pd
from .leak_simulator import simple_add_leak, simple_recover_wn, simple_simulation_pf
def add_noise_pd(data, noise_type, noise_para):
output_data = copy.deepcopy(data)
if type(output_data) == pd.core.frame.Series:
if noise_type == "uni":
for x in output_data.index:
noise = (np.random.random() - 0.5) * 2
output_data[x] = output_data[x] + noise * noise_para
elif noise_type == "gauss":
noise = np.random.normal(loc=0, scale=noise_para, size=output_data.shape)
output_data = output_data + noise
elif type(output_data) == pd.core.frame.DataFrame:
if noise_type == "uni":
noise = (np.random.random(size=output_data.shape) - 0.5) * 2
output_data = output_data + noise * noise_para
elif noise_type == "gauss":
noise = np.random.normal(loc=0, scale=noise_para, size=output_data.shape)
output_data = output_data + noise
return output_data
def add_noise_number(data, noise_type, noise_para):
output_data = copy.deepcopy(data)
if noise_type == "uni":
noise = (np.random.random() - 0.5) * 2
output_data = output_data + noise * noise_para
elif noise_type == "gauss":
noise = random.gauss(0, noise_para)
output_data = output_data + noise
return output_data
def add_noise_number_flow(data, noise_para_mean, noise_para_std1, noise_para_std2):
output_data = copy.deepcopy(data)
noise_flag1 = np.random.random() - 0.5
if noise_flag1 < 0:
noise = noise_para_mean - abs(np.random.normal(loc=0, scale=noise_para_std1))
else:
noise = noise_para_mean + abs(np.random.normal(loc=0, scale=noise_para_std2))
noise_flag2 = np.random.random() - 0.5
if noise_flag2 < 0:
noise_f = noise * (-1)
else:
noise_f = noise
output_data = output_data + noise_f
return output_data
def produce_noise_number(noise_type, noise_para):
if noise_type == "uni":
noise = (np.random.random() - 0.5) * 2
noise = noise * noise_para
elif noise_type == "gauss":
noise = random.gauss(0, noise_para)
else:
noise = 0
return noise
def add_noise_percentage_pd(data, noise_type, noise_para):
output_data = copy.deepcopy(data)
if type(output_data) == pd.core.frame.Series:
if noise_type == "uni":
for x in output_data.index:
noise = (np.random.random() - 0.5) * 2
output_data[x] = output_data[x] * (1 + noise * noise_para / 100)
elif noise_type == "gauss":
for x in output_data.index:
noise = np.random.gauss(0, noise_para)
output_data[x] = output_data[x] * (1 + noise / 100)
# std_noise = noise.std()
elif type(output_data) == pd.core.frame.DataFrame:
if noise_type == "uni":
noise = (np.random.random(size=output_data.shape) - 0.5) * 2
output_data = output_data * (1 + noise * noise_para / 100)
elif noise_type == "gauss":
noise = np.random.normal(loc=0, scale=noise_para, size=output_data.shape)
output_data = output_data * (1 + noise / 100)
# std_noise = noise.std().mean()
return output_data
def add_noise_in_wn_pf(
wn,
pipe_c_noise,
timestep_list,
pipe_coefficient,
sensor_name,
sensor_f_name,
all_node,
basic_demand_pd,
noise_type,
noise_para,
leak_pipe,
leak_flow,
):
wn.options.time.duration = 0
pipe_roughness_change = add_noise_pd(pipe_coefficient, noise_type, pipe_c_noise)
wn = change_para_of_wn(wn, pipe_roughness_change)
record_pressure = pd.DataFrame(index=timestep_list, columns=sensor_name)
record_flow = pd.DataFrame(index=timestep_list, columns=sensor_f_name)
record_noise_all = pd.DataFrame(
index=pd.MultiIndex.from_product([timestep_list, all_node]),
columns=basic_demand_pd.columns,
)
record_noise_all = record_noise_all.sort_index()
# normal 获取添加噪声后的监测点数据
for i in range(len(timestep_list)):
wn, record_noise = change_node_demand(
wn, basic_demand_pd, all_node, noise_type, noise_para
)
record_noise_all.loc[timestep_list[i]].loc[:, :] = record_noise
pressure_temp, flow_temp = simple_simulation_pf(
wn, sensor_name, sensor_f_name, [], []
)
record_pressure.iloc[i, :] = pressure_temp
record_flow.iloc[i, :] = flow_temp
# leak_simulation 获取添加漏损后的监测点数据
record_pressure_leak = pd.DataFrame(index=timestep_list, columns=sensor_name)
record_flow_leak = pd.DataFrame(index=timestep_list, columns=sensor_f_name)
# 改_wz_________________________________________
# add leak
wn, whole_inf, add_pipe1 = simple_add_leak(wn, leak_flow, leak_pipe)
# simulation
for i in range(len(timestep_list)):
record_noise = record_noise_all.loc[timestep_list[i]]
wn = change_node_demand_leak(wn, record_noise, all_node)
pressure_temp, flow_temp = simple_simulation_pf(
wn, sensor_name, sensor_f_name, leak_pipe, add_pipe1
)
record_pressure_leak.iloc[i, :] = pressure_temp
record_flow_leak.iloc[i, :] = flow_temp
# delete leak
wn = simple_recover_wn(wn, whole_inf)
return wn, record_pressure, record_flow, record_pressure_leak, record_flow_leak
def change_node_demand(wn, basic_demand_pd, all_node, noise_type, noise_para):
# 改_wz_____________________________________
record_noise = pd.DataFrame(index=all_node, columns=basic_demand_pd.columns)
for each_node in all_node:
node = wn.get_node(each_node)
num_columns = len(basic_demand_pd.columns)
# 处理前N-1列(如果有)
for i in range(num_columns - 1):
# 获取原始值并添加噪声
record_noise.loc[each_node].iloc[i] = (
1 + produce_noise_number(noise_type, noise_para)
) * basic_demand_pd.loc[each_node].iloc[i]
node.demand_timeseries_list[i].base_value = record_noise.loc[
each_node
].iloc[i]
# 处理最后一列(当列数>=1时)
if num_columns >= 1:
last_col = basic_demand_pd.columns[-1]
original_last = basic_demand_pd.loc[each_node, last_col]
record_noise.loc[each_node, last_col] = original_last
node.demand_timeseries_list[-1].base_value = original_last
return wn, record_noise
def change_node_demand_leak(wn, record_noise, all_node):
sample_node = wn.get_node(all_node[0])
# num_categories = len(sample_node.demand_timeseries_list)
num_categories = 1
for each in all_node:
node = wn.get_node(each)
for i in range(num_categories):
node.demand_timeseries_list[i].base_value = record_noise.loc[each].iloc[i]
return wn
def change_para_of_wn(wn, pipe_roughness_change):
for pipe_name, pipe in wn.pipes():
pipe.roughness = pipe_roughness_change[pipe_name]
return wn
@@ -0,0 +1,858 @@
"""相似性计算模块。"""
import math
import numpy as np
import pandas as pd
def cal_similarity_simple_return_dd(
similarity_mode,
monitor_p,
predict_p,
normal_p,
leak_p,
monitor_p_all,
predict_p_all,
normal_p_all,
leak_p_all,
important_sensor,
mean_dpressure,
dpressure_std,
dpressure_std_all,
if_gy=0,
cos_or_flow=1,
):
# cos_or_flow 用于 CAF
dpressure_s = normal_p - leak_p
dpressure = predict_p - monitor_p
act_dpressure = pd.Series(dtype=object)
for i in range(len(leak_p.index)):
if dpressure_std.iloc[i] > -200: # 0.001:
if if_gy == 1:
act_dpressure[leak_p.index[i]] = (
leak_p.iloc[i] - monitor_p.iloc[i]
) / dpressure_std.iloc[i]
else:
act_dpressure[leak_p.index[i]] = leak_p.iloc[i] - monitor_p.iloc[i]
if similarity_mode == "COS" or (similarity_mode == "CAF" and cos_or_flow == 1):
"""if leak_p.min()<0:
none_flag = 1
similarity_cos = 0
similarity_dis = 0
else:"""
none_flag = 0
sensor_for_cos = sorted(
set(dpressure_s.index).intersection(set(act_dpressure.index))
)
"""if len(dpressure_s) ==0 or len(dpressure) ==0:
jj=9
else:"""
try:
s1 = np.dot(
np.transpose(dpressure_s.loc[sensor_for_cos]),
dpressure.loc[sensor_for_cos],
)
s2 = np.linalg.norm(dpressure_s.loc[sensor_for_cos]) * np.linalg.norm(
dpressure.loc[sensor_for_cos]
)
if s2 == 0:
s2 = s2 + 0.0001
similarity_cos = s1 / s2
similarity_dis = 0
except Exception as e:
print(dpressure_s)
print(sensor_for_cos)
print(act_dpressure)
print(dpressure_std)
print(dpressure)
elif similarity_mode == "DIS" or (similarity_mode == "CAF" and cos_or_flow == 2):
"""if leak_p.min()<0:
none_flag = 1
else:"""
none_flag = 0
important_sensor = sorted(
set(important_sensor).intersection(set(act_dpressure.index))
)
part_dpressure = dpressure_s[important_sensor] - dpressure[important_sensor]
similarity_pre_DIS = np.linalg.norm(part_dpressure)
# similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
similarity_dis = similarity_pre_DIS
similarity_cos = 0
elif similarity_mode == "CAD_new":
act_dpressure = leak_p - monitor_p
"""if leak_p.min() < 0:
none_flag = 1
similarity_cos = 0
similarity_dis =0
else:"""
none_flag = 0
# cos
s1 = np.dot(np.transpose(dpressure_s), dpressure)
s2 = np.linalg.norm(dpressure_s) * np.linalg.norm(dpressure)
if s2 == 0:
s2 = s2 + 0.0001
similarity_cos = s1 / s2
# DIS
part_dpressure = act_dpressure.loc[important_sensor]
similarity_pre_DIS = np.linalg.norm(part_dpressure)
similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
similarity_dis = similarity_pre_DIS
elif similarity_mode == "CAD_new_gy" or similarity_mode == "CDF":
# cos
sensor_for_cos = sorted(
set(dpressure_s.index).intersection(set(act_dpressure.index))
)
if len(sensor_for_cos) == 0 and len(dpressure_s) == 0:
similarity_cos = 0
elif len(sensor_for_cos) == 0 and len(dpressure_s) > 0:
sensor_for_cos = list(dpressure_s.index)
none_flag = 0
s1 = np.dot(
np.transpose(dpressure_s.loc[sensor_for_cos]),
dpressure.loc[sensor_for_cos],
)
s2 = np.linalg.norm(dpressure_s.loc[sensor_for_cos]) * np.linalg.norm(
dpressure.loc[sensor_for_cos]
)
if s2 == 0:
s2 = s2 + 0.0001
similarity_cos = s1 / s2
else:
none_flag = 0
s1 = np.dot(
np.transpose(dpressure_s.loc[sensor_for_cos]),
dpressure.loc[sensor_for_cos],
)
s2 = np.linalg.norm(dpressure_s.loc[sensor_for_cos]) * np.linalg.norm(
dpressure.loc[sensor_for_cos]
)
if s2 == 0:
s2 = s2 + 0.0001
similarity_cos = s1 / s2
# DIS
important_sensor_new = sorted(
set(important_sensor).intersection(set(act_dpressure.index))
)
if len(important_sensor_new) == 0:
important_sensor_new = important_sensor
act_dpressure = pd.Series(dtype=object)
for i in range(len(leak_p_all.index)):
# if dpressure_std.iloc [i] > -200: # 0.001:
if if_gy == 1:
act_dpressure[leak_p_all.index[i]] = (
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
) / dpressure_std_all.iloc[i]
else:
act_dpressure[leak_p_all.index[i]] = (
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
)
# part_dpressure = act_dpressure.loc[important_sensor_new]
part_dpressure = (
dpressure.loc[important_sensor_new] - dpressure_s.loc[important_sensor_new]
)
similarity_pre_DIS = np.linalg.norm(part_dpressure) ## chang test
# part_dpressure = dpressure_s.loc[important_sensor]-dpressure.loc[important_sensor]
# similarity_pre_DIS = np.linalg.norm(part_dpressure)
# similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
similarity_dis = similarity_pre_DIS
elif similarity_mode == "OF":
# cos
similarity_cos = 0
none_flag = 0
# DIS
important_sensor_new = sorted(
set(important_sensor).intersection(set(act_dpressure.index))
)
if len(important_sensor_new) == 0:
important_sensor_new = important_sensor
act_dpressure = pd.Series(dtype=object)
for i in range(len(leak_p_all.index)):
# if dpressure_std.iloc [i] > -200: # 0.001:
if if_gy == 1:
act_dpressure[leak_p_all.index[i]] = (
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
) / dpressure_std_all.iloc[i]
else:
act_dpressure[leak_p_all.index[i]] = (
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
)
# part_dpressure = act_dpressure.loc[important_sensor_new]
part_dpressure = (
dpressure.loc[important_sensor_new] - dpressure_s.loc[important_sensor_new]
)
similarity_pre_DIS = np.linalg.norm(part_dpressure) ## chang test
# part_dpressure = dpressure_s.loc[important_sensor]-dpressure.loc[important_sensor]
# similarity_pre_DIS = np.linalg.norm(part_dpressure)
# similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
similarity_dis = similarity_pre_DIS
return similarity_cos, similarity_dis, none_flag
def adjust(
similarity_cos,
similarity_dis,
record_success_candidate,
record_success_no_candidate,
):
if len(record_success_no_candidate) > 0:
for each in record_success_no_candidate:
similarity_cos[each] = similarity_cos[record_success_candidate].min() * 0.9
similarity_dis[each] = similarity_dis[record_success_candidate].max() * 1.1
return similarity_cos, similarity_dis
def cal_sq_all_multi(
similarity_cos,
similarity_dis,
similarity_f,
candidate_pipe,
timestep_list_spc,
if_flow,
if_only_cos,
if_only_flow,
cos_h_input,
dis_h_input,
dis_f_h_input,
if_compalsive,
cos_sensor_num,
flow_sensor_num,
):
"""融合多种相似性并输出按时刻与候选管段组织的综合相似度。
该函数会根据模式开关(是否仅流量、是否仅 COS、是否包含流量)对
`similarity_cos`、`similarity_dis`、`similarity_f` 做标准化,并计算
权重 `sq_cos/sq_dis/sq_f` 后进行加权融合。
Args:
similarity_cos: 压力余弦相似性(DataFrame/Series,通常为时刻 x 候选管段)。
similarity_dis: 压力距离相似性(DataFrame/Series,通常为时刻 x 候选管段)。
similarity_f: 流量距离相似性(DataFrame/Series,通常为时刻 x 候选管段)。
candidate_pipe: 候选管段列表,用于输出列索引。
timestep_list_spc: 时刻列表,用于输出行索引。
if_flow: 是否启用流量相似性(1 启用,0 禁用)。
if_only_cos: 相似性模式标识(0: COS+DIS;1: COS;其他值按分支定义处理)。
if_only_flow: 是否仅使用流量相似性(1 是,0 否)。
cos_h_input: 外部给定的 COS 权重(强制权重模式下使用)。
dis_h_input: 外部给定的 DIS 权重(强制权重模式下使用)。
dis_f_h_input: 外部给定的流量权重(强制权重模式下使用)。
if_compalsive: 是否使用外部强制权重(1 使用输入权重,0 自动计算权重)。
cos_sensor_num: 压力传感器数量,用于权重调整。
flow_sensor_num: 流量传感器数量,用于权重调整。
Returns:
tuple[pd.DataFrame | pd.Series, float, float, float]:
- output_similarity_pd: 综合相似性结果。
- sq_cos: 最终 COS 权重。
- sq_dis: 最终 DIS 权重。
- sq_f: 最终流量权重。
"""
if if_only_flow == 1:
similarity_f, h_f = cal_sq_single_array(
similarity_f.values.reshape((-1, 1)), if_direct=2
)
sq_cos = 0
sq_dis = 0
sq_f = 1
similarity_all = similarity_f * sq_f
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
output_similarity_pd = pd.DataFrame(
output_similarity, index=timestep_list_spc, columns=candidate_pipe
)
else:
if if_only_cos == 0:
if if_flow == 1:
# standerdize
similarity_cos, h_cos = cal_sq_single_array(
similarity_cos.values.reshape((-1, 1)), if_direct=1
)
similarity_dis, h_dis = cal_sq_single_array(
similarity_dis.values.reshape((-1, 1)), if_direct=2
)
similarity_f, h_f = cal_sq_single_array(
similarity_f.values.reshape((-1, 1)), if_direct=2
)
if if_compalsive == 1:
sq_cos = cos_h_input
sq_dis = dis_h_input
sq_f = dis_f_h_input
else:
"""sq_cos = h_cos/(h_cos +h_dis +h_f )
sq_dis = h_dis/(h_cos +h_dis +h_f )
sq_f = h_f/(h_cos +h_dis +h_f )"""
sq_cos, sq_dis, sq_f = add_weight_for_SQ(
h_cos, h_dis, h_f, cos_sensor_num, flow_sensor_num
)
"""if cos_sensor_num == 2 and sq_cos>0.2:
sq_cos = 0.2
sq_dis = 0.8*h_dis / (h_dis + h_f)
sq_f = 0.8*h_f / (h_dis + h_f)
if cos_sensor_num == 1 and sq_dis > 0.3:
sq_cos = 0.1
sq_dis = 0.3
sq_f = 0.6"""
sq_cos, sq_dis, sq_f = adjust_ratio("CDF", sq_cos, sq_dis, sq_f)
if cos_sensor_num <= 1:
sq_cos = 0
# similarity
similarity_all = (
similarity_cos * sq_cos
+ similarity_dis * sq_dis
+ similarity_f * sq_f
)
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
output_similarity_pd = pd.DataFrame(
output_similarity, index=timestep_list_spc, columns=candidate_pipe
)
else:
# standerdize
similarity_cos, h_cos = cal_sq_single_array(
similarity_cos.values.reshape((-1, 1)), if_direct=1
)
similarity_dis, h_dis = cal_sq_single_array(
similarity_dis.values.reshape((-1, 1)), if_direct=2
)
if if_compalsive == 1:
sq_cos = cos_h_input
sq_dis = dis_h_input
else:
sq_cos = h_cos / (h_cos + h_dis)
sq_dis = h_dis / (h_cos + h_dis)
if cos_sensor_num == 2 and sq_cos > 0.5:
sq_cos = 0.5
sq_dis = 0.5
sq_cos, sq_dis, sq_f = adjust_ratio("CAD_new_gy", sq_cos, sq_dis, 0)
sq_f = 0
# similarity
similarity_all = similarity_cos * sq_cos + similarity_dis * sq_dis
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
output_similarity_pd = pd.DataFrame(
output_similarity, index=timestep_list_spc, columns=candidate_pipe
)
elif if_only_cos == 1:
if if_flow == 1:
# standerdize
similarity_cos, h_cos = cal_sq_single_array(
similarity_cos.values.reshape((-1, 1)), if_direct=1
)
similarity_f, h_f = cal_sq_single_array(
similarity_f.values.reshape((-1, 1)), if_direct=2
)
if if_compalsive == 1:
sq_cos = cos_h_input
sq_f = dis_f_h_input
else:
sq_cos = h_cos / (h_cos + h_f)
sq_f = h_f / (h_cos + h_f)
sq_cos, sq_dis, sq_f = adjust_ratio("CAF", sq_cos, 0, sq_f)
sq_dis = 0
# similarity
similarity_all = similarity_cos * sq_cos + similarity_f * sq_f
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
output_similarity_pd = pd.DataFrame(
output_similarity, index=timestep_list_spc, columns=candidate_pipe
)
else:
sq_cos = cos_h_input
sq_dis = dis_h_input
sq_f = dis_f_h_input
output_similarity_pd = similarity_cos
else:
sq_cos = cos_h_input
sq_dis = dis_h_input
sq_f = dis_f_h_input
output_similarity_pd = 1 / (similarity_dis + 1)
return output_similarity_pd, sq_cos, sq_dis, sq_f
def add_weight_for_SQ(h_cos, h_dis, h_f, sensor_cos_num, sensor_f_num):
h_f_new = h_f * sensor_f_num
if sensor_cos_num <= 1:
h_cos_new = 0
h_dis_new = h_dis * sensor_cos_num
else:
h_cos_new = h_cos * sensor_cos_num # / 2
h_dis_new = h_dis * sensor_cos_num # / 2
cos_sq = h_cos_new / (h_cos_new + h_dis_new + h_f_new)
dis_sq = h_dis_new / (h_cos_new + h_dis_new + h_f_new)
f_sq = h_f_new / (h_cos_new + h_dis_new + h_f_new)
if sensor_cos_num == 2 and cos_sq > 0.2:
cos_sq = 0.2
dis_sq = 0.8 * h_dis_new / (h_dis_new + h_f_new)
f_sq = 0.8 * h_f_new / (h_dis_new + h_f_new)
"""if sensor_cos_num == 1:
if dis_sq / f_sq > sensor_cos_num/sensor_f_num:
dis_sq = sensor_cos_num/sensor_f_num
f_sq=1-dis_sq"""
# if h_dis_new/h_f_new > sensor_cos_num/sensor_f_num
return cos_sq, dis_sq, f_sq
def cal_sq_single_array(similarity_pre, if_direct):
if similarity_pre.max() - similarity_pre.min() == 0:
similarity_pre = np.ones(similarity_pre.shape)
else:
if if_direct == 1:
similarity_pre = (
0.998
* (similarity_pre - similarity_pre.min())
/ (similarity_pre.max() - similarity_pre.min())
+ 0.002
)
else:
similarity_pre = (
0.998
* (similarity_pre.max() - similarity_pre)
/ (similarity_pre.max() - similarity_pre.min())
+ 0.002
)
# calculate pij
similarity_p = similarity_pre / similarity_pre.sum()
# cal xinxishang
similarity_lnp = np.zeros((len(similarity_pre), 1))
for j in range(len(similarity_p)):
similarity_lnp[j] = -similarity_p[j] * math.log(similarity_p[j], math.e)
h = 1 - 1 / math.log(len(similarity_pre), math.e) * similarity_lnp.sum()
return similarity_pre, h
def cal_similarity_all_multi_new_sq_improve_double_lzr(
candidate_pipe,
similarity_mode,
pressure_leak,
monitor_p,
predict_p,
normal_p,
if_flow,
if_only_cos,
if_only_flow,
flow_leak,
monitor_f,
predict_f,
normal_f,
timestep_list,
Top_sensor_num,
if_gy,
effective_sensor,
cos_h,
dis_h,
dis_f_h,
if_compalsive,
max_flow,
):
similarity = pd.Series(dtype=float, index=candidate_pipe)
similarity_detail: pd.DataFrame | None = None
important_p_sensor = cal_top_sensors(monitor_p, predict_p, Top_sensor_num)
# important_f_sensor, basic_f = cal_top_f_sensor(normal_f)
important_f_sensor = monitor_f.columns
if (
len(important_p_sensor) > 0 or len(important_f_sensor) > 0
): # if len(important_p_sensor) > 0
break_flag = 0
pressure_leak_new = pressure_leak.swaplevel()
# flow_leak_new = flow_leak.swaplevel()
if isinstance(flow_leak, pd.DataFrame) and len(flow_leak) > 0:
flow_leak_new = flow_leak.swaplevel()
else:
flow_leak_new = None
total_similarity_cos = pd.DataFrame(index=timestep_list, columns=candidate_pipe)
total_similarity_dis = pd.DataFrame(index=timestep_list, columns=candidate_pipe)
total_similarity_dis_f = pd.DataFrame(
index=timestep_list, columns=candidate_pipe
)
for timestep in timestep_list:
# cal p_cos, p_dis, f_dis
if if_only_flow != 1:
pressure_leak_temp = pressure_leak_new.loc[timestep].loc[
:, effective_sensor
]
monitor_p_temp = monitor_p.loc[timestep, effective_sensor]
predict_p_temp = predict_p.loc[timestep, effective_sensor]
normal_p_temp = normal_p.loc[timestep, effective_sensor]
(
total_similarity_cos.loc[timestep, :],
total_similarity_dis.loc[timestep, :],
) = cal_similarity_all_cos_dis(
candidate_pipe,
pressure_leak_temp,
similarity_mode,
monitor_p_temp,
predict_p_temp,
normal_p_temp,
pressure_leak_new.loc[timestep].loc[:, monitor_p.columns],
monitor_p.loc[timestep, :],
predict_p.loc[timestep, :],
normal_p.loc[timestep, :],
important_p_sensor,
if_gy,
cos_or_flow=1,
)
if if_flow == 1:
if len(timestep_list) == 1:
leak_f_temp = flow_leak_new.loc[timestep].loc[:, important_f_sensor]
monitor_f_temp = monitor_f.loc[timestep, important_f_sensor]
predict_f_temp = predict_f.loc[timestep, important_f_sensor]
normal_f_temp = normal_f.loc[timestep, important_f_sensor]
basic_normal_f_temp = abs(max_flow.loc[important_f_sensor])
leak_f_temp = leak_f_temp / basic_normal_f_temp
monitor_f_temp = monitor_f_temp / basic_normal_f_temp
predict_f_temp = predict_f_temp / basic_normal_f_temp
normal_f_temp = normal_f_temp / basic_normal_f_temp
else:
basic_f = abs(max_flow.loc[important_f_sensor])
leak_f_temp = (
flow_leak_new.loc[timestep].loc[:, important_f_sensor] / basic_f
)
monitor_f_temp = (
monitor_f.loc[timestep, important_f_sensor] / basic_f
)
predict_f_temp = (
predict_f.loc[timestep, important_f_sensor] / basic_f
)
normal_f_temp = normal_f.loc[timestep, important_f_sensor] / basic_f
_, total_similarity_dis_f.loc[timestep, :] = cal_similarity_all_cos_dis(
candidate_pipe,
leak_f_temp,
similarity_mode,
monitor_f_temp,
predict_f_temp,
normal_f_temp,
flow_leak_new.loc[timestep].loc[:, monitor_f.columns],
monitor_f.loc[timestep, :],
predict_f.loc[timestep, :],
normal_f.loc[timestep, :],
important_f_sensor,
if_gy,
cos_or_flow=2,
)
else:
total_similarity_dis_f = []
similarity_all, cos_h, dis_h, dis_f_h = cal_sq_all_multi(
total_similarity_cos,
total_similarity_dis,
total_similarity_dis_f,
candidate_pipe,
timestep_list,
if_flow,
if_only_cos,
if_only_flow,
cos_h,
dis_h,
dis_f_h,
if_compalsive,
len(important_p_sensor),
len(important_f_sensor),
)
if len(timestep_list) == 1:
similarity = similarity_all.iloc[0]
elif len(timestep_list) > 3:
for each_candidate in candidate_pipe:
similarity[each_candidate] = remove_3_sigma(
similarity_all.loc[:, each_candidate]
)
else:
for each_candidate in candidate_pipe:
similarity[each_candidate] = similarity_all.loc[
:, each_candidate
].mean()
similarity = similarity.sort_values(ascending=False, kind="mergesort")
detail_index = [str(pipe) for pipe in candidate_pipe]
similarity_detail = pd.DataFrame(index=detail_index)
similarity_detail.index.name = "pipe_id"
if isinstance(total_similarity_cos, pd.DataFrame) and len(total_similarity_cos) > 0:
pressure_cos_mean = (
total_similarity_cos.mean(axis=0)
.reindex(candidate_pipe)
.to_numpy(dtype=float)
)
else:
pressure_cos_mean = np.full(len(candidate_pipe), np.nan)
if isinstance(total_similarity_dis, pd.DataFrame) and len(total_similarity_dis) > 0:
pressure_dis_mean = (
total_similarity_dis.mean(axis=0)
.reindex(candidate_pipe)
.to_numpy(dtype=float)
)
else:
pressure_dis_mean = np.full(len(candidate_pipe), np.nan)
if isinstance(total_similarity_dis_f, pd.DataFrame) and len(total_similarity_dis_f) > 0:
flow_dis_mean = (
total_similarity_dis_f.mean(axis=0)
.reindex(candidate_pipe)
.to_numpy(dtype=float)
)
else:
flow_dis_mean = np.full(len(candidate_pipe), np.nan)
similarity_detail["pressure_cos_mean"] = pressure_cos_mean
similarity_detail["pressure_dis_mean"] = pressure_dis_mean
similarity_detail["flow_dis_mean"] = flow_dis_mean
similarity_detail["weight_cos"] = float(cos_h)
similarity_detail["weight_dis"] = float(dis_h)
similarity_detail["weight_flow"] = float(dis_f_h)
similarity_detail["final_similarity"] = (
similarity.reindex(candidate_pipe).to_numpy(dtype=float)
)
similarity_detail["similarity_rank"] = (
similarity_detail["final_similarity"].rank(method="dense", ascending=False)
).astype(int)
similarity_detail["pressure_sensor_count"] = int(len(important_p_sensor))
similarity_detail["flow_sensor_count"] = int(len(important_f_sensor))
similarity_detail = similarity_detail.sort_values(
by="final_similarity", ascending=False, kind="mergesort"
)
else:
break_flag = 1
similarity = 0
cos_h = 0
dis_h = 0
dis_f_h = 0
return similarity, cos_h, dis_h, dis_f_h, break_flag, similarity_detail
def cal_similarity_all_cos_dis(
candidate_pipe,
pressure_leak,
similarity_mode,
monitor_p,
predict_p,
normal_p,
pressure_leak_all,
monitor_p_all,
predict_p_all,
normal_p_all,
important_sensor,
if_gy,
cos_or_flow,
):
similarity_cos = pd.Series(dtype=float, index=candidate_pipe)
similarity_dis = pd.Series(dtype=float, index=candidate_pipe)
dpressure = normal_p - pressure_leak
# 无用 ----------------------------------------------
mean_dpressure = dpressure.mean()
monitor_new = pd.DataFrame(index=["monitor"], columns=monitor_p.index)
monitor_new.iloc[0] = monitor_p
add_m_leak_pressure = [pressure_leak, monitor_p]
add_m_leak_pressure = pd.concat(add_m_leak_pressure)
pressure_leak_std = add_m_leak_pressure.std(axis=0, ddof=1)
pressure_leak_std = pd.Series(pressure_leak_std, index=pressure_leak.columns)
add_m_leak_pressure_all = [pressure_leak_all, monitor_p_all]
add_m_leak_pressure_all = pd.concat(add_m_leak_pressure_all)
pressure_leak_std_all = add_m_leak_pressure_all.std(axis=0, ddof=1)
pressure_leak_std_all = pd.Series(
pressure_leak_std_all, index=pressure_leak.columns
)
# 无用 ----------------------------------------------
monitor_p_temp = monitor_p
predict_p_temp = predict_p
normal_p_temp = normal_p
monitor_p_temp_all = monitor_p_all
predict_p_temp_all = predict_p_all
normal_p_temp_all = normal_p_all
record_success_candidate = []
record_success_no_candidate = []
for i in range(len(candidate_pipe)):
leak_p = pressure_leak.iloc[i, :]
leak_p_all = pressure_leak_all.iloc[i, :]
similarity_cos.iloc[i], similarity_dis.iloc[i], none_flag = (
cal_similarity_simple_return_dd(
similarity_mode,
monitor_p_temp,
predict_p_temp,
normal_p_temp,
leak_p,
monitor_p_temp_all,
predict_p_temp_all,
normal_p_temp_all,
leak_p_all,
important_sensor,
mean_dpressure,
pressure_leak_std,
pressure_leak_std_all,
if_gy,
cos_or_flow,
)
)
if none_flag == 0:
record_success_candidate.append(candidate_pipe[i])
else:
record_success_no_candidate.append(candidate_pipe[i])
similarity_cos, similarity_dis = adjust(
similarity_cos,
similarity_dis,
record_success_candidate,
record_success_no_candidate,
)
return similarity_cos, similarity_dis
def cal_top_f_sensor(normal_f):
if type(normal_f) == pd.core.frame.DataFrame:
mean_f = normal_f.mean()
else:
mean_f = normal_f
output_sensor = []
output_normal_f = pd.Series(dtype=object)
for i in range(len(mean_f.index)):
if abs(mean_f.iloc[i]) > 0.01 / 3600:
output_sensor.append(mean_f.index[i])
output_normal_f[mean_f.index[i]] = mean_f.iloc[i]
return output_sensor, output_normal_f
def cal_top_sensors(monitor_p, predict_p, Top_sensor_num):
dpressure = abs(predict_p - monitor_p)
if type(dpressure) == pd.core.frame.DataFrame:
dpressure = dpressure.mean()
dpressure_rank = dpressure.sort_values(ascending=False, kind="mergesort")
return list(dpressure_rank.index[:Top_sensor_num])
def remove_3_sigma(similarity_t):
all_sample = len(similarity_t.index)
apart_sample = math.ceil(all_sample * 0.6)
similarity = similarity_t.astype("float")
mean_t = similarity.mean()
std_t = similarity.std()
new_similarity = similarity[
(similarity <= mean_t + 3 * std_t) & (similarity >= mean_t - 3 * std_t)
]
mean_t_new = new_similarity.mean()
return mean_t_new
def update_similarity(leak_candidate_center, similarity, leak_center_dict):
similarity_new = pd.Series(dtype=float)
for each_center in leak_candidate_center:
houxuan_center = leak_center_dict[each_center]
if len(houxuan_center) > 1:
temp_similarity = similarity[houxuan_center]
similarity_new[each_center] = temp_similarity.max()
else:
if type(similarity[each_center]) == pd.core.series.Series:
similarity_new[each_center] = similarity[each_center].mean()
else:
similarity_new[each_center] = similarity[each_center]
similarity_new = similarity_new.sort_values(ascending=False, kind="mergesort")
return similarity_new
def extra_judge(
similarity, min_candidates_to_prune: int = 200, std_relax_factor: float = 0.5
):
if len(similarity.index) == 0:
return 1.0, similarity
if len(similarity.index) < int(min_candidates_to_prune):
return 1.0, similarity
mean_similarity = float(similarity.mean())
std_similarity = float(similarity.std())
if not math.isfinite(std_similarity):
std_similarity = 0.0
threshold = mean_similarity - float(std_relax_factor) * std_similarity
out_put_similarity = similarity[similarity >= threshold - 1e-10]
if len(out_put_similarity.index) == 0:
out_put_similarity = similarity.iloc[:1]
cut_ratio = len(out_put_similarity.index) / len(similarity.index)
return cut_ratio, out_put_similarity
def adjust_ratio(similarity_mode, cos_h, dis_h, dis_f_h, low_limit=0.1):
if similarity_mode == "CAF":
if cos_h < low_limit:
cos_h = low_limit
dis_f_h = 1 - cos_h
elif dis_f_h < low_limit:
dis_f_h = low_limit
cos_h = 1 - dis_f_h
elif similarity_mode == "CAD_new_gy":
if dis_h < low_limit:
dis_h = low_limit
cos_h = 1 - dis_h
elif cos_h < low_limit:
cos_h = low_limit
dis_h = 1 - cos_h
elif similarity_mode == "CDF":
normal_index = [0, 1, 2]
h_list = [cos_h, dis_h, dis_f_h]
if cos_h < low_limit:
h_list[0] = low_limit
normal_index.remove(0)
if dis_h < low_limit:
h_list[1] = low_limit
normal_index.remove(1)
if dis_f_h < low_limit:
h_list[2] = low_limit
normal_index.remove(2)
if len(normal_index) == 1:
h_list[normal_index[0]] = h_list[normal_index[0]] - (sum(h_list) - 1)
elif len(normal_index) == 2:
sum_list = sum(h_list)
multiper = 1 - (sum_list - 1) / (
h_list[normal_index[0]] + h_list[normal_index[1]]
)
h_list[normal_index[0]] = h_list[normal_index[0]] * multiper
h_list[normal_index[1]] = h_list[normal_index[1]] * multiper
cos_h, dis_h, dis_f_h = h_list[0], h_list[1], h_list[2]
return cos_h, dis_h, dis_f_h
# 返回相似性计算的模式(不同权重),是否计算流量相似性,是否只计算cos相似性,是否只计算流量相似性。
def decode_mode(similarity_mode):
if similarity_mode == "COS":
if_flow = 0
if_only_cos = 1
if_only_flow = 0
elif similarity_mode == "CAD_new_gy":
if_flow = 0
if_only_cos = 0
if_only_flow = 0
elif similarity_mode == "CDF":
if_flow = 1
if_only_cos = 0
if_only_flow = 0
elif similarity_mode == "CAF":
if_flow = 1
if_only_cos = 1
if_only_flow = 0
elif similarity_mode == "DIS":
if_flow = 1
if_only_cos = 2
if_only_flow = 0
elif similarity_mode == "OF":
if_flow = 1
if_only_cos = 0
if_only_flow = 1
return if_flow, if_only_cos, if_only_flow
@@ -1,7 +1,7 @@
import os
import app.algorithms.api_ex.flow_data_clean as flow_data_clean
import app.algorithms.api_ex.pressure_data_clean as pressure_data_clean
from app.algorithms.cleaning import flow as _flow_module
from app.algorithms.cleaning import pressure as _pressure_module
############################################################
@@ -19,14 +19,15 @@ def flow_data_clean(input_csv_file: str) -> str:
"""
# 提供的 input_csv_path 绝对路径,以下为 默认脚本目录下同名 CSV 文件,构建绝对路径,可根据情况修改
script_dir = os.path.dirname(os.path.abspath(__file__))
# 使用 algorithms 根目录保持与原 data_cleaning.py 一致的行为
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
input_csv_path = os.path.join(script_dir, input_csv_file)
# 检查文件是否存在
if not os.path.exists(input_csv_path):
raise FileNotFoundError(f"指定的文件不存在: {input_csv_path}")
# 调用 Fdataclean.clean_flow_data_kf 函数进行数据清洗
out_xlsx_path = flow_data_clean.clean_flow_data_kf(input_csv_path)
# 调用 clean_flow_data_kf 函数进行数据清洗
out_xlsx_path = _flow_module.clean_flow_data_kf(input_csv_path)
print("清洗后的数据已保存到:", out_xlsx_path)
@@ -46,12 +47,13 @@ def pressure_data_clean(input_csv_file: str) -> str:
"""
# 提供的 input_csv_path 绝对路径,以下为 默认脚本目录下同名 CSV 文件,构建绝对路径,可根据情况修改
script_dir = os.path.dirname(os.path.abspath(__file__))
# 使用 algorithms 根目录保持与原 data_cleaning.py 一致的行为
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
input_csv_path = os.path.join(script_dir, input_csv_file)
# 检查文件是否存在
if not os.path.exists(input_csv_path):
raise FileNotFoundError(f"指定的文件不存在: {input_csv_path}")
# 调用 Fdataclean.clean_flow_data_kf 函数进行数据清洗
out_xlsx_path = pressure_data_clean.clean_pressure_data_km(input_csv_path)
# 调用 clean_pressure_data_km 函数进行数据清洗
out_xlsx_path = _pressure_module.clean_pressure_data_km(input_csv_path)
print("清洗后的数据已保存到:", out_xlsx_path)
@@ -1,83 +1,10 @@
# ...existing code...
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pykalman import KalmanFilter
import os
def fill_time_gaps(
data: pd.DataFrame,
time_col: str = "time",
freq: str = "1min",
short_gap_threshold: int = 10,
) -> pd.DataFrame:
"""
补齐缺失时间戳并填补数据缺口
Args:
data: 包含时间列的 DataFrame
time_col: 时间列名默认 'time'
freq: 重采样频率默认 '1min'
short_gap_threshold: 短缺口阈值分钟<=此值用线性插值>此值用前向填充
Returns:
补齐时间后的 DataFrame保留原时间列格式
"""
if time_col not in data.columns:
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
# 解析时间列并设为索引
data = data.copy()
data[time_col] = pd.to_datetime(data[time_col], utc=True)
data_indexed = data.set_index(time_col)
# 生成完整时间范围
full_range = pd.date_range(
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
)
# 重索引以补齐缺失时间点,同时保留原始时间戳
combined_index = data_indexed.index.union(full_range).sort_values().unique()
data_reindexed = data_indexed.reindex(combined_index)
# 按列处理缺口
for col in data_reindexed.columns:
# 识别缺失值位置
is_missing = data_reindexed[col].isna()
# 计算连续缺失的长度
missing_groups = (is_missing != is_missing.shift()).cumsum()
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
# 短缺口:时间插值
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
if short_gap_mask.any():
data_reindexed.loc[short_gap_mask, col] = (
data_reindexed[col]
.interpolate(method="time", limit_area="inside")
.loc[short_gap_mask]
)
# 长缺口:前向填充
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
if long_gap_mask.any():
data_reindexed.loc[long_gap_mask, col] = (
data_reindexed[col].ffill().loc[long_gap_mask]
)
# 重置索引并恢复时间列(保留原格式)
data_result = data_reindexed.reset_index()
data_result.rename(columns={"index": time_col}, inplace=True)
# 保留时区信息
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
# 修正时区格式(Python的%z输出为+0000,需转为+00:00
data_result[time_col] = data_result[time_col].str.replace(
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
)
return data_result
from app.algorithms._utils import fill_time_gaps
def clean_flow_data_kf(
+579
View File
@@ -0,0 +1,579 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
ID_LIKE_COLUMNS = {
"id",
"device_id",
"node_id",
"sensor_id",
"monitor_id",
"junction_id",
}
def _normalize_time_frame(data: pd.DataFrame) -> pd.DataFrame:
"""返回按时间排序的副本,并尽量将 time 列解析为时间类型。"""
data = data.copy()
if "time" in data.columns:
data["time"] = pd.to_datetime(data["time"], errors="coerce")
data = data.sort_values(["time"]).reset_index(drop=True)
return data
def _select_pressure_columns(data: pd.DataFrame) -> tuple[list[str], list[str]]:
"""区分需要清洗的数值列与需要原样保留的列。"""
value_cols: list[str] = []
keep_cols: list[str] = []
for col in data.columns:
if col == "time":
continue
col_key = col.lower()
if col_key in ID_LIKE_COLUMNS or col_key.endswith("_id"):
keep_cols.append(col)
continue
numeric = pd.to_numeric(data[col], errors="coerce")
if numeric.notna().sum() == 0 or numeric.nunique(dropna=True) <= 1:
keep_cols.append(col)
else:
value_cols.append(col)
return value_cols, keep_cols
def _robust_scale(values: pd.Series) -> float:
"""基于 MAD 计算稳健尺度。"""
series = pd.to_numeric(values, errors="coerce").dropna()
if series.empty:
return 1.0
median = series.median()
mad = (series - median).abs().median()
if pd.notna(mad) and mad > 0:
return float(1.4826 * mad)
iqr = series.quantile(0.75) - series.quantile(0.25)
if pd.notna(iqr) and iqr > 0:
return float(iqr / 1.349)
std = series.std()
if pd.notna(std) and std > 0:
return float(std)
return 1.0
def _shrink_toward_baseline(observed: float, baseline: float, scale: float) -> float:
"""把观测值向基线值收缩,scale 越小,修复越强。"""
if pd.isna(observed):
return baseline
if pd.isna(baseline):
return observed
diff = observed - baseline
weight = scale / (abs(diff) + scale)
return float(baseline + diff * weight)
def _infer_time_frequency(time_values: pd.Series | pd.Index) -> pd.Timedelta:
"""从时间序列中推断采样频率,失败时默认 15 分钟。"""
parsed = pd.to_datetime(pd.Series(time_values), errors="coerce").dropna().sort_values()
if len(parsed) < 2:
return pd.Timedelta(minutes=15)
diffs = parsed.diff().dropna()
diffs = diffs[diffs > pd.Timedelta(0)]
if diffs.empty:
return pd.Timedelta(minutes=15)
mode = diffs.mode()
return mode.iloc[0] if not mode.empty else diffs.median()
def _build_local_pressure_baseline(series: pd.Series) -> pd.Series:
"""基于局部插值与中值滤波构造平滑基线。"""
baseline = _safe_time_interpolate(series)
baseline = baseline.rolling(window=5, center=True, min_periods=1).median()
baseline = _safe_time_interpolate(baseline)
return baseline.ffill().bfill()
def _build_seasonal_pressure_baseline(series: pd.Series) -> pd.Series:
"""按一天内的同一时刻构造季节性基线,适合日周期压力数据。"""
if not isinstance(series.index, pd.DatetimeIndex):
return pd.Series(np.nan, index=series.index, dtype=float)
slot_labels = pd.Series(series.index.strftime("%H:%M:%S"), index=series.index)
return series.groupby(slot_labels).transform("median")
def _detect_pressure_spikes(series: pd.Series, local_baseline: pd.Series) -> pd.Series:
"""识别单点异常上升/下降尖峰,避免过度修正正常波动。"""
residual = series - local_baseline
neighbor_center = (series.shift(1) + series.shift(-1)) / 2
curvature = series - neighbor_center
residual_scale = max(_robust_scale(residual), 1e-6)
curvature_scale = max(_robust_scale(curvature), 1e-6)
direction_flip = ((series - series.shift(1)) * (series.shift(-1) - series) < 0).fillna(False)
return (
residual.abs() > 3.5 * residual_scale
) & (
curvature.abs() > 3.0 * curvature_scale
) & direction_flip
def _fill_pressure_gaps(
original: pd.Series,
repaired: pd.Series,
local_baseline: pd.Series,
seasonal_baseline: pd.Series,
) -> pd.Series:
"""短缺口用局部插值,长缺口优先使用同一时刻的季节性轨迹。"""
missing_mask = original.isna()
if not missing_mask.any():
return repaired
gap_groups = (missing_mask != missing_mask.shift(fill_value=False)).cumsum()
gap_lengths = missing_mask.groupby(gap_groups).transform("sum").where(missing_mask, 0)
filled = repaired.copy()
short_gap_mask = missing_mask & (gap_lengths < 4)
long_gap_mask = missing_mask & ~short_gap_mask
filled[short_gap_mask] = local_baseline[short_gap_mask]
long_gap_fill = seasonal_baseline.where(seasonal_baseline.notna(), local_baseline)
filled[long_gap_mask] = long_gap_fill[long_gap_mask]
return filled
def _clean_pressure_series(series: pd.Series) -> pd.Series:
"""清洗单个压力时间序列。"""
series = pd.to_numeric(series, errors="coerce").astype(float)
local_baseline = _build_local_pressure_baseline(series)
spike_mask = _detect_pressure_spikes(series, local_baseline)
repaired = series.copy()
repaired[spike_mask] = local_baseline[spike_mask]
seasonal_baseline = _build_seasonal_pressure_baseline(repaired)
repaired = _fill_pressure_gaps(series, repaired, local_baseline, seasonal_baseline)
if repaired.isna().any():
repaired = repaired.where(repaired.notna(), local_baseline)
return repaired.ffill().bfill()
def _format_time_column(data: pd.DataFrame) -> pd.DataFrame:
"""统一输出时间格式,方便下游直接按 ISO 字符串解析。"""
if "time" not in data.columns:
return data
formatted = data.copy()
time_values = pd.to_datetime(formatted["time"], errors="coerce")
if time_values.isna().all():
return formatted
if time_values.dt.tz is not None:
time_strings = time_values.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
time_strings = time_strings.str.replace(
r"([+-]\d{2})(\d{2})$",
r"\1:\2",
regex=True,
)
else:
time_strings = time_values.dt.strftime("%Y-%m-%dT%H:%M:%S")
formatted["time"] = time_strings.where(time_values.notna(), formatted["time"])
return formatted
def _expand_snapshot_time_grid(data: pd.DataFrame, freq: pd.Timedelta) -> pd.DataFrame:
"""仅补齐时间轴,不提前填充值,避免长缺口丢失原始形状特征。"""
expanded = data.copy()
expanded["time"] = pd.to_datetime(expanded["time"], errors="coerce")
expanded = expanded.dropna(subset=["time"]).sort_values("time")
if expanded.empty:
return data
indexed = expanded.set_index("time")
full_index = pd.date_range(indexed.index.min(), indexed.index.max(), freq=freq)
indexed = indexed.reindex(full_index)
indexed.index.name = "time"
return indexed.reset_index()
def _safe_datetime_index(values: pd.Series | pd.Index | list[object]) -> pd.DatetimeIndex | None:
"""尽量把时间值标准化为 DatetimeIndex;失败则返回 None。"""
parsed = pd.to_datetime(values, errors="coerce")
try:
datetime_index = pd.DatetimeIndex(parsed)
except (TypeError, ValueError):
return None
if datetime_index.isna().all():
return None
return datetime_index
def _safe_time_interpolate(series: pd.Series) -> pd.Series:
"""仅在索引确实是 DatetimeIndex 时使用 time interpolation。"""
if isinstance(series.index, pd.DatetimeIndex):
return series.interpolate(method="time", limit_direction="both")
return series.interpolate(limit_direction="both")
def _detect_long_form_identifier(data: pd.DataFrame, value_cols: list[str], keep_cols: list[str]) -> str | None:
"""识别 time/id/value 长表结构。"""
if "time" not in data.columns or len(value_cols) != 1:
return None
identifier_candidates = [
col
for col in keep_cols
if col.lower() in ID_LIKE_COLUMNS or col.lower().endswith("_id")
]
if len(identifier_candidates) != 1:
return None
if not data["time"].duplicated().any():
return None
return identifier_candidates[0]
def _clean_long_form_pressure(
data: pd.DataFrame,
value_col: str,
identifier_col: str,
keep_cols: list[str],
fill_gaps: bool,
) -> pd.DataFrame:
"""按测点拆分 long-form 压力数据,再逐列清洗后恢复原结构。"""
data = _normalize_time_frame(data)
wide_df = (
data[[identifier_col, "time", value_col]]
.pivot(index="time", columns=identifier_col, values=value_col)
.reset_index()
)
sensor_cols = [col for col in wide_df.columns if col != "time"]
cleaned_wide = _clean_snapshot_pressure(wide_df, sensor_cols, keep_cols=[], fill_gaps=fill_gaps)
cleaned_long = cleaned_wide.melt(
id_vars="time",
var_name=identifier_col,
value_name=value_col,
)
passthrough_cols = [col for col in keep_cols if col != identifier_col]
if passthrough_cols:
metadata = data[[identifier_col] + passthrough_cols].drop_duplicates(subset=[identifier_col])
cleaned_long = cleaned_long.merge(metadata, on=identifier_col, how="left")
try:
cleaned_long[identifier_col] = cleaned_long[identifier_col].astype(data[identifier_col].dtype)
except (TypeError, ValueError):
pass
cleaned_long = cleaned_long.sort_values(["time", identifier_col]).reset_index(drop=True)
ordered_cols = ["time", identifier_col] + passthrough_cols + [value_col]
cleaned_long = cleaned_long[[col for col in ordered_cols if col in cleaned_long.columns]]
return cleaned_long
def _build_time_slot_frame(
data: pd.DataFrame, value_col: str, expected_slots: int
) -> pd.DataFrame:
"""把重复时间点整理成 time x slot 的矩阵。"""
grouped = data.groupby("time", sort=True)
times = list(grouped.groups.keys())
slot_frame = pd.DataFrame(index=pd.Index(times, name="time"), columns=range(expected_slots), dtype=float)
for time_value, group in grouped:
values = pd.to_numeric(group[value_col], errors="coerce").tolist()
for slot_idx, value in enumerate(values[:expected_slots]):
slot_frame.loc[time_value, slot_idx] = value
return slot_frame
def _slot_baseline(slot_frame: pd.DataFrame) -> pd.DataFrame:
"""对每个槽位做时间插值和平滑,得到基线轨迹。"""
baseline = pd.DataFrame(index=slot_frame.index, columns=slot_frame.columns, dtype=float)
for col in slot_frame.columns:
series = slot_frame[col].astype(float)
series = _safe_time_interpolate(series)
series = series.rolling(window=5, center=True, min_periods=1).median()
series = _safe_time_interpolate(series).ffill().bfill()
baseline[col] = series
return baseline
def _choose_insertion_position(
observed: list[float], baseline_row: pd.Series, expected_slots: int
) -> int:
"""为少一个观测值的时间组选择最合理的插入位置。"""
missing_count = expected_slots - len(observed)
if missing_count <= 0:
return 0
best_pos = 0
best_cost = float("inf")
for insert_pos in range(expected_slots):
cost = 0.0
obs_idx = 0
for slot_idx in range(expected_slots):
if slot_idx == insert_pos:
continue
obs_value = observed[obs_idx]
base_value = float(baseline_row.iloc[slot_idx])
if pd.notna(obs_value) and pd.notna(base_value):
cost += abs(obs_value - base_value)
obs_idx += 1
if cost < best_cost:
best_cost = cost
best_pos = insert_pos
return best_pos
def _clean_repeated_timestamp_pressure(
data: pd.DataFrame, value_col: str, keep_cols: list[str]
) -> pd.DataFrame:
"""针对同一时间点重复采样的压力数据进行修复。"""
data = _normalize_time_frame(data)
grouped_sizes = data.groupby("time").size()
if grouped_sizes.empty:
return data
expected_slots = int(grouped_sizes.mode().iloc[0]) if not grouped_sizes.mode().empty else int(grouped_sizes.max())
expected_slots = max(expected_slots, int(grouped_sizes.max()))
slot_frame = _build_time_slot_frame(data, value_col, expected_slots)
baseline_frame = _slot_baseline(slot_frame)
residuals = slot_frame - baseline_frame
slot_scales = {
col: max(_robust_scale(residuals[col]), 1e-6) for col in residuals.columns
}
cleaned_rows: list[dict[str, object]] = []
grouped = data.groupby("time", sort=True)
for time_value, group in grouped:
observed_values = pd.to_numeric(group[value_col], errors="coerce").tolist()
baseline_row = baseline_frame.loc[time_value]
insert_pos = _choose_insertion_position(observed_values, baseline_row, expected_slots)
cleaned_values: list[float] = []
obs_idx = 0
for slot_idx in range(expected_slots):
if slot_idx == insert_pos and len(observed_values) < expected_slots:
cleaned_values.append(float(baseline_row.iloc[slot_idx]))
continue
if obs_idx >= len(observed_values):
cleaned_values.append(float(baseline_row.iloc[slot_idx]))
continue
observed = observed_values[obs_idx]
baseline = float(baseline_row.iloc[slot_idx])
cleaned_values.append(
_shrink_toward_baseline(observed, baseline, slot_scales.get(slot_idx, 1.0))
)
obs_idx += 1
# 其余字段原样保留;常量列(如 id)直接复制第一条记录即可
template_row = group.iloc[0].to_dict()
for slot_idx, cleaned_value in enumerate(cleaned_values):
row = dict(template_row)
row["time"] = time_value
row[value_col] = cleaned_value
cleaned_rows.append(row)
cleaned_df = pd.DataFrame(cleaned_rows)
cleaned_df = cleaned_df.sort_values(["time"]).reset_index(drop=True)
ordered_cols = ["time"] + keep_cols + [value_col]
ordered_cols = [col for col in ordered_cols if col in cleaned_df.columns]
remaining_cols = [col for col in cleaned_df.columns if col not in ordered_cols]
cleaned_df = cleaned_df[ordered_cols + remaining_cols]
return _format_time_column(cleaned_df)
def _clean_snapshot_pressure(
data: pd.DataFrame, value_cols: list[str], keep_cols: list[str], fill_gaps: bool
) -> pd.DataFrame:
"""针对单条时间序列或多列快照数据进行稳健修复。"""
data = _normalize_time_frame(data)
if fill_gaps and "time" in data.columns:
freq = _infer_time_frequency(data["time"])
data = _expand_snapshot_time_grid(data, freq)
data["time"] = pd.to_datetime(data["time"], errors="coerce")
data = data.sort_values(["time"]).reset_index(drop=True)
cleaned_df = data.copy()
time_index = (
_safe_datetime_index(cleaned_df["time"])
if "time" in cleaned_df.columns
else None
)
if time_index is None:
time_index = pd.RangeIndex(start=0, stop=len(cleaned_df))
for col in value_cols:
series = pd.Series(
pd.to_numeric(cleaned_df[col], errors="coerce").to_numpy(),
index=time_index,
dtype=float,
)
cleaned_df[col] = _clean_pressure_series(series).to_numpy()
ordered_cols = ["time"] + keep_cols + value_cols
ordered_cols = [col for col in ordered_cols if col in cleaned_df.columns]
remaining_cols = [col for col in cleaned_df.columns if col not in ordered_cols]
cleaned_df = cleaned_df[ordered_cols + remaining_cols]
return _format_time_column(cleaned_df)
def clean_pressure_data_km(
input_csv_path: str, show_plot: bool = False, fill_gaps: bool = True
) -> str:
"""
读取输入 CSV,基于时间结构进行稳健修复。输出为 <input_basename>_cleaned.xlsx(同目录)。
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'
返回输出文件的绝对路径。
Args:
input_csv_path: CSV 文件路径
show_plot: 是否显示可视化
fill_gaps: 是否先补齐时间缺口(默认 True)
"""
# 读取 CSV
input_csv_path = os.path.abspath(input_csv_path)
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
data = _normalize_time_frame(data)
value_cols, keep_cols = _select_pressure_columns(data)
has_repeated_time = "time" in data.columns and data["time"].duplicated().any()
identifier_col = _detect_long_form_identifier(data, value_cols, keep_cols)
if identifier_col is not None:
data_repaired = _clean_long_form_pressure(
data,
value_cols[0],
identifier_col,
keep_cols,
fill_gaps,
)
elif has_repeated_time and len(value_cols) == 1:
data_repaired = _clean_repeated_timestamp_pressure(data, value_cols[0], keep_cols)
else:
data_repaired = _clean_snapshot_pressure(data, value_cols, keep_cols, fill_gaps)
# 可选可视化(只展示首个数值列)
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
if show_plot and value_cols:
plot_col = value_cols[0]
if "time" in data_repaired.columns:
x = pd.to_datetime(data_repaired["time"], errors="coerce")
else:
x = np.arange(len(data_repaired))
plt.figure(figsize=(12, 6))
plt.plot(x, pd.to_numeric(data_repaired[plot_col], errors="coerce"), label="cleaned")
plt.xlabel("时间" if "time" in data_repaired.columns else "序号")
plt.ylabel("压力监测值")
plt.title(f"{plot_col} 清洗结果")
plt.legend()
plt.show()
# 保存到 Excel:两个 sheet
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
output_filename = f"{input_base}_cleaned.xlsx"
output_path = os.path.join(input_dir, output_filename)
# 如果原始数据包含时间列,将其添加回结果
data_for_save = data.copy()
data_repaired_for_save = data_repaired.copy()
if os.path.exists(output_path):
os.remove(output_path) # 覆盖同名文件
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
data_for_save.to_excel(writer, sheet_name="raw_pressure_data", index=False)
data_repaired_for_save.to_excel(
writer, sheet_name="cleaned_pressusre_data", index=False
)
# 返回输出文件的绝对路径
return os.path.abspath(output_path)
def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> pd.DataFrame:
"""
接收一个 DataFrame 数据结构,使用时间感知的稳健修复方法清洗压力数据。
返回清洗后的 DataFrame。
Args:
data: 输入 DataFrame(可包含 time 列)
show_plot: 是否显示可视化
"""
# 使用传入的 DataFrame
data = data.copy()
data = _normalize_time_frame(data)
value_cols, keep_cols = _select_pressure_columns(data)
has_repeated_time = "time" in data.columns and data["time"].duplicated().any()
identifier_col = _detect_long_form_identifier(data, value_cols, keep_cols)
if identifier_col is not None:
data_repaired = _clean_long_form_pressure(
data,
value_cols[0],
identifier_col,
keep_cols,
fill_gaps=True,
)
elif has_repeated_time and len(value_cols) == 1:
data_repaired = _clean_repeated_timestamp_pressure(data, value_cols[0], keep_cols)
else:
data_repaired = _clean_snapshot_pressure(data, value_cols, keep_cols, fill_gaps=True)
if show_plot and value_cols:
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
plot_col = value_cols[0]
x = pd.to_datetime(data_repaired["time"], errors="coerce") if "time" in data_repaired.columns else np.arange(len(data_repaired))
plt.figure(figsize=(12, 6))
plt.plot(x, pd.to_numeric(data_repaired[plot_col], errors="coerce"), label="cleaned")
plt.xlabel("时间" if "time" in data_repaired.columns else "序号")
plt.ylabel("压力监测值")
plt.title(f"{plot_col} 清洗结果")
plt.legend()
plt.show()
return data_repaired
# 测试
# if __name__ == "__main__":
# # 默认使用脚本目录下的 pressure_raw_data.csv
# script_dir = os.path.dirname(os.path.abspath(__file__))
# default_csv = os.path.join(script_dir, "pressure_raw_data.csv")
# out_path = clean_pressure_data_km(default_csv, show_plot=False)
# print("保存路径:", out_path)
# 测试 clean_pressure_data_dict_km 函数
if __name__ == "__main__":
import random
# 读取 szh_pressure_scada.csv 文件
script_dir = os.path.dirname(os.path.abspath(__file__))
csv_path = os.path.join(script_dir, "szh_pressure_scada.csv")
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
# 排除 Time 列,随机选择 5 列
columns_to_exclude = ["Time"]
available_columns = [col for col in data.columns if col not in columns_to_exclude]
selected_columns = random.sample(available_columns, 5)
# 将选中的列转换为字典
data_dict = {col: data[col].tolist() for col in selected_columns}
print("选中的列:", selected_columns)
print("原始数据长度:", len(data_dict[selected_columns[0]]))
# 调用函数进行清洗
cleaned_dict = clean_pressure_data_df_km(data_dict, show_plot=True)
print("清洗后的字典键:", list(cleaned_dict.keys()))
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
print("测试完成:函数运行正常")
+3
View File
@@ -0,0 +1,3 @@
from app.algorithms.health.analyzer import PipelineHealthAnalyzer
__all__ = ["PipelineHealthAnalyzer"]
+3
View File
@@ -0,0 +1,3 @@
from app.algorithms.isolation.valve import valve_isolation_analysis
__all__ = ["valve_isolation_analysis"]
+3
View File
@@ -0,0 +1,3 @@
from app.algorithms.leakage.identifier import LeakageIdentifier
__all__ = ["LeakageIdentifier"]
+653
View File
@@ -0,0 +1,653 @@
import wntr
import numpy as np
import pandas as pd
import os
import time
import argparse
from multiprocessing import Pool, cpu_count
from typing import Any, List, Dict, Union
from pymoo.core.problem import Problem
from pymoo.core.callback import Callback
from pymoo.algorithms.soo.nonconvex.ga import GA
from pymoo.operators.crossover.sbx import SBX
from pymoo.operators.mutation.pm import PM
from pymoo.optimize import minimize as pymoo_minimize
from pymoo.termination.default import DefaultSingleObjectiveTermination
from app.algorithms._utils import _cleanup_temp_files
_worker_data: dict[str, Any] = {}
DEFAULT_N_WORKERS = max(1, min(cpu_count() - 1, 4))
def _worker_init(
inp_path: str,
sensor_nodes: list[str],
area_ids: list[str],
nodes_by_area: dict[str, list[str]],
obs_matrix: np.ndarray,
q_sum: float,
duration_sec: float,
timestep_sec: float,
) -> None:
global _worker_data
wn = wntr.network.WaterNetworkModel(inp_path)
wn.options.hydraulic.demand_model = "DD"
wn.options.time.duration = duration_sec
wn.options.time.hydraulic_timestep = timestep_sec
wn.options.time.pattern_timestep = timestep_sec
wn.options.time.report_timestep = timestep_sec
demand_objs_by_area = {}
allocatable_counts = {}
for area_id in area_ids:
demand_objs = []
for node_name in nodes_by_area.get(area_id, []):
if node_name not in wn.node_name_list:
continue
node = wn.get_node(node_name)
if (
hasattr(node, "demand_timeseries_list")
and len(node.demand_timeseries_list) > 0
):
demand_objs.append(node.demand_timeseries_list[0])
demand_objs_by_area[area_id] = demand_objs
allocatable_counts[area_id] = len(demand_objs)
_worker_data = {
"wn": wn,
"sensor_nodes": sensor_nodes,
"area_ids": area_ids,
"nodes_by_area": nodes_by_area,
"demand_objs_by_area": demand_objs_by_area,
"allocatable_counts": allocatable_counts,
"obs_matrix": obs_matrix,
"q_sum": q_sum,
}
def _worker_evaluate(raw_ratios: np.ndarray) -> float:
d = _worker_data
effective_ratio_map = LeakageIdentifier._effective_area_ratios(
raw_ratios,
d["area_ids"],
d["nodes_by_area"],
allocatable_counts=d["allocatable_counts"],
)
modifications = []
for area_id in d["area_ids"]:
ratio = effective_ratio_map.get(area_id, 0.0)
if ratio <= 0:
continue
demand_objs = d["demand_objs_by_area"].get(area_id, [])
if not demand_objs:
continue
per_node_leak = d["q_sum"] * ratio / len(demand_objs)
for demand_obj in demand_objs:
original_val = demand_obj.base_value
demand_obj.base_value = original_val + per_node_leak
modifications.append((demand_obj, original_val))
temp_dir = os.path.abspath(os.path.join("temp", "leakage"))
os.makedirs(temp_dir, exist_ok=True)
prefix = os.path.join(temp_dir, f"temp_{os.getpid()}")
try:
sim = wntr.sim.EpanetSimulator(d["wn"])
results = sim.run_sim(file_prefix=prefix)
sim_pressure = results.node["pressure"].loc[:, d["sensor_nodes"]]
n_steps = min(sim_pressure.shape[0], d["obs_matrix"].shape[0])
sim_vals = sim_pressure.values[:n_steps, :]
obs_vals = d["obs_matrix"][:n_steps, :]
diff = sim_vals - obs_vals
row_max = np.max(np.abs(diff), axis=1, keepdims=True)
row_max[row_max == 0] = 1.0
normalized_diff = diff / row_max
return float(np.linalg.norm(normalized_diff))
except Exception:
return 1e9
finally:
for demand_obj, original_val in modifications:
demand_obj.base_value = original_val
_cleanup_temp_files(prefix)
class LeakageIdentifier:
FLOW_UNIT_TO_M3S = {
"m3/s": 1.0,
"m3/h": 1.0 / 3600.0,
"L/s": 1.0 / 1000.0,
"L/min": 1.0 / 60000.0,
}
@classmethod
def _flow_to_m3s(cls, value: float, unit: str) -> float:
if unit not in cls.FLOW_UNIT_TO_M3S:
raise ValueError(f"不支持的流量单位: {unit}")
return float(value) * cls.FLOW_UNIT_TO_M3S[unit]
@classmethod
def _flow_from_m3s(cls, value_m3s: float, unit: str) -> float:
if unit not in cls.FLOW_UNIT_TO_M3S:
raise ValueError(f"不支持的流量单位: {unit}")
return float(value_m3s) / cls.FLOW_UNIT_TO_M3S[unit]
@staticmethod
def _effective_area_ratios(
raw_ratios: Union["np.ndarray", Dict[str, float]],
area_ids: List[str],
nodes_by_area: Dict[str, List[str]],
allocatable_counts: Union[Dict[str, int], None] = None,
) -> Dict[str, float]:
"""将输入比例转换为有效区域比例,确保有效区域比例和为 1。"""
area_count = len(area_ids)
if area_count == 0:
return {}
if isinstance(raw_ratios, dict):
ratios = np.array(
[float(raw_ratios.get(area_id, 0.0)) for area_id in area_ids],
dtype=float,
)
else:
arr = np.asarray(raw_ratios, dtype=float).reshape(-1)
ratios = np.zeros(area_count, dtype=float)
fill_len = min(area_count, arr.shape[0])
if fill_len > 0:
ratios[:fill_len] = arr[:fill_len]
# 仅保留非负比例,负值按 0 处理
ratios = np.clip(ratios, a_min=0.0, a_max=None)
# 仅在有效区域(存在可分配节点)内归一化
if allocatable_counts is not None:
valid_mask = np.array(
[int(allocatable_counts.get(area_id, 0)) > 0 for area_id in area_ids],
dtype=bool,
)
else:
valid_mask = np.array(
[len(nodes_by_area.get(area_id, [])) > 0 for area_id in area_ids],
dtype=bool,
)
if not np.any(valid_mask):
raise ValueError("没有可分配漏损的有效分区,无法满足漏损总量约束。")
effective = np.zeros(area_count, dtype=float)
valid_sum = float(np.sum(ratios[valid_mask]))
if valid_sum > 0:
effective[valid_mask] = ratios[valid_mask] / valid_sum
else:
# 若输入全为 0,则在有效区域内均分,保证总和仍为 1
valid_count = int(np.sum(valid_mask))
effective[valid_mask] = 1.0 / valid_count
return {area_id: float(effective[idx]) for idx, area_id in enumerate(area_ids)}
@staticmethod
def _normalize_area_map_df(df: pd.DataFrame) -> pd.DataFrame:
"""标准化区域映射列名为 ID 和 Area。"""
if "ID" in df.columns and "Area" in df.columns:
return df
if "ID" in df.columns and "now" in df.columns:
df = df.rename(columns={"now": "Area"})
return df
df = df.copy()
df.columns = ["ID", "Area"] + list(df.columns[2:])
return df
def __init__(
self,
inp_path: str,
sensor_nodes: List[str],
area_map: Union[str, Dict[str, str]],
start_time: float = 0,
duration: float = 24,
timestep: float = 5,
q_sum: float = 0.2,
):
"""
初始化漏损识别器。
参数:
inp_path: EPANET .inp 文件路径。
sensor_nodes: 用作压力传感器的节点 ID 列表。
area_map: 节点到区域的映射。可以是 CSV 文件路径(列:ID, Area),也可以是字典 {NodeID: AreaID}。
start_time: 模拟开始时间(小时)。
duration: 模拟持续时间(小时)。
timestep: 模拟时间步长(分钟)。
q_sum: 假设的总漏损流量 (m3/s)。
"""
self.inp_path = inp_path
self.sensor_nodes = sensor_nodes
self.start_time = start_time
self.duration = duration
self.timestep = timestep
self.q_sum = q_sum
# 加载管网模型(仅一次)
self.wn = wntr.network.WaterNetworkModel(self.inp_path)
# 优化 WNTR 设置以提高速度
self.wn.options.hydraulic.demand_model = "DD"
self.wn.options.time.duration = float(self.duration) * 3600
self.wn.options.time.hydraulic_timestep = float(self.timestep) * 60
self.wn.options.time.pattern_timestep = float(self.timestep) * 60
self.wn.options.time.report_timestep = float(self.timestep) * 60
# 加载区域映射
if isinstance(area_map, str):
self.area_map_df = self._load_area_map(area_map)
elif isinstance(area_map, dict):
self.area_map_df = self._normalize_area_map_df(
pd.DataFrame(list(area_map.items()), columns=["ID", "Area"])
)
else:
raise ValueError("area_map 必须是 CSV 文件路径或字典。")
self.area_ids = sorted(self.area_map_df["Area"].unique())
self.num_areas = len(self.area_ids)
# 按区域对节点进行预分类,以便更快查找
self.nodes_by_area = {
area: self.area_map_df[self.area_map_df["Area"] == area]["ID"].tolist()
for area in self.area_ids
}
def _load_area_map(self, path: str) -> pd.DataFrame:
"""加载并验证节点-区域映射文件。"""
df = pd.read_csv(path, dtype={"ID": str, "Area": str})
return self._normalize_area_map_df(df)
def run_identification(
self,
observed_pressure_data: Union[
str, pd.DataFrame, Dict[str, List[Any]], List[Dict[str, Any]]
],
output_dir: str = "Results",
pop_size: int = 50,
max_gen: int = 100,
output_flow_unit: str = "m3/s",
save_result: bool = True,
ftol: float = 1e-3,
ftol_period: int = 15,
n_workers: int = DEFAULT_N_WORKERS,
):
"""
运行遗传算法以识别漏损分布。
参数:
observed_pressure_data: 包含 SCADA 压力数据的 CSV 文件路径或 DataFrame/字典列表数据。
output_dir: 结果保存目录。
pop_size: GA 的种群大小。
max_gen: GA 的最大代数。
output_flow_unit: 输出漏损流量的单位。
save_result: 是否保存识别结果到本地 CSV。
ftol: 目标值收敛容差(连续 ftol_period 代改善 < ftol 则停止)。
ftol_period: 收敛检测的窗口代数。
n_workers: 并行工作进程数(1=串行,>1=并行评估)。
"""
if save_result:
os.makedirs(output_dir, exist_ok=True)
# 加载观测数据
if isinstance(observed_pressure_data, str):
obs_df = pd.read_csv(observed_pressure_data)
observed_name = os.path.basename(observed_pressure_data)
elif isinstance(observed_pressure_data, pd.DataFrame):
obs_df = observed_pressure_data.copy()
observed_name = "observed_pressure.csv"
else:
obs_df = pd.DataFrame(observed_pressure_data)
observed_name = "observed_pressure.csv"
# 准备 pymoo 问题实例
problem = LeakageProblem(
self.wn,
self.nodes_by_area,
self.area_ids,
self.sensor_nodes,
obs_df,
q_sum=self.q_sum,
n_workers=n_workers,
inp_path=os.path.abspath(self.inp_path),
)
# 配置 pymoo GA 算法
n_var = self.num_areas
algorithm = GA(
pop_size=pop_size,
crossover=SBX(prob=0.9, eta=15),
mutation=PM(prob=1.0 / max(1, n_var), eta=20),
eliminate_duplicates=False,
)
# 终止条件:收敛检测 + 最大代数
termination = DefaultSingleObjectiveTermination(
ftol=ftol,
period=ftol_period,
n_max_gen=max_gen,
)
# 回调:记录每代信息
callback = _ProgressCallback()
t0 = time.time()
try:
res = pymoo_minimize(
problem,
algorithm,
termination,
seed=42,
verbose=True,
callback=callback,
)
finally:
problem.close()
elapsed = time.time() - t0
# 提取最优解
best_ind = res.X # 最优个体(漏损比例原始值)
best_obj = float(res.F[0])
# 输出终止信息
print(f"\n优化完成。耗时: {elapsed:.1f}s")
print(f"总代数: {res.algorithm.n_gen}, 总评估次数: {problem._eval_count}")
print(f"最佳目标值: {best_obj:.6f}")
# 保存到文件
effective_ratio_map = self._effective_area_ratios(
best_ind,
self.area_ids,
self.nodes_by_area,
allocatable_counts=problem.allocatable_counts,
)
normalized_ratios = [
effective_ratio_map.get(area_id, 0.0) for area_id in self.area_ids
]
leakage_flow_m3s = [ratio * self.q_sum for ratio in normalized_ratios]
leakage_flow_output = [
self._flow_from_m3s(value_m3s, output_flow_unit)
for value_m3s in leakage_flow_m3s
]
result_df = pd.DataFrame(
{
"Area": self.area_ids,
"LeakageRatioRaw": best_ind,
"LeakageRatio": normalized_ratios,
"LeakageFlow_m3_per_s": leakage_flow_m3s,
f"LeakageFlow_{output_flow_unit.replace('/', '_per_')}": leakage_flow_output,
}
)
result_path = None
if save_result:
result_path = os.path.join(
output_dir, f"identified_leakage_{observed_name}"
)
result_df.to_csv(result_path, index=False)
print(f"结果已保存至 {result_path}")
result_df.attrs["result_path"] = result_path
return result_df
class _ProgressCallback(Callback):
"""每代回调:记录进度。"""
def __init__(self):
super().__init__()
self.gen_times = []
self._t_last = None
def notify(self, algorithm):
now = time.time()
if self._t_last is not None:
self.gen_times.append(now - self._t_last)
self._t_last = now
class LeakageProblem(Problem):
"""pymoo 批量评估问题定义。
搜索空间:n 维 [0, 1] 实数 -> 通过 _effective_area_ratios 归一化到单纯形。
目标:模拟压力与观测压力之间的归一化误差范数。
无显式约束(sum=1 由归一化自动保证)。
"""
def __init__(
self,
wn,
nodes_by_area,
area_ids,
sensor_nodes,
observed_data,
q_sum: float = 0.2,
n_workers: int = DEFAULT_N_WORKERS,
inp_path: str | None = None,
):
n_var = len(area_ids)
super().__init__(
n_var=n_var,
n_obj=1,
n_ieq_constr=0,
xl=np.zeros(n_var),
xu=np.ones(n_var),
)
self.wn = wn
self.nodes_by_area = nodes_by_area
self.area_ids = area_ids
self.sensor_nodes = sensor_nodes
self.q_sum = q_sum
self.n_workers = max(1, int(n_workers))
self.inp_path = inp_path
# 预处理观测数据以匹配模拟格式
try:
missing_sensors = [
s for s in self.sensor_nodes if s not in observed_data.columns
]
if not missing_sensors:
self.obs_matrix = observed_data[self.sensor_nodes].values
else:
self.obs_matrix = observed_data.values[:, : len(self.sensor_nodes)]
except Exception:
self.obs_matrix = observed_data.values[:, : len(self.sensor_nodes)]
duration_sec = float(self.wn.options.time.duration)
step_sec = float(self.wn.options.time.hydraulic_timestep)
if step_sec > 0:
max_steps = int(duration_sec / step_sec) + 1
self.obs_matrix = self.obs_matrix[:max_steps, :]
# 预先缓存每个区域的需水对象,减少每次适应度计算的节点查找
self.demand_objs_by_area = {}
for area_id in self.area_ids:
demand_objs = []
for node_name in self.nodes_by_area.get(area_id, []):
if node_name not in self.wn.node_name_list:
continue
node = self.wn.get_node(node_name)
if (
hasattr(node, "demand_timeseries_list")
and len(node.demand_timeseries_list) > 0
):
demand_objs.append(node.demand_timeseries_list[0])
self.demand_objs_by_area[area_id] = demand_objs
self.allocatable_counts = {
area_id: len(self.demand_objs_by_area.get(area_id, []))
for area_id in self.area_ids
}
if not any(count > 0 for count in self.allocatable_counts.values()):
raise ValueError("没有可分配漏损的有效分区,无法满足漏损总量约束。")
# 评估计数器(诊断用)
self._eval_count = 0
self._pool = None
if self.n_workers > 1:
if not self.inp_path:
raise ValueError("并行评估需要提供 inp_path。")
duration_sec = float(self.wn.options.time.duration)
timestep_sec = float(self.wn.options.time.hydraulic_timestep)
self._pool = Pool(
processes=self.n_workers,
initializer=_worker_init,
initargs=(
self.inp_path,
list(self.sensor_nodes),
list(self.area_ids),
{k: list(v) for k, v in self.nodes_by_area.items()},
self.obs_matrix.copy(),
self.q_sum,
duration_sec,
timestep_sec,
),
)
def _evaluate(self, X, out, *args, **kwargs):
"""批量评估种群。
X: 形状 (pop_size, n_var) 的决策变量矩阵。
"""
n_pop = X.shape[0]
self._eval_count += n_pop
if self._pool is not None:
results = self._pool.map(_worker_evaluate, [X[i] for i in range(n_pop)])
out["F"] = np.array(results, dtype=float).reshape(-1, 1)
return
F = np.zeros((n_pop, 1))
for i in range(n_pop):
F[i, 0] = self._evaluate_single(X[i])
out["F"] = F
def _evaluate_single(self, x):
"""评估单个个体,返回归一化误差范数。"""
leak_ratios = x
# 将漏损分布归一化
effective_ratio_map = LeakageIdentifier._effective_area_ratios(
leak_ratios,
self.area_ids,
self.nodes_by_area,
allocatable_counts=self.allocatable_counts,
)
# 跟踪修改以便稍后恢复
modifications = []
for j, area_id in enumerate(self.area_ids):
ratio = effective_ratio_map.get(area_id, 0.0)
if ratio <= 0:
continue
demand_objs = self.demand_objs_by_area.get(area_id, [])
if not demand_objs:
continue
per_node_leak = self.q_sum * ratio / len(demand_objs)
for demand_obj in demand_objs:
original_val = demand_obj.base_value
demand_obj.base_value = original_val + per_node_leak
modifications.append((demand_obj, original_val))
# 结果保存在根目录的temp/leakage文件夹中
temp_dir = os.path.abspath(os.path.join("temp", "leakage"))
os.makedirs(temp_dir, exist_ok=True)
prefix = os.path.join(temp_dir, f"temp_{os.getpid()}")
try:
sim = wntr.sim.EpanetSimulator(self.wn)
results = sim.run_sim(file_prefix=prefix)
sim_pressure = results.node["pressure"].loc[:, self.sensor_nodes]
n_steps = min(sim_pressure.shape[0], self.obs_matrix.shape[0])
sim_vals = sim_pressure.values[:n_steps, :]
obs_vals = self.obs_matrix[:n_steps, :]
diff = sim_vals - obs_vals
# 按行最大值归一化
row_max = np.max(np.abs(diff), axis=1, keepdims=True)
row_max[row_max == 0] = 1.0 # 防止除以零
normalized_diff = diff / row_max
# 目标:归一化差值矩阵的 2-范数
return float(np.linalg.norm(normalized_diff))
except Exception:
return 1e9
finally:
for demand_obj, original_val in modifications:
demand_obj.base_value = original_val
_cleanup_temp_files(prefix)
def close(self) -> None:
if self._pool is not None:
self._pool.close()
self._pool.join()
self._pool = None
def main() -> int:
parser = argparse.ArgumentParser(description="漏损区域识别")
parser.add_argument("--inp", required=True, help=".inp 文件路径")
parser.add_argument("--map", help="节点-区域映射 CSV 路径")
parser.add_argument("--scada", help="SCADA 压力 CSV 路径 (观测数据)")
parser.add_argument("--sensors", help="传感器节点 ID 列表 (逗号分隔)")
parser.add_argument("--output", default="Results", help="输出目录")
parser.add_argument("--pop_size", type=int, default=50, help="种群大小")
parser.add_argument("--max_gen", type=int, default=100, help="最大代数")
parser.add_argument("--duration", type=float, default=24, help="模拟时长(小时)")
parser.add_argument("--q_sum", type=float, default=0.241, help="总漏损流量")
parser.add_argument(
"--q_sum_unit",
default="m3/s",
choices=list(LeakageIdentifier.FLOW_UNIT_TO_M3S.keys()),
help="q_sum 输入单位(建议与现场习惯一致,内部统一换算为 m3/s)",
)
args = parser.parse_args()
if not args.map or not args.scada or not args.sensors:
parser.error("--map、--scada、--sensors 为必填")
q_sum_m3s = LeakageIdentifier._flow_to_m3s(args.q_sum, args.q_sum_unit)
sensors = [sensor.strip() for sensor in args.sensors.split(",") if sensor.strip()]
identifier = LeakageIdentifier(
args.inp, sensors, args.map, duration=args.duration, q_sum=q_sum_m3s
)
identifier.run_identification(
args.scada,
args.output,
pop_size=args.pop_size,
max_gen=args.max_gen,
output_flow_unit=args.q_sum_unit,
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
@@ -1,8 +1,8 @@
import psycopg
import app.algorithms.api_ex.kmeans_sensor as kmeans_sensor
import app.algorithms.api_ex.sensitivity as sensitivity
from app.native.api.postgresql_info import get_pgconn_string
from app.algorithms.sensor import kmeans as kmeans_sensor
from app.algorithms.sensor import sensitivity
from app.core.config import get_pgconn_string
from app.services.tjnetwork import dump_inp
@@ -6,104 +6,66 @@ import sklearn.cluster
import os
class QD_KMeans(object):
def __init__(self, wn, num_monitors):
# self.inp = inp
self.cluster_num = num_monitors # 聚类中心个数,也即测压点个数
self.wn=wn
self.cluster_num = num_monitors # 聚类中心个数,也即测压点个数
self.wn = wn
self.monitor_nodes = []
self.coords = []
self.junction_nodes = {} # Added missing initialization
def get_junctions_coordinates(self):
for junction_name in self.wn.junction_name_list:
for junction_name in self.wn.junction_name_list:
junction = self.wn.get_node(junction_name)
self.junction_nodes[junction_name] = junction.coordinates
self.coords.append(junction.coordinates )
self.coords.append(junction.coordinates)
# print(f"Total junctions: {self.junction_coordinates}")
# print(f"Total junctions: {self.junction_coordinates}")
def select_monitoring_points(self):
if not self.coords: # Add check if coordinates are collected
self.get_junctions_coordinates()
coords = np.array(self.coords)
coords_normalized = (coords - coords.min(axis=0)) / (coords.max(axis=0) - coords.min(axis=0))
kmeans = sklearn.cluster.KMeans(n_clusters= self.cluster_num, random_state=42)
kmeans.fit(coords_normalized)
coords_normalized = (coords - coords.min(axis=0)) / (
coords.max(axis=0) - coords.min(axis=0)
)
kmeans = sklearn.cluster.KMeans(n_clusters=self.cluster_num, random_state=42)
kmeans.fit(coords_normalized)
for center in kmeans.cluster_centers_:
distances = np.sum((coords_normalized - center) ** 2, axis=1)
nearest_node = self.wn.junction_name_list[np.argmin(distances)]
self.monitor_nodes.append(nearest_node)
self.monitor_nodes.append(nearest_node)
return self.monitor_nodes
def visualize_network(self):
"""Visualize network with monitoring points"""
ax=wntr.graphics.plot_network(self.wn,
node_attribute=self.monitor_nodes,
node_size=30,
title='Optimal sensor')
plt.show()
ax = wntr.graphics.plot_network(
self.wn,
node_attribute=self.monitor_nodes,
node_size=30,
title="Optimal sensor",
)
plt.show()
def kmeans_sensor_placement(name: str, sensor_num: int, min_diameter: int) -> list:
inp_name = f'./db_inp/{name}.db.inp'
wn= wntr.network.WaterNetworkModel(inp_name)
wn_cluster=QD_KMeans(wn, sensor_num)
inp_name = f"./db_inp/{name}.db.inp"
wn = wntr.network.WaterNetworkModel(inp_name)
wn_cluster = QD_KMeans(wn, sensor_num)
# Select monitoring pointse
sensor_ids= wn_cluster.select_monitoring_points()
sensor_ids = wn_cluster.select_monitoring_points()
# wn_cluster.visualize_network()
return sensor_ids
if __name__ == "__main__":
#sensorindex = get_ID(name='suzhouhe_2024_cloud_0817', sensor_num=30, min_diameter=500)
sensorindex = kmeans_sensor_placement(name='szh', sensor_num=50, min_diameter=300)
# sensorindex = get_ID(name='suzhouhe_2024_cloud_0817', sensor_num=30, min_diameter=500)
sensorindex = kmeans_sensor_placement(name="szh", sensor_num=50, min_diameter=300)
print(sensorindex)
@@ -11,7 +11,6 @@ from sklearn.cluster import KMeans
from wntr.epanet.toolkit import EpanetException
from numpy.linalg import slogdet
import random
from app.services.tjnetwork import *
from matplotlib.lines import Line2D
from sklearn.cluster import SpectralClustering
import libpysal as ps
@@ -21,6 +20,7 @@ import geopandas as gpd
from sklearn.metrics import pairwise_distances
import app.services.project_info as project_info
# 2025/03/12
# Step1: 获取节点坐标
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
@@ -32,7 +32,7 @@ def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
# site: pandas.Series
# index:节点名称(wn.node_name_list
# values:每个节点的坐标,格式为 tuple(如 (x, y) 或 (x, y, z)
site = wn.query_node_attribute('coordinates')
site = wn.query_node_attribute("coordinates")
# Coor: pandas.Series
# index:与site相同(节点名称)。
# values:坐标转换为numpy.ndarray(如array([10.5, 20.3])
@@ -44,9 +44,9 @@ def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
# xy: dict[str, list], x、y 坐标的字典
xy = {'x': x, 'y': y}
xy = {"x": x, "y": y}
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=["x", "y"])
return Coor_node
@@ -88,23 +88,23 @@ def skater_partition(G, n_clusters):
字典形式的聚类结果键为区域编号值为该区域内的节点列表
"""
# 1. 获取所有节点坐标,假设每个节点都有 'pos' 属性
pos = nx.get_node_attributes(G, 'pos')
pos = nx.get_node_attributes(G, "pos")
nodes = list(G.nodes())
# 构造坐标数组:每行为 [x, y]
coords = np.array([pos[node] for node in nodes])
# 2. 构造 GeoDataFrame:创建 DataFrame 并生成 geometry 列
df = pd.DataFrame(coords, columns=['x', 'y'], index=nodes)
df = pd.DataFrame(coords, columns=["x", "y"], index=nodes)
# 利用 shapely 的 Point 构造空间位置
df['geometry'] = df.apply(lambda row: Point(row['x'], row['y']), axis=1)
gdf = gpd.GeoDataFrame(df, geometry='geometry')
df["geometry"] = df.apply(lambda row: Point(row["x"], row["y"]), axis=1)
gdf = gpd.GeoDataFrame(df, geometry="geometry")
# 3. 构造空间权重矩阵,使用 4 近邻方法(k=4,可根据实际情况调整)
w = ps.weights.KNN.from_array(coords, k=4)
w.transform = 'R'
w.transform = "R"
# 4. 调用 SKATER:新版本 API 要求传入 gdf, w 以及 attrs_name(这里使用 'x' 和 'y' 作为属性)
skater = Skater(gdf, w, attrs_name=['x', 'y'], n_clusters=n_clusters)
skater = Skater(gdf, w, attrs_name=["x", "y"], n_clusters=n_clusters)
skater.solve()
# 5. 获取聚类标签,构造成字典格式
@@ -134,24 +134,24 @@ def spectral_partition(G, n_clusters):
键为聚类标签值为该聚类对应的节点列表
"""
# 1. 获取节点空间坐标,注意保证每个节点都有 'pos' 属性
pos_dict = nx.get_node_attributes(G, 'pos')
pos_dict = nx.get_node_attributes(G, "pos")
nodes = list(G.nodes())
coords = np.array([pos_dict[node] for node in nodes])
# 2. 计算节点之间的欧氏距离矩阵
D = pairwise_distances(coords, metric='euclidean')
D = pairwise_distances(coords, metric="euclidean")
# 3. 计算 sigma 值:这里取所有距离的均值,当然也可以根据实际情况调整
sigma = np.mean(D)
# 4. 构造相似度矩阵:使用高斯核函数
# A(i, j) = exp( -d(i,j)^2 / (2*sigma^2) )
A = np.exp(- (D ** 2) / (2 * sigma ** 2))
A = np.exp(-(D**2) / (2 * sigma**2))
# 5. 使用谱聚类进行图分区
clustering = SpectralClustering(n_clusters=n_clusters,
affinity='precomputed',
random_state=0)
clustering = SpectralClustering(
n_clusters=n_clusters, affinity="precomputed", random_state=0
)
labels = clustering.fit_predict(A)
# 6. 构造字典形式的分区结果
@@ -161,6 +161,7 @@ def spectral_partition(G, n_clusters):
return groups
# 2025/03/12
# Step3: wn_func类,水力计算
# wn_func 主要用于计算:
@@ -182,7 +183,7 @@ class wn_func(object):
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
self.wn = wn
# self.qpandas.DataFrame,管道流量,索引为时间步长,列为管道名称
self.q = self.results.link['flowrate']
self.q = self.results.link["flowrate"]
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
ReservoirIndex = wn.reservoir_name_list
Tankindex = wn.tank_name_list
@@ -192,7 +193,7 @@ class wn_func(object):
# self.nodes: list[str],所有节点的名称
self.nodes = wn.node_name_list
# self.coordinatespandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
self.coordinates = wn.query_node_attribute('coordinates')
self.coordinates = wn.query_node_attribute("coordinates")
# allpumps / allvalves: list[str],所有泵/阀门名称列表
allpumps = wn.pump_name_list
allvalves = wn.valve_name_list
@@ -223,17 +224,27 @@ class wn_func(object):
# 泵的起终点、tank、reservoir
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
self.delnodes = list(
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
set(ReservoirIndex).union(
Tankindex,
pumpstnode,
pumpednode,
valvestnode,
valveednode,
Reservoirednode,
)
)
# 泵、起终点为tank、reservoir的管道
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
self.delpipes = list(
set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe)
)
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
# self.L: list[float],所有管道的长度(以米为单位)
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
self.L = wn.query_link_attribute("length")[self.pipes].tolist()
self.n = len(self.nodes)
self.m = len(self.pipes)
# self.unit_headloss: list[float],单位水头损失(headloss 数据的第一行,单位:米/km)
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
self.unit_headloss = self.results.link["headloss"].iloc[0, :].tolist()
##
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
@@ -246,7 +257,9 @@ class wn_func(object):
end_node = wn.links[pipe].end_node.name
self.less_than_min_diameter_junction_list.extend([start_node, end_node])
# 去重
self.less_than_min_diameter_junction_list = list(set(self.less_than_min_diameter_junction_list))
self.less_than_min_diameter_junction_list = list(
set(self.less_than_min_diameter_junction_list)
)
# Step3.2: 计算水力距离
def CtoS(self):
@@ -267,7 +280,7 @@ class wn_func(object):
q = self.q
L = self.L
# H1pandas.DataFrame,水头数据,索引为时间步长,列为节点名
H1 = self.results.node['head'].T
H1 = self.results.node["head"].T
# hhlist[float],计算管道两端水头之差
hh = []
# 水头损失
@@ -281,8 +294,18 @@ class wn_func(object):
# headlosspandas.DataFrame,管道水头损失矩阵
headloss = pd.DataFrame(hh, index=pipes).T
# s1:管道阻力系数,s2:将管道阻力系数与管道的起始节点和终止节点对应
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
hf = pd.DataFrame(
np.array([0] * (n**2)).reshape(n, n),
index=nodes,
columns=nodes,
dtype=float,
)
weightL = pd.DataFrame(
np.array([0] * (n**2)).reshape(n, n),
index=nodes,
columns=nodes,
dtype=float,
)
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵,index代表起始节点,columns代表终止节点
G = nx.DiGraph()
for i in range(0, m):
@@ -299,11 +322,16 @@ class wn_func(object):
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
hydraulicL = pd.DataFrame(
np.array([0] * (n**2)).reshape(n, n),
index=nodes,
columns=nodes,
dtype=float,
)
for a in nodes:
if a in G.nodes:
d = nx.shortest_path_length(G, source=a, weight='weight')
d = nx.shortest_path_length(G, source=a, weight="weight")
for b in list(d.keys()):
hydraulicL.loc[a, b] = d[b]
@@ -332,11 +360,17 @@ class wn_func(object):
for t in self.wn.tanks():
self.nonjunc_index.append(t[0])
# Connnumpy.matrix,节点-管道连接矩阵,起点 -1,终点 1
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
Conn = np.mat(
np.zeros([n, m - p - v])
) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
# NConnnumpy.matrix,节点-节点连接矩阵,有管道相连的地方设为 1
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系,之间有管道为1,反之为0
# pipeslist[str],去除泵和阀门的管道列表
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
pipes = [
pipe
for pipe in self.wn.pipes()
if pipe not in self.wn.pumps() and pipe not in self.wn.valves()
]
for pipe_name, pipe in pipes:
start = self.wn.node_name_list.index(pipe.start_node_name)
end = self.wn.node_name_list.index(pipe.end_node_name)
@@ -346,12 +380,21 @@ class wn_func(object):
NConn[start, end] = 1
NConn[end, start] = 1
self.A = Conn
link_name_list = [link for link in self.wn.link_name_list if
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
link_name_list = [
link
for link in self.wn.link_name_list
if link not in self.wn.pump_name_list
and link not in self.wn.valve_name_list
]
self.A2 = pd.DataFrame(
self.A, index=self.wn.node_name_list, columns=link_name_list
)
self.A2 = self.A2.drop(self.delnodes)
for pipe in self.delpipes:
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
if (
pipe not in self.wn.pump_name_list
and pipe not in self.wn.valve_name_list
):
self.A2 = self.A2.drop(columns=pipe)
self.junc_list = self.A2.index
self.A2 = np.mat(self.A2) # 节点管道关系
@@ -373,10 +416,10 @@ class wn_func(object):
except EpanetException:
pass
finally:
h = result.link['headloss'][self.pipes].values[0]
q = result.link['flowrate'][self.pipes].values[0]
l = self.wn.query_link_attribute('length')[self.pipes]
C = self.wn.query_link_attribute('roughness')[self.pipes]
h = result.link["headloss"][self.pipes].values[0]
q = result.link["flowrate"][self.pipes].values[0]
l = self.wn.query_link_attribute("length")[self.pipes]
C = self.wn.query_link_attribute("roughness")[self.pipes]
# headlossnumpy.ndarray,水头损失数组
headloss = np.array(h)
# 调整流量方向
@@ -394,7 +437,7 @@ class wn_func(object):
try:
det = np.linalg.det(X)
except RuntimeError as e:
sign, logdet = slogdet(X) # 防止溢出
sign, logdet = slogdet(X) # 防止溢出
det = sign * np.exp(logdet)
if det != 0:
J_H_Cw = X.I * A * S
@@ -431,7 +474,10 @@ class Sensorplacement(wn_func):
"""
Sensorplacement 类继承了 wn_func 并且用于计算和优化传感器布置的位置
"""
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int, min_diameter: int):
def __init__(
self, wn: wntr.network.WaterNetworkModel, sensornum: int, min_diameter: int
):
"""
:param wn: 由wntr生成的模型
@@ -443,7 +489,9 @@ class Sensorplacement(wn_func):
# 1.某个节点到所有节点的加权距离之和
# 2.某个节点到该组内所有节点的加权距离之和
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
def sensor(
self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]
):
"""
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
:param SS: 灵敏度矩阵每个节点的行和列代表不同节点矩阵元素表示节点间的灵敏度SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
@@ -528,7 +576,7 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
:return: 测压点节点ID
"""
# inp_file_realstr,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
inp_file_real = f'./db_inp/{name}.db.inp'
inp_file_real = f"./db_inp/{name}.db.inp"
# sensornumint,需要布置的传感器数量
# sensornum = sensor_num
# wn_realwntr.network.WaterNetworkModel,加载 EPANET 水力模型
@@ -539,7 +587,7 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
results_real = sim_real.run_sim()
# real_Clist[float],包含所有管道粗糙度的列表
real_C = wn_real.query_link_attribute('roughness').tolist()
real_C = wn_real.query_link_attribute("roughness").tolist()
# wn_fun1wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
wn_fun1 = wn_func(wn_real, min_diameter=min_diameter)
# nodeslist[str],管网的节点名称列表
@@ -599,7 +647,6 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
# 重新打开数据库
# if is_project_open(name=name):
# close_project(name=name)
@@ -638,7 +685,7 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
return sensorindex
if __name__ == '__main__':
if __name__ == "__main__":
sensorindex = get_ID(name=project_info.name, sensor_num=20, min_diameter=300)
print(sensorindex)
+19
View File
@@ -0,0 +1,19 @@
from app.algorithms.simulation.scenarios import (
convert_to_local_unit,
burst_analysis,
valve_close_analysis,
flushing_analysis,
contaminant_simulation,
age_analysis,
pressure_regulation,
)
__all__ = [
"convert_to_local_unit",
"burst_analysis",
"valve_close_analysis",
"flushing_analysis",
"contaminant_simulation",
"age_analysis",
"pressure_regulation",
]
@@ -1,6 +1,26 @@
import numpy as np
from app.services.tjnetwork import *
from api.s36_wda_cal import *
from app.services.tjnetwork import (
ChangeSet,
close_project,
copy_project,
delete_project,
get_pattern,
get_patterns,
get_pump,
get_reservoir,
get_status,
get_tank,
get_time,
have_project,
is_project_open,
open_project,
read_all,
run_project,
set_pattern,
set_status,
set_tank,
set_time,
)
# from get_real_status import *
from datetime import datetime,timedelta
from math import modf
@@ -5,14 +5,39 @@ from math import pi, sqrt
import pytz
import app.services.simulation as simulation
from app.algorithms.api_ex.run_simulation import (
from app.algorithms.simulation.runner import (
run_simulation_ex,
from_clock_to_seconds_2,
)
from app.native.api.project import copy_project
from app.services.epanet.epanet import Output
from app.services.scheme_management import store_scheme_info
from app.services.tjnetwork import *
from app.services.tjnetwork import (
ChangeSet,
OPTION_DEMAND_MODEL_PDA,
OPTION_QUALITY_CHEMICAL,
SOURCE_TYPE_SETPOINT,
add_pattern,
add_source,
close_project,
copy_project,
delete_project,
get_demand,
get_emitter,
get_node_links,
get_option,
get_pattern,
get_pipe,
get_source,
get_time,
have_project,
is_junction,
is_project_open,
open_project,
set_demand,
set_emitter,
set_option,
set_source,
set_time,
)
############################################################
@@ -637,24 +662,24 @@ def age_analysis(
new_name,
"realtime",
modify_pattern_start_time,
modify_total_duration,
duration=modify_total_duration,
downloading_prohibition=True,
)
simulation_result = json.loads(result)
output_data = simulation_result.get("output")
if not isinstance(output_data, dict):
raise RuntimeError("run_simulation_ex did not return JSON output content")
# step 2. restore the base model status
# execute_undo(name) #有疑惑
if is_project_open(new_name):
close_project(new_name)
delete_project(new_name)
output = Output("./temp/{}.db.out".format(new_name))
# element_name = output.element_name()
# node_name = element_name['nodes']
# link_name = element_name['links']
nodes_age = []
node_result = output.node_results()
node_result = output_data.get("node_results") or []
for node in node_result:
nodes_age.append(node["result"][-1]["quality"])
links_age = []
link_result = output.link_results()
link_result = output_data.get("link_results") or []
for link in link_result:
links_age.append(link["result"][-1]["quality"])
age_result = {"nodes": nodes_age, "links": links_age}
+37 -16
View File
@@ -3,28 +3,36 @@
仅管理员可访问
"""
from typing import List, Optional
from uuid import UUID
from datetime import datetime
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Query, Path
from app.domain.schemas.audit import AuditLogResponse
from app.infra.repositories.audit_repository import AuditRepository
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
from app.auth.metadata_dependencies import (
get_current_metadata_admin,
get_current_metadata_user,
)
from app.infra.db.metadata.database import get_metadata_session
from app.infra.db.metadb.database import get_metadata_session
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
async def get_audit_repository(
session: AsyncSession = Depends(get_metadata_session),
) -> AuditRepository:
"""获取审计日志仓储"""
return AuditRepository(session)
@router.get("/logs", response_model=List[AuditLogResponse])
@router.get(
"/logs",
summary="查询审计日志",
description="查询审计日志(仅管理员)",
response_model=List[AuditLogResponse],
)
async def get_audit_logs(
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
@@ -38,9 +46,9 @@ async def get_audit_logs(
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询审计日志(仅管理员)
支持按用户、时间、操作类型等条件过滤
查询审计日志
支持按用户、时间、操作类型等条件过滤,仅管理员可访问
"""
logs = await audit_repo.get_logs(
user_id=user_id,
@@ -50,11 +58,16 @@ async def get_audit_logs(
start_time=start_time,
end_time=end_time,
skip=skip,
limit=limit
limit=limit,
)
return logs
@router.get("/logs/count")
@router.get(
"/logs/count",
summary="获取审计日志总数",
description="获取审计日志总数(仅管理员)",
)
async def get_audit_logs_count(
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
@@ -66,7 +79,9 @@ async def get_audit_logs_count(
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> dict:
"""
获取审计日志总数(仅管理员)
获取审计日志总数
获取符合条件的审计日志的总数,仅管理员可访问
"""
count = await audit_repo.get_log_count(
user_id=user_id,
@@ -74,23 +89,29 @@ async def get_audit_logs_count(
action=action,
resource_type=resource_type,
start_time=start_time,
end_time=end_time
end_time=end_time,
)
return {"count": count}
@router.get("/logs/my", response_model=List[AuditLogResponse])
@router.get(
"/logs/my",
summary="查询我的审计日志",
description="查询当前用户的审计日志",
response_model=List[AuditLogResponse],
)
async def get_my_audit_logs(
action: Optional[str] = Query(None, description="按操作类型过滤"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
current_user=Depends(get_current_metadata_user),
audit_repo: AuditRepository = Depends(get_audit_repository),
) -> List[AuditLogResponse]:
"""
查询当前用户的审计日志
普通用户只能查看自己的操作记录
"""
logs = await audit_repo.get_logs(
@@ -99,6 +120,6 @@ async def get_my_audit_logs(
start_time=start_time,
end_time=end_time,
skip=skip,
limit=limit
limit=limit,
)
return logs
+48 -44
View File
@@ -5,7 +5,7 @@ from fastapi.security import OAuth2PasswordRequestForm
from app.core.config import settings
from app.core.security import create_access_token, create_refresh_token, verify_password
from app.domain.schemas.user import UserCreate, UserResponse, UserLogin, Token
from app.infra.repositories.user_repository import UserRepository
from app.infra.db.metadb.repositories.user_repository import UserRepository
from app.auth.dependencies import get_user_repository, get_current_active_user
from app.domain.schemas.user import UserInDB
import logging
@@ -14,53 +14,55 @@ logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
)
async def register(
user_data: UserCreate,
user_repo: UserRepository = Depends(get_user_repository)
user_data: UserCreate, user_repo: UserRepository = Depends(get_user_repository)
) -> UserResponse:
"""
用户注册
创建新用户账号
"""
# 检查用户名和邮箱是否已存在
if await user_repo.user_exists(username=user_data.username):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered"
detail="Username already registered",
)
if await user_repo.user_exists(email=user_data.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered"
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
)
# 创建用户
try:
user = await user_repo.create_user(user_data)
if not user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create user"
detail="Failed to create user",
)
return UserResponse.model_validate(user)
except Exception as e:
logger.error(f"Error during user registration: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Registration failed"
detail="Registration failed",
)
@router.post("/login", response_model=Token)
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> Token:
"""
用户登录(OAuth2 标准格式)
返回 JWT Access Token 和 Refresh Token
"""
# 验证用户(支持用户名或邮箱登录)
@@ -68,119 +70,121 @@ async def login(
if not user:
# 尝试用邮箱登录
user = await user_repo.get_user_by_email(form_data.username)
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user account"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user account"
)
# 生成 Token
access_token = create_access_token(subject=user.username)
refresh_token = create_refresh_token(subject=user.username)
return Token(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
@router.post("/login/simple", response_model=Token)
async def login_simple(
username: str,
password: str,
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> Token:
"""
简化版登录接口(保持向后兼容)
直接使用 username 和 password 参数
"""
# 验证用户
user = await user_repo.get_user_by_username(username)
if not user:
user = await user_repo.get_user_by_email(username)
if not user or not verify_password(password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password"
detail="Incorrect username or password",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Inactive user account"
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user account"
)
# 生成 Token
access_token = create_access_token(subject=user.username)
refresh_token = create_refresh_token(subject=user.username)
return Token(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
@router.get("/me", response_model=UserResponse)
async def get_current_user_info(
current_user: UserInDB = Depends(get_current_active_user)
current_user: UserInDB = Depends(get_current_active_user),
) -> UserResponse:
"""
获取当前登录用户信息
"""
return UserResponse.model_validate(current_user)
@router.post("/refresh", response_model=Token)
async def refresh_token(
refresh_token: str,
user_repo: UserRepository = Depends(get_user_repository)
refresh_token: str, user_repo: UserRepository = Depends(get_user_repository)
) -> Token:
"""
刷新 Access Token
使用 Refresh Token 获取新的 Access Token
"""
from jose import jwt, JWTError
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate refresh token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
payload = jwt.decode(
refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
token_type: str = payload.get("type")
if username is None or token_type != "refresh":
raise credentials_exception
except JWTError:
raise credentials_exception
# 验证用户仍然存在且激活
user = await user_repo.get_user_by_username(username)
if not user or not user.is_active:
raise credentials_exception
# 生成新的 Access Token
new_access_token = create_access_token(subject=user.username)
return Token(
access_token=new_access_token,
refresh_token=refresh_token, # 保持原 refresh token
token_type="bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
+131
View File
@@ -0,0 +1,131 @@
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from pydantic import BaseModel, Field
from app.auth.keycloak_dependencies import get_current_keycloak_username
from app.services.burst_detection import (
get_burst_detection_scheme_detail,
list_burst_detection_schemes,
run_burst_detection,
)
router = APIRouter()
class BurstDetectionRequest(BaseModel):
"""爆管检测请求模型"""
network: str = Field(..., description="管网名称(或数据库名称)")
observed_pressure_data: (
dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]] | None
) = Field(
default=None,
description=(
"压力观测数据。支持列式字典 {sensor_id: [values,...]}、"
"逐时刻对象数组 [{sensor_id: value,...}, ...]、"
"或二维数组 [[t1_s1, t1_s2], [t2_s1, t2_s2], ...]。"
),
)
points_per_day: int = Field(1440, description="每天的数据点数")
mu: int = Field(100, description="异常值检测的参数")
iforest_params: dict[str, Any] | None = Field(None, description="隔离森林算法参数")
scada_start: datetime | None = Field(None, description="SCADA数据起始时间")
scada_end: datetime | None = Field(None, description="SCADA数据结束时间")
sensor_nodes: list[str] | None = Field(None, description="传感器节点列表")
scheme_name: str | None = Field(None, description="方案名称")
data_source: str = Field("monitoring", description="数据来源:monitoring(监测)或simulation(模拟)")
simulation_scheme_name: str | None = Field(None, description="模拟方案名称")
simulation_scheme_type: str | None = Field(None, description="模拟方案类型")
@router.post(
"/detect/",
summary="执行爆管检测",
description="基于压力观测数据和其他参数执行爆管检测分析"
)
async def detect_burst(
data: BurstDetectionRequest = Body(..., description="爆管检测请求数据"),
username: str = Depends(get_current_keycloak_username),
) -> dict[str, Any]:
"""
执行爆管检测分析。
使用异常检测算法(隔离森林)识别压力时间序列中的异常,
将其作为潜在的爆管事件。
Args:
data: 包含管网名称(或数据库名称)、压力数据及相关参数的请求体
username: 当前认证用户名
Returns:
包含检测结果的字典
Raises:
HTTPException: 当处理过程中发生错误时
"""
try:
return run_burst_detection(**data.model_dump(), username=username)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.get(
"/schemes/",
summary="查询爆管检测方案列表",
description="获取指定网络的所有爆管检测方案"
)
async def query_burst_detection_schemes(
network: str = Query(..., description="管网名称(或数据库名称)"),
query_date: datetime | None = Query(None, description="查询日期(可选)"),
) -> list[dict[str, Any]]:
"""
获取爆管检测方案列表。
查询指定网络的所有已配置的爆管检测方案,
可按日期进行筛选。
Args:
network: 管网名称(或数据库名称)
query_date: 查询日期(可选)
Returns:
爆管检测方案列表
Raises:
HTTPException: 当查询失败时
"""
try:
return list_burst_detection_schemes(network=network, query_date=query_date)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.get(
"/schemes/{scheme_name}",
summary="获取爆管检测方案详情",
description="获取指定爆管检测方案的详细信息"
)
async def query_burst_detection_scheme_detail(
network: str = Query(..., description="管网名称(或数据库名称)"),
scheme_name: str = Path(..., description="爆管检测方案名称"),
) -> dict[str, Any]:
"""
获取爆管检测方案详情。
查询指定爆管检测方案的完整配置和参数信息。
Args:
network: 管网名称(或数据库名称)
scheme_name: 爆管检测方案名称
Returns:
包含方案详情的字典
Raises:
HTTPException: 当查询失败时
"""
try:
return get_burst_detection_scheme_detail(network=network, scheme_name=scheme_name)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
+129
View File
@@ -0,0 +1,129 @@
from typing import Any
from datetime import datetime
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from pydantic import BaseModel, Field
from app.auth.keycloak_dependencies import get_current_keycloak_username
from app.services.burst_location import (
get_burst_location_scheme_detail,
list_burst_location_schemes,
run_burst_location_by_network,
)
router = APIRouter()
class BurstLocationRequest(BaseModel):
"""爆管定位请求模型"""
network: str = Field(..., description="管网名称(或数据库名称)")
data_source: Literal["monitoring", "simulation"] = Field("monitoring", description="数据来源:monitoring(监测)或simulation(模拟)")
pressure_scada_ids: list[str] | None = Field(None, description="压力SCADA传感器ID列表")
burst_pressure: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="爆管时的压力数据")
normal_pressure: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="正常时的压力数据")
burst_leakage: float = Field(..., description="爆管时的漏水量")
flow_scada_ids: list[str] | None = Field(None, description="流量SCADA传感器ID列表")
burst_flow: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="爆管时的流量数据")
normal_flow: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="正常时的流量数据")
min_dpressure: float = Field(2.0, description="最小压力差(bar")
basic_pressure: float = Field(10.0, description="基准压力(bar")
scada_burst_start: datetime | None = Field(None, description="SCADA爆管开始时间")
scada_burst_end: datetime | None = Field(None, description="SCADA爆管结束时间")
use_scada_flow: bool = Field(False, description="是否使用SCADA流量数据")
scheme_name: str | None = Field(None, description="方案名称")
simulation_scheme_name: str | None = Field(None, description="模拟方案名称")
simulation_scheme_type: str | None = Field(None, description="模拟方案类型")
@router.post(
"/locate/",
summary="执行爆管定位",
description="基于压力和流量数据定位管网中的爆管位置"
)
async def locate_burst(
data: BurstLocationRequest = Body(..., description="爆管定位请求数据"),
username: str = Depends(get_current_keycloak_username),
) -> dict[str, Any]:
"""
执行爆管定位分析。
使用压力和流量SCADA数据,通过对比爆管和正常状态下的数据差异,
定位管网中的爆管位置。
Args:
data: 包含管网名称(或数据库名称)、压力、流量数据及相关参数的请求体
username: 当前认证用户名
Returns:
包含定位结果的字典
Raises:
HTTPException: 当数据类型或值不正确时
"""
try:
return run_burst_location_by_network(**data.model_dump(), username=username)
except (TypeError, ValueError) as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.get(
"/schemes/",
summary="查询爆管定位方案列表",
description="获取指定网络的所有爆管定位方案"
)
async def query_burst_schemes(
network: str = Query(..., description="管网名称(或数据库名称)"),
query_date: datetime | None = Query(None, description="查询日期(可选)")
) -> list[dict[str, Any]]:
"""
获取爆管定位方案列表。
查询指定网络的所有已配置的爆管定位方案,
可按日期进行筛选。
Args:
network: 管网名称(或数据库名称)
query_date: 查询日期(可选)
Returns:
爆管定位方案列表
Raises:
HTTPException: 当查询失败时
"""
try:
return list_burst_location_schemes(network=network, query_date=query_date)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.get(
"/schemes/{scheme_name}",
summary="获取爆管定位方案详情",
description="获取指定爆管定位方案的详细信息"
)
async def query_burst_scheme_detail(
network: str = Query(..., description="管网名称(或数据库名称)"),
scheme_name: str = Path(..., description="爆管定位方案名称")
) -> dict[str, Any]:
"""
获取爆管定位方案详情。
查询指定爆管定位方案的完整配置和参数信息。
Args:
network: 管网名称(或数据库名称)
scheme_name: 爆管定位方案名称
Returns:
包含方案详情的字典
Raises:
HTTPException: 当查询失败时
"""
try:
return get_burst_location_scheme_detail(network=network, scheme_name=scheme_name)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
+27 -7
View File
@@ -1,16 +1,26 @@
from fastapi import APIRouter
from fastapi import APIRouter, Query
from app.infra.cache.redis_client import redis_client
router = APIRouter()
@router.post("/clearrediskey/")
async def fastapi_clear_redis_key(key: str):
@router.post("/clearrediskey/", summary="清除单个缓存键", description="根据键名清除单个Redis缓存")
async def fastapi_clear_redis_key(key: str = Query(..., description="缓存键名")):
"""
清除单个缓存键
根据指定的键名删除Redis中对应的缓存
"""
redis_client.delete(key)
return True
@router.post("/clearrediskeys/")
async def fastapi_clear_redis_keys(keys: str):
@router.post("/clearrediskeys/", summary="清除匹配的缓存键", description="根据模式清除匹配的Redis缓存键")
async def fastapi_clear_redis_keys(keys: str = Query(..., description="缓存键模式(支持通配符)")):
"""
清除匹配的缓存键
根据指定的模式删除Redis中所有匹配的缓存键
"""
# delete keys contains the key
matched_keys = redis_client.keys(f"*{keys}*")
if matched_keys:
@@ -19,14 +29,24 @@ async def fastapi_clear_redis_keys(keys: str):
return True
@router.post("/clearallredis/")
@router.post("/clearallredis/", summary="清除所有缓存", description="清空整个Redis数据库的所有缓存")
async def fastapi_clear_all_redis():
"""
清除所有缓存
清空Redis数据库中的所有缓存键值对
"""
redis_client.flushdb()
return True
@router.get("/queryredis/")
@router.get("/queryredis/", summary="查询缓存键列表", description="获取Redis中所有的缓存键")
async def fastapi_query_redis():
"""
查询缓存键列表
获取Redis数据库中所有的缓存键列表
"""
# Helper to decode bytes to str for JSON response if needed,
# but original just returned keys (which might be bytes in redis-py unless decode_responses=True)
# create_redis_client usually sets decode_responses=False by default.
+53 -14
View File
@@ -1,31 +1,70 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
get_control,
get_control_schema,
get_rule,
get_rule_schema,
set_control,
set_rule,
)
router = APIRouter()
@router.get("/getcontrolschema/")
async def fastapi_get_control_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getcontrolschema/", summary="获取控制架构", description="获取网络中控制对象的架构定义")
async def fastapi_get_control_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取控制架构。
返回指定网络中控制对象的属性架构定义。
"""
return get_control_schema(network)
@router.get("/getcontrolproperties/")
async def fastapi_get_control_properties(network: str) -> dict[str, Any]:
@router.get("/getcontrolproperties/", summary="获取控制属性", description="获取指定网络中的控制属性信息")
async def fastapi_get_control_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取控制属性。
返回指定网络中的控制对象属性信息。
"""
return get_control(network)
@router.post("/setcontrolproperties/", response_model=None)
async def fastapi_set_control_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setcontrolproperties/", response_model=None, summary="设置控制属性", description="更新指定网络中的控制属性")
async def fastapi_set_control_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置控制属性。
更新指定网络中的控制属性值。
"""
props = await req.json()
return set_control(network, ChangeSet(props))
@router.get("/getruleschema/")
async def fastapi_get_rule_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getruleschema/", summary="获取规则架构", description="获取网络中规则对象的架构定义")
async def fastapi_get_rule_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取规则架构。
返回指定网络中规则对象的属性架构定义。
"""
return get_rule_schema(network)
@router.get("/getruleproperties/")
async def fastapi_get_rule_properties(network: str) -> dict[str, Any]:
@router.get("/getruleproperties/", summary="获取规则属性", description="获取指定网络中的规则属性信息")
async def fastapi_get_rule_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取规则属性。
返回指定网络中的规则对象属性信息。
"""
return get_rule(network)
@router.post("/setruleproperties/", response_model=None)
async def fastapi_set_rule_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setruleproperties/", response_model=None, summary="设置规则属性", description="更新指定网络中的规则属性")
async def fastapi_set_rule_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置规则属性。
更新指定网络中的规则属性值。
"""
props = await req.json()
return set_rule(network, ChangeSet(props))
+69 -16
View File
@@ -1,42 +1,95 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_curve,
delete_curve,
get_curve,
get_curve_schema,
get_curves,
is_curve,
set_curve,
)
router = APIRouter()
@router.get("/getcurveschema")
async def fastapi_get_curve_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getcurveschema", summary="获取曲线架构", description="获取网络中曲线对象的架构定义")
async def fastapi_get_curve_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取曲线架构。
返回指定网络中曲线对象的属性架构定义。
"""
return get_curve_schema(network)
@router.post("/addcurve/", response_model=None)
async def fastapi_add_curve(network: str, curve: str, req: Request) -> ChangeSet:
@router.post("/addcurve/", response_model=None, summary="添加曲线", description="在网络中添加一条新的曲线")
async def fastapi_add_curve(
network: str = Query(..., description="管网名称(或数据库名称)"),
curve: str = Query(..., description="曲线ID"),
req: Request = None
) -> ChangeSet:
"""添加曲线。
在指定网络中创建一条新的曲线,并设置其初始属性。
"""
props = await req.json()
ps = {
"id": curve,
} | props
return add_curve(network, ChangeSet(ps))
@router.post("/deletecurve/", response_model=None)
async def fastapi_delete_curve(network: str, curve: str) -> ChangeSet:
@router.post("/deletecurve/", response_model=None, summary="删除曲线", description="从网络中删除指定的曲线")
async def fastapi_delete_curve(
network: str = Query(..., description="管网名称(或数据库名称)"),
curve: str = Query(..., description="曲线ID")
) -> ChangeSet:
"""删除曲线。
从指定网络中删除指定的曲线及其相关数据。
"""
ps = {"id": curve}
return delete_curve(network, ChangeSet(ps))
@router.get("/getcurveproperties/")
async def fastapi_get_curve_properties(network: str, curve: str) -> dict[str, Any]:
@router.get("/getcurveproperties/", summary="获取曲线属性", description="获取指定曲线的属性信息")
async def fastapi_get_curve_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
curve: str = Query(..., description="曲线ID")
) -> dict[str, Any]:
"""获取曲线属性。
返回指定曲线的所有属性信息。
"""
return get_curve(network, curve)
@router.post("/setcurveproperties/", response_model=None)
@router.post("/setcurveproperties/", response_model=None, summary="设置曲线属性", description="更新指定曲线的属性")
async def fastapi_set_curve_properties(
network: str, curve: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
curve: str = Query(..., description="曲线ID"),
req: Request = None
) -> ChangeSet:
"""设置曲线属性。
更新指定曲线的属性值。
"""
props = await req.json()
ps = {"id": curve} | props
return set_curve(network, ChangeSet(ps))
@router.get("/getcurves/")
async def fastapi_get_curves(network: str) -> list[str]:
@router.get("/getcurves/", summary="获取所有曲线", description="获取网络中的所有曲线列表")
async def fastapi_get_curves(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
"""获取所有曲线。
返回指定网络中的所有曲线ID列表。
"""
return get_curves(network)
@router.get("/iscurve/")
async def fastapi_is_curve(network: str, curve: str) -> bool:
@router.get("/iscurve/", summary="检查曲线存在性", description="检查指定的曲线是否存在")
async def fastapi_is_curve(
network: str = Query(..., description="管网名称(或数据库名称)"),
curve: str = Query(..., description="曲线ID")
) -> bool:
"""检查曲线是否存在。
判断指定的曲线是否在网络中存在。
"""
return is_curve(network, curve)
+103 -26
View File
@@ -1,60 +1,137 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
get_energy,
get_energy_schema,
get_option_v3,
get_option_v3_schema,
get_pump_energy,
get_pump_energy_schema,
get_time,
get_time_schema,
set_energy,
set_option_v3,
set_pump_energy,
set_time,
)
router = APIRouter()
@router.get("/gettimeschema")
async def fastapi_get_time_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/gettimeschema", summary="获取时间选项架构", description="获取网络中时间选项的架构定义")
async def fastapi_get_time_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取时间选项架构。
返回指定网络中时间相关选项的属性架构定义。
"""
return get_time_schema(network)
@router.get("/gettimeproperties/")
async def fastapi_get_time_properties(network: str) -> dict[str, Any]:
@router.get("/gettimeproperties/", summary="获取时间选项属性", description="获取指定网络中的时间选项属性信息")
async def fastapi_get_time_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取时间选项属性。
返回指定网络中的时间相关选项属性。
"""
return get_time(network)
@router.post("/settimeproperties/", response_model=None)
async def fastapi_set_time_properties(network: str, req: Request) -> ChangeSet:
@router.post("/settimeproperties/", response_model=None, summary="设置时间选项属性", description="更新指定网络中的时间选项属性")
async def fastapi_set_time_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置时间选项属性。
更新指定网络中的时间相关选项属性值。
"""
props = await req.json()
return set_time(network, ChangeSet(props))
@router.get("/getenergyschema/")
async def fastapi_get_energy_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getenergyschema/", summary="获取能耗选项架构", description="获取网络中能耗选项的架构定义")
async def fastapi_get_energy_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取能耗选项架构。
返回指定网络中能耗相关选项的属性架构定义。
"""
return get_energy_schema(network)
@router.get("/getenergyproperties/")
async def fastapi_get_energy_properties(network: str) -> dict[str, Any]:
@router.get("/getenergyproperties/", summary="获取能耗选项属性", description="获取指定网络中的能耗选项属性信息")
async def fastapi_get_energy_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取能耗选项属性。
返回指定网络中的能耗相关选项属性。
"""
return get_energy(network)
@router.post("/setenergyproperties/", response_model=None)
async def fastapi_set_energy_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setenergyproperties/", response_model=None, summary="设置能耗选项属性", description="更新指定网络中的能耗选项属性")
async def fastapi_set_energy_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置能耗选项属性。
更新指定网络中的能耗相关选项属性值。
"""
props = await req.json()
return set_energy(network, ChangeSet(props))
@router.get("/getpumpenergyschema/")
async def fastapi_get_pump_energy_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getpumpenergyschema/", summary="获取泵能耗选项架构", description="获取网络中泵能耗选项的架构定义")
async def fastapi_get_pump_energy_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取泵能耗选项架构。
返回指定网络中泵能耗相关选项的属性架构定义。
"""
return get_pump_energy_schema(network)
@router.get("/getpumpenergyproperties//")
async def fastapi_get_pump_energy_proeprties(network: str, pump: str) -> dict[str, Any]:
@router.get("/getpumpenergyproperties//", summary="获取泵能耗属性", description="获取指定泵的能耗属性信息")
async def fastapi_get_pump_energy_proeprties(
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="泵ID")
) -> dict[str, Any]:
"""获取泵能耗属性。
返回指定泵的能耗相关属性。
"""
return get_pump_energy(network, pump)
@router.get("/setpumpenergyproperties//", response_model=None)
@router.get("/setpumpenergyproperties//", response_model=None, summary="设置泵能耗属性", description="更新指定泵的能耗属性")
async def fastapi_set_pump_energy_properties(
network: str, pump: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="泵ID"),
req: Request = None
) -> ChangeSet:
"""设置泵能耗属性。
更新指定泵的能耗相关属性值。
"""
props = await req.json()
ps = {"id": pump} | props
return set_pump_energy(network, ChangeSet(ps))
@router.get("/getoptionschema/")
async def fastapi_get_option_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getoptionschema/", summary="获取选项架构", description="获取网络中选项对象的架构定义")
async def fastapi_get_option_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取选项架构。
返回指定网络中选项对象的属性架构定义。
"""
return get_option_v3_schema(network)
@router.get("/getoptionproperties/")
async def fastapi_get_option_properties(network: str) -> dict[str, Any]:
@router.get("/getoptionproperties/", summary="获取选项属性", description="获取指定网络中的选项属性信息")
async def fastapi_get_option_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取选项属性。
返回指定网络中的选项对象属性信息。
"""
return get_option_v3(network)
@router.post("/setoptionproperties/", response_model=None)
async def fastapi_set_option_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setoptionproperties/", response_model=None, summary="设置选项属性", description="更新指定网络中的选项属性")
async def fastapi_set_option_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置选项属性。
更新指定网络中的选项属性值。
"""
props = await req.json()
return set_option_v3(network, ChangeSet(props))
+69 -16
View File
@@ -1,42 +1,95 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_pattern,
delete_pattern,
get_pattern,
get_pattern_schema,
get_patterns,
is_pattern,
set_pattern,
)
router = APIRouter()
@router.get("/getpatternschema")
async def fastapi_get_pattern_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getpatternschema", summary="获取模式架构", description="获取网络中模式对象的架构定义")
async def fastapi_get_pattern_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取模式架构。
返回指定网络中模式对象的属性架构定义。
"""
return get_pattern_schema(network)
@router.post("/addpattern/", response_model=None)
async def fastapi_add_pattern(network: str, pattern: str, req: Request) -> ChangeSet:
@router.post("/addpattern/", response_model=None, summary="添加模式", description="在网络中添加一个新的模式")
async def fastapi_add_pattern(
network: str = Query(..., description="管网名称(或数据库名称)"),
pattern: str = Query(..., description="模式ID"),
req: Request = None
) -> ChangeSet:
"""添加模式。
在指定网络中创建一个新的模式,并设置其初始属性。
"""
props = await req.json()
ps = {
"id": pattern,
} | props
return add_pattern(network, ChangeSet(ps))
@router.post("/deletepattern/", response_model=None)
async def fastapi_delete_pattern(network: str, pattern: str) -> ChangeSet:
@router.post("/deletepattern/", response_model=None, summary="删除模式", description="从网络中删除指定的模式")
async def fastapi_delete_pattern(
network: str = Query(..., description="管网名称(或数据库名称)"),
pattern: str = Query(..., description="模式ID")
) -> ChangeSet:
"""删除模式。
从指定网络中删除指定的模式及其相关数据。
"""
ps = {"id": pattern}
return delete_pattern(network, ChangeSet(ps))
@router.get("/getpatternproperties/")
async def fastapi_get_pattern_properties(network: str, pattern: str) -> dict[str, Any]:
@router.get("/getpatternproperties/", summary="获取模式属性", description="获取指定模式的属性信息")
async def fastapi_get_pattern_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
pattern: str = Query(..., description="模式ID")
) -> dict[str, Any]:
"""获取模式属性。
返回指定模式的所有属性信息。
"""
return get_pattern(network, pattern)
@router.post("/setpatternproperties/", response_model=None)
@router.post("/setpatternproperties/", response_model=None, summary="设置模式属性", description="更新指定模式的属性")
async def fastapi_set_pattern_properties(
network: str, pattern: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
pattern: str = Query(..., description="模式ID"),
req: Request = None
) -> ChangeSet:
"""设置模式属性。
更新指定模式的属性值。
"""
props = await req.json()
ps = {"id": pattern} | props
return set_pattern(network, ChangeSet(ps))
@router.get("/ispattern/")
async def fastapi_is_pattern(network: str, pattern: str) -> bool:
@router.get("/ispattern/", summary="检查模式存在性", description="检查指定的模式是否存在")
async def fastapi_is_pattern(
network: str = Query(..., description="管网名称(或数据库名称)"),
pattern: str = Query(..., description="模式ID")
) -> bool:
"""检查模式是否存在。
判断指定的模式是否在网络中存在。
"""
return is_pattern(network, pattern)
@router.get("/getpatterns/")
async def fastapi_get_patterns(network: str) -> list[str]:
@router.get("/getpatterns/", summary="获取所有模式", description="获取网络中的所有模式列表")
async def fastapi_get_patterns(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
"""获取所有模式。
返回指定网络中的所有模式ID列表。
"""
return get_patterns(network)
+230 -52
View File
@@ -1,119 +1,297 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_mixing,
add_source,
api,
delete_mixing,
delete_source,
get_emitter,
get_emitter_schema,
get_mixing,
get_mixing_schema,
get_pipe_reaction,
get_pipe_reaction_schema,
get_quality,
get_quality_schema,
get_reaction,
get_reaction_schema,
get_source,
get_source_schema,
get_tank_reaction,
get_tank_reaction_schema,
set_emitter,
set_pipe_reaction,
set_quality,
set_reaction,
set_source,
set_tank_reaction,
)
router = APIRouter()
@router.get("/getqualityschema/")
async def fastapi_get_quality_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getqualityschema/", summary="获取水质架构", description="获取网络中水质对象的架构定义")
async def fastapi_get_quality_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取水质架构。
返回指定网络中水质对象的属性架构定义。
"""
return get_quality_schema(network)
@router.get("/getqualityproperties/")
async def fastapi_get_quality_properties(network: str, node: str) -> dict[str, Any]:
@router.get("/getqualityproperties/", summary="获取水质属性", description="获取指定节点的水质属性信息")
async def fastapi_get_quality_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> dict[str, Any]:
"""获取水质属性。
返回指定节点的水质属性信息。
"""
return get_quality(network, node)
@router.post("/setqualityproperties/", response_model=None)
async def fastapi_set_quality_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setqualityproperties/", response_model=None, summary="设置水质属性", description="更新指定节点的水质属性")
async def fastapi_set_quality_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置水质属性。
更新指定节点的水质属性值。
"""
props = await req.json()
return set_quality(network, ChangeSet(props))
@router.get("/getemitterschema")
async def fastapi_get_emitter_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getemitterschema", summary="获取发射器架构", description="获取网络中发射器对象的架构定义")
async def fastapi_get_emitter_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取发射器架构。
返回指定网络中发射器对象的属性架构定义。
"""
return get_emitter_schema(network)
@router.get("/getemitterproperties/")
async def fastapi_get_emitter_properties(network: str, junction: str) -> dict[str, Any]:
@router.get("/getemitterproperties/", summary="获取发射器属性", description="获取指定连接点的发射器属性信息")
async def fastapi_get_emitter_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="连接点ID")
) -> dict[str, Any]:
"""获取发射器属性。
返回指定连接点的发射器属性信息。
"""
return get_emitter(network, junction)
@router.post("/setemitterproperties/", response_model=None)
@router.post("/setemitterproperties/", response_model=None, summary="设置发射器属性", description="更新指定连接点的发射器属性")
async def fastapi_set_emitter_properties(
network: str, junction: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="连接点ID"),
req: Request = None
) -> ChangeSet:
"""设置发射器属性。
更新指定连接点的发射器属性值。
"""
props = await req.json()
ps = {"junction": junction} | props
return set_emitter(network, ChangeSet(ps))
@router.get("/getsourcechema/")
async def fastapi_get_source_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getsourcechema/", summary="获取水源架构", description="获取网络中水源对象的架构定义")
async def fastapi_get_source_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取水源架构。
返回指定网络中水源对象的属性架构定义。
"""
return get_source_schema(network)
@router.get("/getsource/")
async def fastapi_get_source(network: str, node: str) -> dict[str, Any]:
@router.get("/getsource/", summary="获取水源属性", description="获取指定节点的水源属性信息")
async def fastapi_get_source(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> dict[str, Any]:
"""获取水源属性。
返回指定节点的水源属性信息。
"""
return get_source(network, node)
@router.post("/setsource/", response_model=None)
async def fastapi_set_source(network: str, req: Request) -> ChangeSet:
@router.post("/setsource/", response_model=None, summary="设置水源属性", description="更新指定节点的水源属性")
async def fastapi_set_source(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置水源属性。
更新指定节点的水源属性值。
"""
props = await req.json()
return set_source(network, ChangeSet(props))
@router.post("/addsource/", response_model=None)
async def fastapi_add_source(network: str, req: Request) -> ChangeSet:
@router.post("/addsource/", response_model=None, summary="添加水源", description="在网络中添加一个新的水源")
async def fastapi_add_source(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加水源。
在指定网络中创建一个新的水源,并设置其初始属性。
"""
props = await req.json()
return add_source(network, ChangeSet(props))
@router.post("/deletesource/", response_model=None)
async def fastapi_delete_source(network: str, node: str) -> ChangeSet:
@router.post("/deletesource/", response_model=None, summary="删除水源", description="从网络中删除指定节点的水源")
async def fastapi_delete_source(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> ChangeSet:
"""删除水源。
从指定网络中删除指定节点的水源。
"""
props = {"node": node}
return delete_source(network, ChangeSet(props))
@router.get("/getreactionschema/")
async def fastapi_get_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getreactionschema/", summary="获取反应架构", description="获取网络中反应对象的架构定义")
async def fastapi_get_reaction_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取反应架构。
返回指定网络中反应对象的属性架构定义。
"""
return get_reaction_schema(network)
@router.get("/getreaction/")
async def fastapi_get_reaction(network: str) -> dict[str, Any]:
@router.get("/getreaction/", summary="获取反应属性", description="获取指定网络中的反应属性信息")
async def fastapi_get_reaction(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取反应属性。
返回指定网络中的反应属性信息。
"""
return get_reaction(network)
@router.post("/setreaction/", response_model=None)
async def fastapi_set_reaction(network: str, req: Request) -> ChangeSet:
@router.post("/setreaction/", response_model=None, summary="设置反应属性", description="更新指定网络中的反应属性")
async def fastapi_set_reaction(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置反应属性。
更新指定网络中的反应属性值。
"""
props = await req.json()
return set_reaction(network, ChangeSet(props))
@router.get("/getpipereactionschema/")
async def fastapi_get_pipe_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getpipereactionschema/", summary="获取管道反应架构", description="获取网络中管道反应对象的架构定义")
async def fastapi_get_pipe_reaction_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取管道反应架构。
返回指定网络中管道反应对象的属性架构定义。
"""
return get_pipe_reaction_schema(network)
@router.get("/getpipereaction/")
async def fastapi_get_pipe_reaction(network: str, pipe: str) -> dict[str, Any]:
@router.get("/getpipereaction/", summary="获取管道反应属性", description="获取指定管道的反应属性信息")
async def fastapi_get_pipe_reaction(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> dict[str, Any]:
"""获取管道反应属性。
返回指定管道的反应属性信息。
"""
return get_pipe_reaction(network, pipe)
@router.post("/setpipereaction/", response_model=None)
async def fastapi_set_pipe_reaction(network: str, req: Request) -> ChangeSet:
@router.post("/setpipereaction/", response_model=None, summary="设置管道反应属性", description="更新指定管道的反应属性")
async def fastapi_set_pipe_reaction(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置管道反应属性。
更新指定管道的反应属性值。
"""
props = await req.json()
return set_pipe_reaction(network, ChangeSet(props))
@router.get("/gettankreactionschema/")
async def fastapi_get_tank_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/gettankreactionschema/", summary="获取水池反应架构", description="获取网络中水池反应对象的架构定义")
async def fastapi_get_tank_reaction_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取水池反应架构。
返回指定网络中水池反应对象的属性架构定义。
"""
return get_tank_reaction_schema(network)
@router.get("/gettankreaction/")
async def fastapi_get_tank_reaction(network: str, tank: str) -> dict[str, Any]:
@router.get("/gettankreaction/", summary="获取水池反应属性", description="获取指定水池的反应属性信息")
async def fastapi_get_tank_reaction(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水池ID")
) -> dict[str, Any]:
"""获取水池反应属性。
返回指定水池的反应属性信息。
"""
return get_tank_reaction(network, tank)
@router.post("/settankreaction/", response_model=None)
async def fastapi_set_tank_reaction(network: str, req: Request) -> ChangeSet:
@router.post("/settankreaction/", response_model=None, summary="设置水池反应属性", description="更新指定水池的反应属性")
async def fastapi_set_tank_reaction(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置水池反应属性。
更新指定水池的反应属性值。
"""
props = await req.json()
return set_tank_reaction(network, ChangeSet(props))
@router.get("/getmixingschema/")
async def fastapi_get_mixing_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getmixingschema/", summary="获取混合架构", description="获取网络中混合对象的架构定义")
async def fastapi_get_mixing_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取混合架构。
返回指定网络中混合对象的属性架构定义。
"""
return get_mixing_schema(network)
@router.get("/getmixing/")
async def fastapi_get_mixing(network: str, tank: str) -> dict[str, Any]:
@router.get("/getmixing/", summary="获取混合属性", description="获取指定水池的混合属性信息")
async def fastapi_get_mixing(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水池ID")
) -> dict[str, Any]:
"""获取混合属性。
返回指定水池的混合属性信息。
"""
return get_mixing(network, tank)
@router.post("/setmixing/", response_model=None)
async def fastapi_set_mixing(network: str, req: Request) -> ChangeSet:
@router.post("/setmixing/", response_model=None, summary="设置混合属性", description="更新指定水池的混合属性")
async def fastapi_set_mixing(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置混合属性。
更新指定水池的混合属性值。
"""
props = await req.json()
return api.set_mixing(network, ChangeSet(props))
@router.post("/addmixing/", response_model=None)
async def fastapi_add_mixing(network: str, req: Request) -> ChangeSet:
@router.post("/addmixing/", response_model=None, summary="添加混合", description="在网络中添加一个新的混合")
async def fastapi_add_mixing(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加混合。
在指定网络中创建一个新的混合,并设置其初始属性。
"""
props = await req.json()
return add_mixing(network, ChangeSet(props))
@router.post("/deletemixing/", response_model=None)
async def fastapi_delete_mixing(network: str, req: Request) -> ChangeSet:
@router.post("/deletemixing/", response_model=None, summary="删除混合", description="从网络中删除指定的混合")
async def fastapi_delete_mixing(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""删除混合。
从指定网络中删除指定的混合及其相关数据。
"""
props = await req.json()
return delete_mixing(network, ChangeSet(props))
+136 -32
View File
@@ -1,76 +1,180 @@
from fastapi import APIRouter, Request, Response
from fastapi import APIRouter, Request, Query, Path, Body, Response
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_label,
add_vertex,
delete_label,
delete_vertex,
get_all_vertex_links,
get_all_vertices,
get_backdrop,
get_backdrop_schema,
get_label,
get_label_schema,
get_vertex,
get_vertex_schema,
set_backdrop,
set_label,
set_vertex,
)
from fastapi.responses import PlainTextResponse
import json
router = APIRouter()
@router.get("/getvertexschema/")
async def fastapi_get_vertex_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getvertexschema/", summary="获取图形元素架构", description="获取网络中图形元素对象的架构定义")
async def fastapi_get_vertex_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取图形元素架构。
返回指定网络中图形元素对象的属性架构定义。
"""
return get_vertex_schema(network)
@router.get("/getvertexproperties/")
async def fastapi_get_vertex_properties(network: str, link: str) -> dict[str, Any]:
@router.get("/getvertexproperties/", summary="获取图形元素属性", description="获取指定图形元素的属性信息")
async def fastapi_get_vertex_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="图形元素链接")
) -> dict[str, Any]:
"""获取图形元素属性。
返回指定图形元素的所有属性信息。
"""
return get_vertex(network, link)
@router.post("/setvertexproperties/", response_model=None)
async def fastapi_set_vertex_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setvertexproperties/", response_model=None, summary="设置图形元素属性", description="更新指定图形元素的属性")
async def fastapi_set_vertex_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置图形元素属性。
更新指定图形元素的属性值。
"""
props = await req.json()
return set_vertex(network, ChangeSet(props))
@router.post("/addvertex/", response_model=None)
async def fastapi_add_vertex(network: str, req: Request) -> ChangeSet:
@router.post("/addvertex/", response_model=None, summary="添加图形元素", description="在网络中添加一个新的图形元素")
async def fastapi_add_vertex(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加图形元素。
在指定网络中创建一个新的图形元素,并设置其初始属性。
"""
props = await req.json()
return add_vertex(network, ChangeSet(props))
@router.post("/deletevertex/", response_model=None)
async def fastapi_delete_vertex(network: str, req: Request) -> ChangeSet:
@router.post("/deletevertex/", response_model=None, summary="删除图形元素", description="从网络中删除指定的图形元素")
async def fastapi_delete_vertex(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""删除图形元素。
从指定网络中删除指定的图形元素及其相关数据。
"""
props = await req.json()
return delete_vertex(network, ChangeSet(props))
@router.get("/getallvertexlinks/", response_class=PlainTextResponse)
async def fastapi_get_all_vertex_links(network: str) -> list[str]:
@router.get("/getallvertexlinks/", response_class=PlainTextResponse, summary="获取所有图形元素链接", description="获取网络中的所有图形元素链接列表")
async def fastapi_get_all_vertex_links(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
"""获取所有图形元素链接。
返回指定网络中的所有图形元素链接列表。
"""
return json.dumps(get_all_vertex_links(network))
@router.get("/getallvertices/", response_class=PlainTextResponse)
async def fastapi_get_all_vertices(network: str) -> list[dict[str, Any]]:
@router.get("/getallvertices/", response_class=PlainTextResponse, summary="获取所有图形元素", description="获取网络中的所有图形元素详细信息")
async def fastapi_get_all_vertices(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[str, Any]]:
"""获取所有图形元素。
返回指定网络中的所有图形元素详细信息。
"""
return json.dumps(get_all_vertices(network))
@router.get("/getlabelschema/")
async def fastapi_get_label_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getlabelschema/", summary="获取标签架构", description="获取网络中标签对象的架构定义")
async def fastapi_get_label_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取标签架构。
返回指定网络中标签对象的属性架构定义。
"""
return get_label_schema(network)
@router.get("/getlabelproperties/")
@router.get("/getlabelproperties/", summary="获取标签属性", description="获取指定坐标处的标签属性信息")
async def fastapi_get_label_properties(
network: str, x: float, y: float
network: str = Query(..., description="管网名称(或数据库名称)"),
x: float = Query(..., description="X坐标"),
y: float = Query(..., description="Y坐标")
) -> dict[str, Any]:
"""获取标签属性。
返回指定坐标处的标签属性信息。
"""
return get_label(network, x, y)
@router.post("/setlabelproperties/", response_model=None)
async def fastapi_set_label_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setlabelproperties/", response_model=None, summary="设置标签属性", description="更新指定标签的属性")
async def fastapi_set_label_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置标签属性。
更新指定标签的属性值。
"""
props = await req.json()
return set_label(network, ChangeSet(props))
@router.post("/addlabel/", response_model=None)
async def fastapi_add_label(network: str, req: Request) -> ChangeSet:
@router.post("/addlabel/", response_model=None, summary="添加标签", description="在网络中添加一个新的标签")
async def fastapi_add_label(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加标签。
在指定网络中创建一个新的标签,并设置其初始属性。
"""
props = await req.json()
return add_label(network, ChangeSet(props))
@router.post("/deletelabel/", response_model=None)
async def fastapi_delete_label(network: str, req: Request) -> ChangeSet:
@router.post("/deletelabel/", response_model=None, summary="删除标签", description="从网络中删除指定的标签")
async def fastapi_delete_label(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""删除标签。
从指定网络中删除指定的标签及其相关数据。
"""
props = await req.json()
return delete_label(network, ChangeSet(props))
@router.get("/getbackdropschema/")
async def fastapi_get_backdrop_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getbackdropschema/", summary="获取背景架构", description="获取网络中背景对象的架构定义")
async def fastapi_get_backdrop_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""获取背景架构。
返回指定网络中背景对象的属性架构定义。
"""
return get_backdrop_schema(network)
@router.get("/getbackdropproperties/")
async def fastapi_get_backdrop_properties(network: str) -> dict[str, Any]:
@router.get("/getbackdropproperties/", summary="获取背景属性", description="获取指定网络的背景属性信息")
async def fastapi_get_backdrop_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取背景属性。
返回指定网络的背景属性信息。
"""
return get_backdrop(network)
@router.post("/setbackdropproperties/", response_model=None)
async def fastapi_set_backdrop_properties(network: str, req: Request) -> ChangeSet:
@router.post("/setbackdropproperties/", response_model=None, summary="设置背景属性", description="更新指定网络的背景属性")
async def fastapi_set_backdrop_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置背景属性。
更新指定网络的背景属性值。
"""
props = await req.json()
return set_backdrop(network, ChangeSet(props))
-388
View File
@@ -1,388 +0,0 @@
from typing import Any, List, Dict, Optional
import logging
from datetime import datetime, timedelta, timezone, time as dt_time
import msgpack
from fastapi import APIRouter
from pydantic import BaseModel
from py_linq import Enumerable
import app.infra.db.influxdb.api as influxdb_api
import app.services.time_api as time_api
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
router = APIRouter()
logger = logging.getLogger(__name__)
# Basic Node/Link Latest Record Queries
@router.get("/querynodelatestrecordbyid/")
async def fastapi_query_node_latest_record_by_id(id: str) -> Any:
return influxdb_api.query_latest_record_by_ID(id, type="node")
@router.get("/querylinklatestrecordbyid/")
async def fastapi_query_link_latest_record_by_id(id: str) -> Any:
return influxdb_api.query_latest_record_by_ID(id, type="link")
@router.get("/queryscadalatestrecordbyid/")
async def fastapi_query_scada_latest_record_by_id(id: str) -> Any:
return influxdb_api.query_latest_record_by_ID(id, type="scada")
# Time-based Queries
@router.get("/queryallrecordsbytime/")
async def fastapi_query_all_records_by_time(querytime: str) -> dict[str, list]:
results: tuple = influxdb_api.query_all_records_by_time(query_time=querytime)
return {"nodes": results[0], "links": results[1]}
@router.get("/queryallrecordsbytimeproperty/")
async def fastapi_query_all_record_by_time_property(
querytime: str, type: str, property: str, bucket: str = "realtime_simulation_result"
) -> dict[str, list]:
results: tuple = influxdb_api.query_all_record_by_time_property(
query_time=querytime, type=type, property=property, bucket=bucket
)
return {"results": results}
@router.get("/queryallschemerecordsbytimeproperty/")
async def fastapi_query_all_scheme_record_by_time_property(
querytime: str,
type: str,
property: str,
schemename: str,
bucket: str = "scheme_simulation_result",
) -> dict[str, list]:
"""
查询指定方案某一时刻的所有记录,查询 'node''link' 的某一属性值
"""
results: list = influxdb_api.query_all_scheme_record_by_time_property(
query_time=querytime,
type=type,
property=property,
scheme_name=schemename,
bucket=bucket,
)
return {"results": results}
@router.get("/querysimulationrecordsbyidtime/")
async def fastapi_query_simulation_record_by_ids_time(
id: str, querytime: str, type: str, bucket: str = "realtime_simulation_result"
) -> dict[str, list]:
results: tuple = influxdb_api.query_simulation_result_by_ID_time(
ID=id, type=type, query_time=querytime, bucket=bucket
)
return {"results": results}
@router.get("/queryschemesimulationrecordsbyidtime/")
async def fastapi_query_scheme_simulation_record_by_ids_time(
scheme_name: str,
id: str,
querytime: str,
type: str,
bucket: str = "scheme_simulation_result",
) -> dict[str, list]:
results: tuple = influxdb_api.query_scheme_simulation_result_by_ID_time(
scheme_name=scheme_name, ID=id, type=type, query_time=querytime, bucket=bucket
)
return {"results": results}
# Date-based Queries with Caching
@router.get("/queryallrecordsbydate/")
async def fastapi_query_all_records_by_date(querydate: str) -> dict:
is_today_or_future = time_api.is_today_or_future(querydate)
logger.info(f"isToday or future: {is_today_or_future}")
cache_key = f"queryallrecordsbydate_{querydate}"
if not is_today_or_future:
data = redis_client.get(cache_key)
if data:
results = msgpack.unpackb(data, object_hook=decode_datetime)
logger.info("return from cache redis")
return results
logger.info("query from influxdb")
nodes_links: tuple = influxdb_api.query_all_records_by_date(query_date=querydate)
results = {"nodes": nodes_links[0], "links": nodes_links[1]}
if not is_today_or_future:
logger.info("save to cache redis")
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
logger.info("return results")
return results
@router.get("/queryallrecordsbytimerange/")
async def fastapi_query_all_records_by_time_range(
starttime: str, endtime: str
) -> dict[str, list]:
cache_key = f"queryallrecordsbytimerange_{starttime}_{endtime}"
if not time_api.is_today_or_future(starttime):
data = redis_client.get(cache_key)
if data:
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
return loaded_dict
nodes_links: tuple = influxdb_api.query_all_records_by_time_range(
starttime=starttime, endtime=endtime
)
results = {"nodes": nodes_links[0], "links": nodes_links[1]}
if not time_api.is_today_or_future(starttime):
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
return results
@router.get("/queryallrecordsbydatewithtype/")
async def fastapi_query_all_records_by_date_with_type(
querydate: str, querytype: str
) -> list:
cache_key = f"queryallrecordsbydatewithtype_{querydate}_{querytype}"
data = redis_client.get(cache_key)
if data:
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
return loaded_dict
results = influxdb_api.query_all_records_by_date_with_type(
query_date=querydate, query_type=querytype
)
packed = msgpack.packb(results, default=encode_datetime)
redis_client.set(cache_key, packed)
return results
@router.get("/queryallrecordsbyidsdatetype/")
async def fastapi_query_all_records_by_ids_date_type(
ids: str, querydate: str, querytype: str
) -> list:
cache_key = f"queryallrecordsbydatewithtype_{querydate}_{querytype}"
data = redis_client.get(cache_key)
results = []
if data:
results = msgpack.unpackb(data, object_hook=decode_datetime)
else:
results = influxdb_api.query_all_records_by_date_with_type(
query_date=querydate, query_type=querytype
)
packed = msgpack.packb(results, default=encode_datetime)
redis_client.set(cache_key, packed)
query_ids = ids.split(",")
# Using Enumerable from py_linq as in original code
e_results = Enumerable(results)
lst_results = e_results.where(lambda x: x["ID"] in query_ids).to_list()
return lst_results
@router.get("/queryallrecordsbydateproperty/")
async def fastapi_query_all_records_by_date_property(
querydate: str, querytype: str, property: str
) -> list[dict]:
cache_key = f"queryallrecordsbydateproperty_{querydate}_{querytype}_{property}"
data = redis_client.get(cache_key)
if data:
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
return loaded_dict
result_dict = influxdb_api.query_all_record_by_date_property(
query_date=querydate, type=querytype, property=property
)
packed = msgpack.packb(result_dict, default=encode_datetime)
redis_client.set(cache_key, packed)
return result_dict
# Curve Queries
@router.get("/querynodecurvebyidpropertydaterange/")
async def fastapi_query_node_curve_by_id_property_daterange(
id: str, prop: str, startdate: str, enddate: str
):
return influxdb_api.query_curve_by_ID_property_daterange(
id, type="node", property=prop, start_date=startdate, end_date=enddate
)
@router.get("/querylinkcurvebyidpropertydaterange/")
async def fastapi_query_link_curve_by_id_property_daterange(
id: str, prop: str, startdate: str, enddate: str
):
return influxdb_api.query_curve_by_ID_property_daterange(
id, type="link", property=prop, start_date=startdate, end_date=enddate
)
# SCADA Data Queries
@router.get("/queryscadadatabydeviceidandtime/")
async def fastapi_query_scada_data_by_device_id_and_time(ids: str, querytime: str):
query_ids = ids.split(",")
logger.info(querytime)
return influxdb_api.query_SCADA_data_by_device_ID_and_time(
query_ids_list=query_ids, query_time=querytime
)
@router.get("/queryscadadatabydeviceidandtimerange/")
async def fastapi_query_scada_data_by_device_id_and_time_range(
ids: str, starttime: str, endtime: str
):
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
query_ids = ids.split(",")
return influxdb_api.query_SCADA_data_by_device_ID_and_timerange(
query_ids_list=query_ids, start_time=starttime, end_time=endtime
)
@router.get("/queryfillingscadadatabydeviceidandtimerange/")
async def fastapi_query_filling_scada_data_by_device_id_and_time_range(
ids: str, starttime: str, endtime: str
):
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
query_ids = ids.split(",")
return influxdb_api.query_filling_SCADA_data_by_device_ID_and_timerange(
query_ids_list=query_ids, start_time=starttime, end_time=endtime
)
@router.get("/querycleaningscadadatabydeviceidandtimerange/")
async def fastapi_query_cleaning_scada_data_by_device_id_and_time_range(
ids: str, starttime: str, endtime: str
):
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
query_ids = ids.split(",")
return influxdb_api.query_cleaning_SCADA_data_by_device_ID_and_timerange(
query_ids_list=query_ids, start_time=starttime, end_time=endtime
)
@router.get("/querysimulationscadadatabydeviceidandtimerange/")
async def fastapi_query_simulation_scada_data_by_device_id_and_time_range(
ids: str, starttime: str, endtime: str
):
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
query_ids = ids.split(",")
return influxdb_api.query_simulation_SCADA_data_by_device_ID_and_timerange(
query_ids_list=query_ids, start_time=starttime, end_time=endtime
)
@router.get("/querycleanedscadadatabydeviceidandtimerange/")
async def fastapi_query_cleaned_scada_data_by_device_id_and_time_range(
ids: str, starttime: str, endtime: str
):
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
query_ids = ids.split(",")
return influxdb_api.query_cleaned_SCADA_data_by_device_ID_and_timerange(
query_ids_list=query_ids, start_time=starttime, end_time=endtime
)
@router.get("/queryscadadatabydeviceidanddate/")
async def fastapi_query_scada_data_by_device_id_and_date(ids: str, querydate: str):
query_ids = ids.split(",")
return influxdb_api.query_SCADA_data_by_device_ID_and_date(
query_ids_list=query_ids, query_date=querydate
)
@router.get("/queryallscadarecordsbydate/")
async def fastapi_query_all_scada_records_by_date(querydate: str):
is_today_or_future = time_api.is_today_or_future(querydate)
logger.info(f"isToday or future: {is_today_or_future}")
cache_key = f"queryallscadarecordsbydate_{querydate}"
if not is_today_or_future:
data = redis_client.get(cache_key)
if data:
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
logger.info("return from cache redis")
return loaded_dict
logger.info("query from influxdb")
result_dict = influxdb_api.query_all_SCADA_records_by_date(query_date=querydate)
if not is_today_or_future:
logger.info("save to cache redis")
packed = msgpack.packb(result_dict, default=encode_datetime)
redis_client.set(cache_key, packed)
logger.info("return results")
return result_dict
@router.get("/queryallschemeallrecords/")
async def fastapi_query_all_scheme_all_records(
schemetype: str, schemename: str, querydate: str
) -> tuple:
cache_key = f"queryallschemeallrecords_{schemetype}_{schemename}_{querydate}"
data = redis_client.get(cache_key)
if data:
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
return loaded_dict
results = influxdb_api.query_scheme_all_record(
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
)
packed = msgpack.packb(results, default=encode_datetime)
redis_client.set(cache_key, packed)
return results
@router.get("/queryschemeallrecordsproperty/")
async def fastapi_query_all_scheme_all_records_property(
schemetype: str, schemename: str, querydate: str, querytype: str, queryproperty: str
) -> Optional[List]:
cache_key = f"queryallschemeallrecords_{schemetype}_{schemename}_{querydate}"
data = redis_client.get(cache_key)
all_results = None
if data:
all_results = msgpack.unpackb(data, object_hook=decode_datetime)
else:
all_results = influxdb_api.query_scheme_all_record(
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
)
packed = msgpack.packb(all_results, default=encode_datetime)
redis_client.set(cache_key, packed)
results = None
if querytype == "node":
results = all_results[0]
elif querytype == "link":
results = all_results[1]
return results
@router.get("/queryinfluxdbbuckets/")
async def fastapi_query_influxdb_buckets():
return influxdb_api.query_buckets()
@router.get("/queryinfluxdbbucketmeasurements/")
async def fastapi_query_influxdb_bucket_measurements(bucket: str):
return influxdb_api.query_measurements(bucket=bucket)
############################################################
# download history data
############################################################
class Download_History_Data_Manually(BaseModel):
"""
download_date:样式如 datetime(2025, 5, 4)
"""
download_date: datetime
@router.post("/download_history_data_manually/")
async def fastapi_download_history_data_manually(
data: Download_History_Data_Manually,
) -> None:
item = data.dict()
tz = timezone(timedelta(hours=8))
begin_dt = datetime.combine(item.get("download_date").date(), dt_time.min).replace(
tzinfo=tz
)
end_dt = datetime.combine(item.get("download_date").date(), dt_time(23, 59, 59)).replace(
tzinfo=tz
)
begin_time = begin_dt.isoformat()
end_time = end_dt.isoformat()
influxdb_api.download_history_data_manually(
begin_time=begin_time, end_time=end_time
)
+83 -10
View File
@@ -1,7 +1,7 @@
from typing import List, Any
from fastapi import APIRouter, Request, HTTPException
from app.native.api import ChangeSet
from fastapi import APIRouter, Request, HTTPException, Query, Body
from app.services.tjnetwork import (
ChangeSet,
get_all_extension_data_keys,
get_all_extension_data,
get_extension_data,
@@ -10,20 +10,93 @@ from app.services.tjnetwork import (
router = APIRouter()
@router.get("/getallextensiondatakeys/")
async def get_all_extension_data_keys_endpoint(network: str) -> list[str]:
@router.get(
"/getallextensiondatakeys/",
summary="获取所有扩展数据键",
description="获取指定网络的所有扩展数据的键列表"
)
async def get_all_extension_data_keys_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[str]:
"""
获取所有扩展数据键。
返回指定网络中所有可用的扩展数据键。
Args:
network: 管网名称(或数据库名称)
Returns:
扩展数据键列表
"""
return get_all_extension_data_keys(network)
@router.get("/getallextensiondata/")
async def get_all_extension_data_endpoint(network: str) -> dict[str, Any]:
@router.get(
"/getallextensiondata/",
summary="获取所有扩展数据",
description="获取指定网络的所有扩展数据"
)
async def get_all_extension_data_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, Any]:
"""
获取所有扩展数据。
返回指定网络的所有扩展数据及其值。
Args:
network: 管网名称(或数据库名称)
Returns:
扩展数据字典
"""
return get_all_extension_data(network)
@router.get("/getextensiondata/")
async def get_extension_data_endpoint(network: str, key: str) -> str | None:
@router.get(
"/getextensiondata/",
summary="获取指定扩展数据",
description="获取指定网络中指定键的扩展数据值"
)
async def get_extension_data_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
key: str = Query(..., description="扩展数据键")
) -> str | None:
"""
获取指定扩展数据。
返回指定网络中指定键对应的扩展数据值。
Args:
network: 管网名称(或数据库名称)
key: 扩展数据键
Returns:
扩展数据值,如果不存在返回None
"""
return get_extension_data(network, key)
@router.post("/setextensiondata/", response_model=None)
async def set_extension_data_endpoint(network: str, req: Request) -> ChangeSet:
@router.post(
"/setextensiondata/",
response_model=None,
summary="设置扩展数据",
description="设置指定网络中的扩展数据"
)
async def set_extension_data_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
设置扩展数据。
在指定网络中设置扩展数据,并返回变更集信息。
Args:
network: 管网名称(或数据库名称)
req: 包含扩展数据的请求体
Returns:
变更集信息
"""
props = await req.json()
print(props)
cs = set_extension_data(network, ChangeSet(props))
+133
View File
@@ -0,0 +1,133 @@
import os
from typing import Any
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from pydantic import BaseModel, Field
from app.auth.keycloak_dependencies import get_current_keycloak_username
from app.services.leakage_identifier import (
get_leakage_identify_scheme_detail,
list_leakage_identify_schemes,
run_leakage_identification,
)
router = APIRouter()
DEFAULT_N_WORKERS = max(1, min((os.cpu_count() or 1) - 1, 4))
class LeakageIdentifyRequest(BaseModel):
"""漏损识别请求模型"""
network: str = Field(..., description="管网名称(或数据库名称)")
observed_pressure_data: str | dict[str, list[Any]] | list[dict[str, Any]] | None = Field(
None, description="观测的压力数据"
)
start_time: float = Field(0, description="起始时间(小时)")
duration: float = Field(24, description="持续时间(小时)")
timestep: float = Field(5, description="时间步长(分钟)")
q_sum: float = Field(0.2, description="总流量(m3/s")
q_sum_unit: str = Field("m3/s", description="流量单位")
output_dir: str = Field("db_inp", description="输出目录")
pop_size: int = Field(50, description="种群大小")
max_gen: int = Field(100, description="最大代数")
n_workers: int = Field(DEFAULT_N_WORKERS, description="工作线程数")
output_flow_unit: str = Field("m3/s", description="输出流量单位")
dma_count: int | None = Field(None, description="DMA区域数量")
scada_start: datetime | None = Field(None, description="SCADA数据起始时间")
scada_end: datetime | None = Field(None, description="SCADA数据结束时间")
sensor_nodes: list[str] | None = Field(None, description="传感器节点列表")
scheme_name: str | None = Field(None, description="方案名称")
@router.post(
"/identify/",
summary="执行漏损识别",
description="基于压力观测数据和遗传算法识别管网中的漏损位置和大小"
)
async def identify_leakage(
data: LeakageIdentifyRequest = Body(..., description="漏损识别请求数据"),
username: str = Depends(get_current_keycloak_username),
) -> dict[str, Any]:
"""
执行漏损识别分析。
使用遗传算法对比模型计算和实测压力数据,
识别管网中的漏损节点和漏水量。
Args:
data: 包含管网名称(或数据库名称)、压力数据及优化参数的请求体
username: 当前认证用户名
Returns:
包含识别结果的字典
Raises:
HTTPException: 当处理过程中发生错误时
"""
try:
return run_leakage_identification(**data.model_dump(), username=username)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.get(
"/schemes/",
summary="查询漏损识别方案列表",
description="获取指定网络的所有漏损识别方案"
)
async def query_leakage_schemes(
network: str = Query(..., description="管网名称(或数据库名称)"),
query_date: datetime | None = Query(None, description="查询日期(可选)")
) -> list[dict[str, Any]]:
"""
获取漏损识别方案列表。
查询指定网络的所有已配置的漏损识别方案,
可按日期进行筛选。
Args:
network: 管网名称(或数据库名称)
query_date: 查询日期(可选)
Returns:
漏损识别方案列表
Raises:
HTTPException: 当查询失败时
"""
try:
return list_leakage_identify_schemes(network=network, query_date=query_date)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.get(
"/schemes/{scheme_name}",
summary="获取漏损识别方案详情",
description="获取指定漏损识别方案的详细信息"
)
async def query_leakage_scheme_detail(
network: str = Query(..., description="管网名称(或数据库名称)"),
scheme_name: str = Path(..., description="漏损识别方案名称")
) -> dict[str, Any]:
"""
获取漏损识别方案详情。
查询指定漏损识别方案的完整配置和参数信息。
Args:
network: 管网名称(或数据库名称)
scheme_name: 漏损识别方案名称
Returns:
包含方案详情的字典
Raises:
HTTPException: 当查询失败时
"""
try:
return get_leakage_identify_scheme_detail(
network=network, scheme_name=scheme_name
)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
+41 -8
View File
@@ -1,5 +1,6 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
import psycopg
from psycopg import AsyncConnection
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
@@ -19,17 +20,22 @@ from app.domain.schemas.metadata import (
ProjectMetaResponse,
ProjectSummaryResponse,
)
from app.infra.repositories.metadata_repository import MetadataRepository
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/meta/project", response_model=ProjectMetaResponse)
@router.get("/meta/project", summary="获取项目元数据", description="获取当前项目的元数据和配置信息", response_model=ProjectMetaResponse)
async def get_project_metadata(
ctx: ProjectContext = Depends(get_project_context),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
"""
获取项目元数据
返回当前项目的完整元数据,包括项目基本信息和GeoServer配置
"""
project = await metadata_repo.get_project_by_id(ctx.project_id)
if not project:
raise HTTPException(
@@ -53,17 +59,23 @@ async def get_project_metadata(
code=project.code,
description=project.description,
gs_workspace=project.gs_workspace,
map_extent=project.map_extent,
status=project.status,
project_role=ctx.project_role,
geoserver=geoserver_payload,
)
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
@router.get("/meta/projects", summary="列出用户项目", description="获取当前用户有权限的所有项目列表", response_model=list[ProjectSummaryResponse])
async def list_user_projects(
current_user=Depends(get_current_metadata_user),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
"""
列出用户的所有项目
返回当前用户有权限访问的项目摘要列表
"""
try:
projects = await metadata_repo.list_projects_for_user(current_user.id)
except SQLAlchemyError as exc:
@@ -90,12 +102,33 @@ async def list_user_projects(
]
@router.get("/meta/db/health")
@router.get("/meta/db/health", summary="检查数据库健康状态", description="检查项目数据库连接的健康状况")
async def project_db_health(
pg_session: AsyncSession = Depends(get_project_pg_session),
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
):
await pg_session.execute(text("SELECT 1"))
async with ts_conn.cursor() as cur:
await cur.execute("SELECT 1")
"""
检查数据库健康状态
检查PostgreSQL和TimescaleDB数据库的连接状态
"""
try:
await pg_session.execute(text("SELECT 1"))
except SQLAlchemyError as exc:
logger.error("Project PostgreSQL health check failed", exc_info=True)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project PostgreSQL health check failed: {exc}",
) from exc
try:
async with ts_conn.cursor() as cur:
await cur.execute("SELECT 1")
except psycopg.Error as exc:
logger.error("Project TimescaleDB health check failed", exc_info=True)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Project TimescaleDB health check failed: {exc}",
) from exc
return {"postgres": "ok", "timescale": "ok"}
+28 -19
View File
@@ -1,6 +1,5 @@
from typing import Any
import random
from fastapi import APIRouter
from fastapi import APIRouter, Query
from fastapi.responses import JSONResponse
from fastapi import status
from pydantic import BaseModel
@@ -12,8 +11,13 @@ from app.services.tjnetwork import (
router = APIRouter()
@router.get("/getjson/")
@router.get("/getjson/", summary="获取JSON示例", description="获取JSON格式响应示例")
async def fastapi_get_json():
"""
获取JSON示例
返回示例JSON格式的响应
"""
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
@@ -24,32 +28,37 @@ async def fastapi_get_json():
)
@router.get("/getallsensorplacements/")
async def fastapi_get_all_sensor_placements(network: str) -> list[dict[Any, Any]]:
@router.get("/getallsensorplacements/", summary="获取所有传感器位置", description="获取网络中所有传感器的放置位置信息")
async def fastapi_get_all_sensor_placements(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
"""
获取所有传感器位置
返回网络中所有传感器的放置位置及其配置信息
"""
return get_all_sensor_placements(network)
@router.get("/getallburstlocateresults/")
async def fastapi_get_all_burst_locate_results(network: str) -> list[dict[Any, Any]]:
@router.get("/getallburstlocateresults/", summary="获取所有爆管定位结果", description="获取网络中所有爆管定位的分析结果")
async def fastapi_get_all_burst_locate_results(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
"""
获取所有爆管定位结果
返回网络中所有的爆管定位分析结果
"""
return get_all_burst_locate_results(network)
class Item(BaseModel):
"""测试数据模型"""
str_info: str
@router.post("/test_dict/")
@router.post("/test_dict/", summary="测试字典处理", description="测试处理字典类型数据")
async def fastapi_test_dict(data: Item) -> dict[str, str]:
"""
测试字典处理
接收Item模型,返回其字典格式
"""
item = data.dict()
return item
@router.get("/getrealtimedata/")
async def fastapi_get_realtimedata():
data = [random.randint(0, 100) for _ in range(100)]
return data
@router.get("/getsimulationresult/")
async def fastapi_get_simulationresult():
data = [random.randint(0, 100) for _ in range(100)]
return data
+98 -14
View File
@@ -1,6 +1,15 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
calculate_demand_to_network,
calculate_demand_to_nodes,
calculate_demand_to_region,
get_demand,
get_demand_schema,
set_demand,
)
router = APIRouter()
@@ -8,21 +17,54 @@ router = APIRouter()
# demand 9.[DEMANDS]
############################################################
@router.get("/getdemandschema")
async def fastapi_get_demand_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/getdemandschema",
summary="获取需水量属性架构",
description="获取指定水网中需水量(Demand)的属性架构定义"
)
async def fastapi_get_demand_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""
获取需水量属性架构。
返回指定水网的需水量属性架构,包括所有可配置的属性及其类型定义。
"""
return get_demand_schema(network)
@router.get("/getdemandproperties/")
async def fastapi_get_demand_properties(network: str, junction: str) -> dict[str, Any]:
@router.get(
"/getdemandproperties/",
summary="获取需水量属性",
description="获取指定水网中节点的需水量属性信息"
)
async def fastapi_get_demand_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点ID")
) -> dict[str, Any]:
"""
获取节点的需水量属性。
返回指定节点的所有需水量信息,包括需水量值、水压等级等。
"""
return get_demand(network, junction)
# example: set_demand(p, ChangeSet({'junction': 'j1', 'demands': [{'demand': 10.0, 'pattern': None, 'category': 'x'}, {'demand': 20.0, 'pattern': None, 'category': None}]}))
@router.post("/setdemandproperties/", response_model=None)
@router.post(
"/setdemandproperties/",
response_model=None,
summary="设置需水量属性",
description="设置指定水网中节点的需水量属性信息"
)
async def fastapi_set_demand_properties(
network: str, junction: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点ID"),
req: Request = None
) -> ChangeSet:
"""
设置节点的需水量属性。
修改指定节点的需水量信息。请求体应包含需水量值、水压等级等属性。
"""
props = await req.json()
ps = {"junction": junction} | props
return set_demand(network, ChangeSet(ps))
@@ -30,26 +72,68 @@ async def fastapi_set_demand_properties(
############################################################
# water distribution 36.[Water Distribution]
############################################################
@router.get("/calculatedemandtonodes/")
@router.get(
"/calculatedemandtonodes/",
summary="计算需水量到节点分配",
description="将总需水量按指定方式分配到多个节点"
)
async def fastapi_calculate_demand_to_nodes(
network: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> dict[str, float]:
"""
计算需水量到节点分配。
将指定的总需水量均匀或按比例分配到指定的节点列表中。
请求体格式:
{
"demand": 需水量值(float),
"nodes": 节点ID列表(list[str])
}
"""
props = await req.json()
demand = props["demand"]
nodes = props["nodes"]
return calculate_demand_to_nodes(network, demand, nodes)
@router.get("/calculatedemandtoregion/")
@router.get(
"/calculatedemandtoregion/",
summary="计算需水量到区域分配",
description="将总需水量按区域特征分配到该区域内的节点"
)
async def fastapi_calculate_demand_to_region(
network: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> dict[str, float]:
"""
计算需水量到区域分配。
根据区域内节点的特征(如面积、人口等)将总需水量分配到该区域的各个节点。
请求体格式:
{
"demand": 需水量值(float),
"region": 区域ID(str)
}
"""
props = await req.json()
demand = props["demand"]
region = props["region"]
return calculate_demand_to_region(network, demand, region)
@router.get("/calculatedemandtonetwork/")
@router.get(
"/calculatedemandtonetwork/",
summary="计算需水量到整网分配",
description="将需水量均匀分配到整个水网的所有需水节点"
)
async def fastapi_calculate_demand_to_network(
network: str, demand: float
network: str = Query(..., description="管网名称(或数据库名称)"),
demand: float = Query(..., description="总需水量(m³/h)", gt=0)
) -> dict[str, float]:
"""
计算需水量到整网分配。
将指定的需水量均匀分配到整个水网的所有需水节点。
"""
return calculate_demand_to_network(network, demand)
+318 -60
View File
@@ -1,6 +1,42 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
delete_junction,
delete_pipe,
delete_pump,
delete_reservoir,
delete_tank,
delete_valve,
get_all_scada_info,
get_element_properties,
get_element_properties_with_type,
get_element_type,
get_element_type_value,
get_link_properties,
get_link_type,
get_links,
get_node_links,
get_node_properties,
get_node_type,
get_nodes,
get_scada_info,
get_status,
get_status_schema,
get_title,
get_title_schema,
is_junction,
is_link,
is_node,
is_pipe,
is_pump,
is_reservoir,
is_tank,
is_valve,
set_status,
set_title,
)
router = APIRouter()
@@ -8,110 +44,291 @@ router = APIRouter()
# type
############################################################
@router.get("/isnode/")
async def fastapi_is_node(network: str, node: str) -> bool:
@router.get(
"/isnode/",
summary="检查节点有效性",
description="检查指定ID是否为水网中的有效节点"
)
async def fastapi_is_node(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> bool:
"""检查指定ID是否为节点。"""
return is_node(network, node)
@router.get("/isjunction/")
async def fastapi_is_junction(network: str, node: str) -> bool:
@router.get(
"/isjunction/",
summary="检查是否为接点",
description="检查指定ID是否为水网中的接点(需求点)"
)
async def fastapi_is_junction(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> bool:
"""检查指定ID是否为接点。"""
return is_junction(network, node)
@router.get("/isreservoir/")
async def fastapi_is_reservoir(network: str, node: str) -> bool:
@router.get(
"/isreservoir/",
summary="检查是否为水源",
description="检查指定ID是否为水网中的水源(水库/河流)"
)
async def fastapi_is_reservoir(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> bool:
"""检查指定ID是否为水源。"""
return is_reservoir(network, node)
@router.get("/istank/")
async def fastapi_is_tank(network: str, node: str) -> bool:
@router.get(
"/istank/",
summary="检查是否为蓄水池",
description="检查指定ID是否为水网中的蓄水池"
)
async def fastapi_is_tank(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> bool:
"""检查指定ID是否为蓄水池。"""
return is_tank(network, node)
@router.get("/islink/")
async def fastapi_is_link(network: str, link: str) -> bool:
@router.get(
"/islink/",
summary="检查管线有效性",
description="检查指定ID是否为水网中的有效管线"
)
async def fastapi_is_link(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> bool:
"""检查指定ID是否为管线。"""
return is_link(network, link)
@router.get("/ispipe/")
async def fastapi_is_pipe(network: str, link: str) -> bool:
@router.get(
"/ispipe/",
summary="检查是否为管道",
description="检查指定ID是否为水网中的管道"
)
async def fastapi_is_pipe(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> bool:
"""检查指定ID是否为管道。"""
return is_pipe(network, link)
@router.get("/ispump/")
async def fastapi_is_pump(network: str, link: str) -> bool:
@router.get(
"/ispump/",
summary="检查是否为泵",
description="检查指定ID是否为水网中的泵"
)
async def fastapi_is_pump(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> bool:
"""检查指定ID是否为泵。"""
return is_pump(network, link)
@router.get("/isvalve/")
async def fastapi_is_valve(network: str, link: str) -> bool:
@router.get(
"/isvalve/",
summary="检查是否为阀门",
description="检查指定ID是否为水网中的阀门"
)
async def fastapi_is_valve(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> bool:
"""检查指定ID是否为阀门。"""
return is_valve(network, link)
@router.get("/getnodetype/")
async def fastapi_get_node_type(network: str, node: str) -> str:
@router.get(
"/getnodetype/",
summary="获取节点类型",
description="获取指定节点的类型(接点/水源/蓄水池)"
)
async def fastapi_get_node_type(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> str:
"""获取节点的类型标识。"""
return get_node_type(network, node)
@router.get("/getlinktype/")
async def fastapi_get_link_type(network: str, link: str) -> str:
@router.get(
"/getlinktype/",
summary="获取管线类型",
description="获取指定管线的类型(管道/泵/阀门)"
)
async def fastapi_get_link_type(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> str:
"""获取管线的类型标识。"""
return get_link_type(network, link)
@router.get("/getelementtype/")
async def fastapi_get_element_type(network: str, element: str) -> str:
@router.get(
"/getelementtype/",
summary="获取元素类型",
description="获取指定元素的类型(节点或管线)"
)
async def fastapi_get_element_type(
network: str = Query(..., description="管网名称(或数据库名称)"),
element: str = Query(..., description="元素ID")
) -> str:
"""获取元素的类型标识。"""
return get_element_type(network, element)
@router.get("/getelementtypevalue/")
async def fastapi_get_element_type_value(network: str, element: str) -> int:
@router.get(
"/getelementtypevalue/",
summary="获取元素类型值",
description="获取指定元素的类型数值标识"
)
async def fastapi_get_element_type_value(
network: str = Query(..., description="管网名称(或数据库名称)"),
element: str = Query(..., description="元素ID")
) -> int:
"""获取元素的类型数值。"""
return get_element_type_value(network, element)
@router.get("/getnodes/")
async def fastapi_get_nodes(network: str) -> list[str]:
@router.get(
"/getnodes/",
summary="获取所有节点",
description="获取指定水网中的所有节点ID列表"
)
async def fastapi_get_nodes(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
"""获取水网中所有节点的ID列表。"""
return get_nodes(network)
@router.get("/getlinks/")
async def fastapi_get_links(network: str) -> list[str]:
@router.get(
"/getlinks/",
summary="获取所有管线",
description="获取指定水网中的所有管线ID列表"
)
async def fastapi_get_links(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
"""获取水网中所有管线的ID列表。"""
return get_links(network)
@router.get("/getnodelinks/")
def get_node_links_endpoint(network: str, node: str) -> list[str]:
@router.get(
"/getnodelinks/",
summary="获取节点的关联管线",
description="获取指定节点连接的所有管线ID列表"
)
def get_node_links_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> list[str]:
"""获取节点关联的所有管线。"""
return get_node_links(network, node)
############################################################
# Node & Link properties
############################################################
@router.get("/getnodeproperties/")
async def fast_get_node_properties(network: str, node: str) -> dict[str, Any]:
@router.get(
"/getnodeproperties/",
summary="获取节点属性",
description="获取指定节点的所有属性信息"
)
async def fast_get_node_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> dict[str, Any]:
"""获取节点的完整属性信息。"""
return get_node_properties(network, node)
@router.get("/getlinkproperties/")
async def fast_get_link_properties(network: str, link: str) -> dict[str, Any]:
@router.get(
"/getlinkproperties/",
summary="获取管线属性",
description="获取指定管线的所有属性信息"
)
async def fast_get_link_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> dict[str, Any]:
"""获取管线的完整属性信息。"""
return get_link_properties(network, link)
@router.get("/getscadaproperties/")
async def fast_get_scada_properties(network: str, scada: str) -> dict[str, Any]:
@router.get(
"/getscadaproperties/",
summary="获取SCADA点属性",
description="获取指定SCADA点的属性信息"
)
async def fast_get_scada_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
scada: str = Query(..., description="SCADA点ID")
) -> dict[str, Any]:
"""获取SCADA点的属性信息。"""
return get_scada_info(network, scada)
@router.get("/getallscadaproperties/")
async def fast_get_all_scada_properties(network: str) -> list[dict[str, Any]]:
@router.get(
"/getallscadaproperties/",
summary="获取所有SCADA点属性",
description="获取指定水网中所有SCADA点的属性信息"
)
async def fast_get_all_scada_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""获取水网中所有SCADA点的属性列表。"""
return get_all_scada_info(network)
@router.get("/getelementpropertieswithtype/")
@router.get(
"/getelementpropertieswithtype/",
summary="获取指定类型元素属性",
description="获取指定类型的元素属性信息"
)
async def fast_get_element_properties_with_type(
network: str, elementtype: str, element: str
network: str = Query(..., description="管网名称(或数据库名称)"),
elementtype: str = Query(..., description="元素类型"),
element: str = Query(..., description="元素ID")
) -> dict[str, Any]:
"""获取指定类型元素的属性。"""
return get_element_properties_with_type(network, elementtype, element)
@router.get("/getelementproperties/")
async def fast_get_element_properties(network: str, element: str) -> dict[str, Any]:
@router.get(
"/getelementproperties/",
summary="获取元素属性",
description="获取指定元素的属性信息"
)
async def fast_get_element_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
element: str = Query(..., description="元素ID")
) -> dict[str, Any]:
"""获取元素的完整属性信息。"""
return get_element_properties(network, element)
############################################################
# title 1.[TITLE]
############################################################
@router.get("/gettitleschema/")
async def fast_get_title_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/gettitleschema/",
summary="获取标题属性架构",
description="获取指定水网的标题(标题)属性架构定义"
)
async def fast_get_title_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""获取水网标题的属性架构。"""
return get_title_schema(network)
@router.get("/gettitle/")
async def fast_get_title(network: str) -> dict[str, Any]:
@router.get(
"/gettitle/",
summary="获取水网标题属性",
description="获取指定水网的标题(Title)信息"
)
async def fast_get_title(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""获取水网的标题属性。"""
return get_title(network)
@router.get("/settitle/", response_model=None)
async def fastapi_set_title(network: str, req: Request) -> ChangeSet:
@router.get(
"/settitle/",
response_model=None,
summary="设置水网标题属性",
description="设置指定水网的标题(Title)信息"
)
async def fastapi_set_title(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置水网的标题属性。"""
props = await req.json()
return set_title(network, ChangeSet(props))
@@ -119,18 +336,41 @@ async def fastapi_set_title(network: str, req: Request) -> ChangeSet:
# status 10.[STATUS]
############################################################
@router.get("/getstatusschema")
async def fastapi_get_status_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/getstatusschema",
summary="获取状态属性架构",
description="获取指定水网的状态(Status)属性架构定义"
)
async def fastapi_get_status_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""获取水网状态的属性架构。"""
return get_status_schema(network)
@router.get("/getstatus/")
async def fastapi_get_status(network: str, link: str) -> dict[str, Any]:
@router.get(
"/getstatus/",
summary="获取管线状态",
description="获取指定管线的状态信息"
)
async def fastapi_get_status(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> dict[str, Any]:
"""获取管线的状态属性。"""
return get_status(network, link)
@router.post("/setstatus/", response_model=None)
@router.post(
"/setstatus/",
response_model=None,
summary="设置管线状态",
description="设置指定管线的状态信息"
)
async def fastapi_set_status_properties(
network: str, link: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID"),
req: Request = None
) -> ChangeSet:
"""设置管线的状态属性。"""
props = await req.json()
ps = {"link": link} | props
return set_status(network, ChangeSet(ps))
@@ -139,8 +379,17 @@ async def fastapi_set_status_properties(
# General Deletion
############################################################
@router.post("/deletenode/", response_model=None)
async def fastapi_delete_node(network: str, node: str) -> ChangeSet:
@router.post(
"/deletenode/",
response_model=None,
summary="删除节点",
description="删除指定的节点(接点/水源/蓄水池)"
)
async def fastapi_delete_node(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> ChangeSet:
"""删除指定的节点。自动识别节点类型并调用相应的删除操作。"""
ps = {"id": node}
if is_junction(network, node):
return delete_junction(network, ChangeSet(ps))
@@ -150,8 +399,17 @@ async def fastapi_delete_node(network: str, node: str) -> ChangeSet:
return delete_tank(network, ChangeSet(ps))
return ChangeSet() # Should probably raise error or return empty
@router.post("/deletelink/", response_model=None)
async def fastapi_delete_link(network: str, link: str) -> ChangeSet:
@router.post(
"/deletelink/",
response_model=None,
summary="删除管线",
description="删除指定的管线(管道/泵/阀门)"
)
async def fastapi_delete_link(
network: str = Query(..., description="管网名称(或数据库名称)"),
link: str = Query(..., description="管线ID")
) -> ChangeSet:
"""删除指定的管线。自动识别管线类型并调用相应的删除操作。"""
ps = {"id": link}
if is_pipe(network, link):
return delete_pipe(network, ChangeSet(ps))
+69 -14
View File
@@ -1,6 +1,15 @@
from fastapi import APIRouter, Request, Depends
from fastapi import APIRouter, Request, Depends, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
get_all_scada_info,
get_major_node_coords,
get_major_pipe_nodes,
get_network_in_extent,
get_network_link_nodes,
get_network_node_coords,
get_node_coord,
)
from app.auth.dependencies import get_current_user as verify_token
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
import msgpack
@@ -25,19 +34,44 @@ router = APIRouter()
# props = await req.json()
# return set_coord(network, ChangeSet(props))
@router.get("/getnodecoord/")
async def fastapi_get_node_coord(network: str, node: str) -> dict[str, float] | None:
@router.get(
"/getnodecoord/",
summary="获取节点坐标",
description="获取指定节点的地理坐标(X, Y)"
)
async def fastapi_get_node_coord(
network: str = Query(..., description="管网名称(或数据库名称)"),
node: str = Query(..., description="节点ID")
) -> dict[str, float] | None:
"""获取节点的地理坐标信息。"""
return get_node_coord(network, node)
# Additional geometry queries found in main.py logic (implicit or explicit)
@router.get("/getnetworkinextent/")
@router.get(
"/getnetworkinextent/",
summary="获取范围内的网络元素",
description="获取指定地理范围内的网络节点和管线"
)
async def fastapi_get_network_in_extent(
network: str, x1: float, y1: float, x2: float, y2: float
network: str = Query(..., description="管网名称(或数据库名称)"),
x1: float = Query(..., description="范围左下角X坐标", alias="x1"),
y1: float = Query(..., description="范围左下角Y坐标", alias="y1"),
x2: float = Query(..., description="范围右上角X坐标", alias="x2"),
y2: float = Query(..., description="范围右上角Y坐标", alias="y2")
) -> dict[str, Any]:
"""获取地理范围内的网络几何信息。"""
return get_network_in_extent(network, x1, y1, x2, y2)
@router.get("/getnetworkgeometries/", dependencies=[Depends(verify_token)])
async def fastapi_get_network_geometries(network: str) -> dict[str, Any] | None:
@router.get(
"/getnetworkgeometries/",
dependencies=[Depends(verify_token)],
summary="获取完整网络几何信息",
description="获取整个水网的所有节点、管线和SCADA点的几何信息(需要身份验证)"
)
async def fastapi_get_network_geometries(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, Any] | None:
"""获取完整的网络几何信息,包括所有节点、管线和SCADA点。结果从缓存返回。"""
cache_key = f"getnetworkgeometries_{network}"
data = redis_client.get(cache_key)
if data:
@@ -55,18 +89,39 @@ async def fastapi_get_network_geometries(network: str) -> dict[str, Any] | None:
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
return results
@router.get("/getmajornodecoords/")
@router.get(
"/getmajornodecoords/",
summary="获取主要节点坐标",
description="获取直径大于等于指定值的节点坐标"
)
async def fastapi_get_majornode_coords(
network: str, diameter: int
network: str = Query(..., description="管网名称(或数据库名称)"),
diameter: int = Query(..., description="最小直径(mm)", gt=0)
) -> dict[str, dict[str, float]]:
"""获取主要节点的坐标。只返回直径大于等于指定值的节点。"""
return get_major_node_coords(network, diameter)
@router.get("/getmajorpipenodes/")
async def fastapi_get_major_pipe_nodes(network: str, diameter: int) -> list[str] | None:
@router.get(
"/getmajorpipenodes/",
summary="获取主要管道节点",
description="获取直径大于等于指定值的管道的节点ID"
)
async def fastapi_get_major_pipe_nodes(
network: str = Query(..., description="管网名称(或数据库名称)"),
diameter: int = Query(..., description="最小直径(mm)", gt=0)
) -> list[str] | None:
"""获取主要管道节点。只返回直径大于等于指定值的管道。"""
return get_major_pipe_nodes(network, diameter)
@router.get("/getnetworklinknodes/")
async def fastapi_get_network_link_nodes(network: str) -> list[str] | None:
@router.get(
"/getnetworklinknodes/",
summary="获取网络管线节点",
description="获取指定水网所有管线的起点和终点节点"
)
async def fastapi_get_network_link_nodes(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[str] | None:
"""获取网络中所有管线的连接节点。"""
return get_network_link_nodes(network)
# @router.get("/getallcoords/")
+288 -38
View File
@@ -1,111 +1,361 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_junction,
delete_junction,
get_all_junctions,
get_junction,
get_junction_schema,
set_junction,
)
router = APIRouter()
@router.get("/getjunctionschema")
async def fast_get_junction_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getjunctionschema", summary="获取节点架构", description="获取指定项目的节点属性架构和数据类型定义。")
async def fast_get_junction_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""
获取节点架构信息。
返回指定项目的节点属性架构,包括所有属性的类型和约束信息。
Args:
network: 管网名称(或数据库名称)
"""
return get_junction_schema(network)
@router.post("/addjunction/", response_model=None)
@router.post("/addjunction/", response_model=None, summary="添加节点", description="在供水网络中添加新的节点,指定节点ID和空间坐标。")
async def fastapi_add_junction(
network: str, junction: str, x: float, y: float, z: float
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
x: float = Query(..., description="X 坐标"),
y: float = Query(..., description="Y 坐标"),
z: float = Query(..., description="标高(海拔高度)")
) -> ChangeSet:
"""
添加新节点到供水网络。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
x: X 坐标值
y: Y 坐标值
z: 标高(海拔高度)
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction, "x": x, "y": y, "elevation": z}
return add_junction(network, ChangeSet(ps))
@router.post("/deletejunction/", response_model=None)
async def fastapi_delete_junction(network: str, junction: str) -> ChangeSet:
@router.post("/deletejunction/", response_model=None, summary="删除节点", description="从供水网络中删除指定的节点。")
async def fastapi_delete_junction(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> ChangeSet:
"""
删除指定的节点。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction}
return delete_junction(network, ChangeSet(ps))
@router.get("/getjunctionelevation/")
async def fastapi_get_junction_elevation(network: str, junction: str) -> float:
@router.get("/getjunctionelevation/", summary="获取节点标高", description="获取指定节点的标高(海拔高度)。")
async def fastapi_get_junction_elevation(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> float:
"""
获取节点的标高值。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
float: 节点标高值
"""
ps = get_junction(network, junction)
return ps["elevation"]
@router.get("/getjunctionx/")
async def fastapi_get_junction_x(network: str, junction: str) -> float:
@router.get("/getjunctionx/", summary="获取节点 X 坐标", description="获取指定节点的 X 坐标值。")
async def fastapi_get_junction_x(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> float:
"""
获取节点的 X 坐标。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
float: 节点 X 坐标值
"""
ps = get_junction(network, junction)
return ps["x"]
@router.get("/getjunctiony/")
async def fastapi_get_junction_y(network: str, junction: str) -> float:
@router.get("/getjunctiony/", summary="获取节点 Y 坐标", description="获取指定节点的 Y 坐标值。")
async def fastapi_get_junction_y(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> float:
"""
获取节点的 Y 坐标。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
float: 节点 Y 坐标值
"""
ps = get_junction(network, junction)
return ps["y"]
@router.get("/getjunctioncoord/")
async def fastapi_get_junction_coord(network: str, junction: str) -> dict[str, float]:
@router.get("/getjunctioncoord/", summary="获取节点坐标", description="获取指定节点的 X 和 Y 坐标。")
async def fastapi_get_junction_coord(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> dict[str, float]:
"""
获取节点的坐标信息。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
dict: 包含 x 和 y 坐标的字典
"""
ps = get_junction(network, junction)
coord = {"x": ps["x"], "y": ps["y"]}
return coord
@router.get("/getjunctiondemand/")
async def fastapi_get_junction_demand(network: str, junction: str) -> float:
@router.get("/getjunctiondemand/", summary="获取节点需水量", description="获取指定节点的需水量。")
async def fastapi_get_junction_demand(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> float:
"""
获取节点的需水量。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
float: 节点的需水量值
"""
ps = get_junction(network, junction)
return ps["demand"]
@router.get("/getjunctionpattern/")
async def fastapi_get_junction_pattern(network: str, junction: str) -> str:
@router.get("/getjunctionpattern/", summary="获取节点需水模式", description="获取指定节点的需水模式标识。")
async def fastapi_get_junction_pattern(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> str:
"""
获取节点的需水模式。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
str: 节点的需水模式标识
"""
ps = get_junction(network, junction)
return ps["pattern"]
@router.post("/setjunctionelevation/", response_model=None)
@router.post("/setjunctionelevation/", response_model=None, summary="设置节点标高", description="设置指定节点的标高值。")
async def fastapi_set_junction_elevation(
network: str, junction: str, elevation: float
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
elevation: float = Query(..., description="标高(海拔高度)")
) -> ChangeSet:
"""
设置节点的标高。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
elevation: 标高值
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction, "elevation": elevation}
return set_junction(network, ChangeSet(ps))
@router.post("/setjunctionx/", response_model=None)
async def fastapi_set_junction_x(network: str, junction: str, x: float) -> ChangeSet:
@router.post("/setjunctionx/", response_model=None, summary="设置节点 X 坐标", description="设置指定节点的 X 坐标值。")
async def fastapi_set_junction_x(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
x: float = Query(..., description="X 坐标值")
) -> ChangeSet:
"""
设置节点的 X 坐标。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
x: X 坐标值
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction, "x": x}
return set_junction(network, ChangeSet(ps))
@router.post("/setjunctiony/", response_model=None)
async def fastapi_set_junction_y(network: str, junction: str, y: float) -> ChangeSet:
@router.post("/setjunctiony/", response_model=None, summary="设置节点 Y 坐标", description="设置指定节点的 Y 坐标值。")
async def fastapi_set_junction_y(
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
y: float = Query(..., description="Y 坐标值")
) -> ChangeSet:
"""
设置节点的 Y 坐标。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
y: Y 坐标值
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction, "y": y}
return set_junction(network, ChangeSet(ps))
@router.post("/setjunctioncoord/", response_model=None)
@router.post("/setjunctioncoord/", response_model=None, summary="设置节点坐标", description="设置指定节点的 X 和 Y 坐标。")
async def fastapi_set_junction_coord(
network: str, junction: str, x: float, y: float
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
x: float = Query(..., description="X 坐标值"),
y: float = Query(..., description="Y 坐标值")
) -> ChangeSet:
"""
设置节点的坐标。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
x: X 坐标值
y: Y 坐标值
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction, "x": x, "y": y}
return set_junction(network, ChangeSet(ps))
@router.post("/setjunctiondemand/", response_model=None)
@router.post("/setjunctiondemand/", response_model=None, summary="设置节点需水量", description="设置指定节点的需水量。")
async def fastapi_set_junction_demand(
network: str, junction: str, demand: float
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
demand: float = Query(..., description="需水量值")
) -> ChangeSet:
"""
设置节点的需水量。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
demand: 需水量值
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction, "demand": demand}
return set_junction(network, ChangeSet(ps))
@router.post("/setjunctionpattern/", response_model=None)
@router.post("/setjunctionpattern/", response_model=None, summary="设置节点需水模式", description="设置指定节点的需水模式标识。")
async def fastapi_set_junction_pattern(
network: str, junction: str, pattern: str
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
pattern: str = Query(..., description="需水模式标识")
) -> ChangeSet:
"""
设置节点的需水模式。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
pattern: 需水模式标识
Returns:
ChangeSet: 包含变更信息的结果
"""
ps = {"id": junction, "pattern": pattern}
return set_junction(network, ChangeSet(ps))
@router.get("/getjunctionproperties/")
@router.get("/getjunctionproperties/", summary="获取节点属性", description="获取指定节点的所有属性信息。")
async def fastapi_get_junction_properties(
network: str, junction: str
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID")
) -> dict[str, Any]:
"""
获取节点的完整属性信息。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
Returns:
dict: 包含节点所有属性的字典
"""
return get_junction(network, junction)
@router.get("/getalljunctionproperties/")
async def fastapi_get_all_junction_properties(network: str) -> list[dict[str, Any]]:
@router.get("/getalljunctionproperties/", summary="获取所有节点属性", description="获取指定项目中所有节点的属性信息。")
async def fastapi_get_all_junction_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取所有节点的属性信息列表。
此端点返回指定项目中所有节点的详细属性。缓存查询结果以提高性能。
Args:
network: 管网名称(或数据库名称)
Returns:
list: 包含所有节点属性的列表
"""
# 缓存查询结果提高性能
# global redis_client # Redis logic removed for clean split, can be re-added if needed or imported
results = get_all_junctions(network)
return results
@router.post("/setjunctionproperties/", response_model=None)
@router.post("/setjunctionproperties/", response_model=None, summary="批量设置节点属性", description="批量设置指定节点的多个属性。")
async def fastapi_set_junction_properties(
network: str, junction: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
junction: str = Query(..., description="节点 ID"),
req: Request = None
) -> ChangeSet:
"""
批量设置节点属性。
允许一次性设置节点的多个属性,如坐标、标高、需水量等。
Args:
network: 管网名称(或数据库名称)
junction: 节点 ID
req: 包含属性和值的 JSON 请求体
Returns:
ChangeSet: 包含变更信息的结果
"""
props = await req.json()
ps = {"id": junction} | props
return set_junction(network, ChangeSet(ps))
+328 -50
View File
@@ -1,25 +1,63 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
PIPE_STATUS_OPEN,
add_pipe,
delete_pipe,
get_all_pipes,
get_pipe,
get_pipe_schema,
set_pipe,
)
router = APIRouter()
@router.get("/getpipeschema")
async def fastapi_get_pipe_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getpipeschema", summary="获取管道模式", description="获取管道对象的模式定义,包含所有可用字段及其类型")
async def fastapi_get_pipe_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""
获取管道数据模式定义。
Args:
network: 管网名称(或数据库名称)
Returns:
包含管道模式信息的字典
"""
return get_pipe_schema(network)
@router.post("/addpipe/", response_model=None)
@router.post("/addpipe/", response_model=None, summary="添加管道", description="向网络中添加新的管道,需要提供管道的基本参数如长度、管径、粗糙度等")
async def fastapi_add_pipe(
network: str,
pipe: str,
node1: str,
node2: str,
length: float = 0,
diameter: float = 0,
roughness: float = 0,
minor_loss: float = 0,
status: str = PIPE_STATUS_OPEN,
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道标识符"),
node1: str = Query(..., description="管道起始节点ID"),
node2: str = Query(..., description="管道终止节点ID"),
length: float = Query(0, description="管道长度(单位:米)"),
diameter: float = Query(0, description="管道管径(单位:毫米)"),
roughness: float = Query(0, description="管道粗糙度"),
minor_loss: float = Query(0, description="管道局部阻力系数"),
status: str = Query(PIPE_STATUS_OPEN, description="管道状态(开启/关闭)"),
) -> ChangeSet:
"""
添加新管道到网络。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
node1: 起始节点ID
node2: 终止节点ID
length: 管道长度
diameter: 管道管径
roughness: 管道粗糙度
minor_loss: 局部阻力系数
status: 管道状态
Returns:
ChangeSet对象,包含本次操作的变更信息
"""
ps = {
"id": pipe,
"node1": node1,
@@ -32,102 +70,342 @@ async def fastapi_add_pipe(
}
return add_pipe(network, ChangeSet(ps))
@router.post("/deletepipe/", response_model=None)
async def fastapi_delete_pipe(network: str, pipe: str) -> ChangeSet:
@router.post("/deletepipe/", response_model=None, summary="删除管道", description="从网络中删除指定的管道")
async def fastapi_delete_pipe(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="要删除的管道ID")
) -> ChangeSet:
"""
删除管道。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
ChangeSet对象,包含本次删除操作的变更信息
"""
ps = {"id": pipe}
return delete_pipe(network, ChangeSet(ps))
@router.get("/getpipenode1/")
async def fastapi_get_pipe_node1(network: str, pipe: str) -> str | None:
@router.get("/getpipenode1/", summary="获取管道起始节点", description="获取指定管道的起始节点ID")
async def fastapi_get_pipe_node1(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> str | None:
"""
获取管道的起始节点。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
起始节点ID,如果不存在则返回None
"""
ps = get_pipe(network, pipe)
return ps["node1"]
@router.get("/getpipenode2/")
async def fastapi_get_pipe_node2(network: str, pipe: str) -> str | None:
@router.get("/getpipenode2/", summary="获取管道终止节点", description="获取指定管道的终止节点ID")
async def fastapi_get_pipe_node2(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> str | None:
"""
获取管道的终止节点。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
终止节点ID,如果不存在则返回None
"""
ps = get_pipe(network, pipe)
return ps["node2"]
@router.get("/getpipelength/")
async def fastapi_get_pipe_length(network: str, pipe: str) -> float | None:
@router.get("/getpipelength/", summary="获取管道长度", description="获取指定管道的长度")
async def fastapi_get_pipe_length(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> float | None:
"""
获取管道长度。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
管道长度(单位:米),如果不存在则返回None
"""
ps = get_pipe(network, pipe)
return ps["length"]
@router.get("/getpipediameter/")
async def fastapi_get_pipe_diameter(network: str, pipe: str) -> float | None:
@router.get("/getpipediameter/", summary="获取管道管径", description="获取指定管道的管径")
async def fastapi_get_pipe_diameter(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> float | None:
"""
获取管道管径。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
管道管径(单位:毫米),如果不存在则返回None
"""
ps = get_pipe(network, pipe)
return ps["diameter"]
@router.get("/getpiperoughness/")
async def fastapi_get_pipe_roughness(network: str, pipe: str) -> float | None:
@router.get("/getpiperoughness/", summary="获取管道粗糙度", description="获取指定管道的粗糙度")
async def fastapi_get_pipe_roughness(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> float | None:
"""
获取管道粗糙度。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
管道粗糙度值,如果不存在则返回None
"""
ps = get_pipe(network, pipe)
return ps["roughness"]
@router.get("/getpipeminorloss/")
async def fastapi_get_pipe_minor_loss(network: str, pipe: str) -> float | None:
@router.get("/getpipeminorloss/", summary="获取管道局部阻力系数", description="获取指定管道的局部阻力系数")
async def fastapi_get_pipe_minor_loss(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> float | None:
"""
获取管道局部阻力系数。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
局部阻力系数,如果不存在则返回None
"""
ps = get_pipe(network, pipe)
return ps["minor_loss"]
@router.get("/getpipestatus/")
async def fastapi_get_pipe_status(network: str, pipe: str) -> str | None:
@router.get("/getpipestatus/", summary="获取管道状态", description="获取指定管道的状态(开启或关闭)")
async def fastapi_get_pipe_status(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> str | None:
"""
获取管道状态。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
管道状态(开启/关闭),如果不存在则返回None
"""
ps = get_pipe(network, pipe)
return ps["status"]
@router.post("/setpipenode1/", response_model=None)
async def fastapi_set_pipe_node1(network: str, pipe: str, node1: str) -> ChangeSet:
@router.post("/setpipenode1/", response_model=None, summary="设置管道起始节点", description="设置指定管道的起始节点")
async def fastapi_set_pipe_node1(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
node1: str = Query(..., description="新的起始节点ID")
) -> ChangeSet:
"""
设置管道起始节点。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
node1: 新的起始节点ID
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pipe, "node1": node1}
return set_pipe(network, ChangeSet(ps))
@router.post("/setpipenode2/", response_model=None)
async def fastapi_set_pipe_node2(network: str, pipe: str, node2: str) -> ChangeSet:
@router.post("/setpipenode2/", response_model=None, summary="设置管道终止节点", description="设置指定管道的终止节点")
async def fastapi_set_pipe_node2(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
node2: str = Query(..., description="新的终止节点ID")
) -> ChangeSet:
"""
设置管道终止节点。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
node2: 新的终止节点ID
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pipe, "node2": node2}
return set_pipe(network, ChangeSet(ps))
@router.post("/setpipelength/", response_model=None)
async def fastapi_set_pipe_length(network: str, pipe: str, length: float) -> ChangeSet:
@router.post("/setpipelength/", response_model=None, summary="设置管道长度", description="设置指定管道的长度")
async def fastapi_set_pipe_length(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
length: float = Query(..., description="新的管道长度(单位:米)")
) -> ChangeSet:
"""
设置管道长度。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
length: 新的管道长度
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pipe, "length": length}
return set_pipe(network, ChangeSet(ps))
@router.post("/setpipediameter/", response_model=None)
@router.post("/setpipediameter/", response_model=None, summary="设置管道管径", description="设置指定管道的管径")
async def fastapi_set_pipe_diameter(
network: str, pipe: str, diameter: float
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
diameter: float = Query(..., description="新的管道管径(单位:毫米)")
) -> ChangeSet:
"""
设置管道管径。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
diameter: 新的管道管径
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pipe, "diameter": diameter}
return set_pipe(network, ChangeSet(ps))
@router.post("/setpiperoughness/", response_model=None)
@router.post("/setpiperoughness/", response_model=None, summary="设置管道粗糙度", description="设置指定管道的粗糙度")
async def fastapi_set_pipe_roughness(
network: str, pipe: str, roughness: float
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
roughness: float = Query(..., description="新的管道粗糙度值")
) -> ChangeSet:
"""
设置管道粗糙度。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
roughness: 新的管道粗糙度
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pipe, "roughness": roughness}
return set_pipe(network, ChangeSet(ps))
@router.post("/setpipeminorloss/", response_model=None)
@router.post("/setpipeminorloss/", response_model=None, summary="设置管道局部阻力系数", description="设置指定管道的局部阻力系数")
async def fastapi_set_pipe_minor_loss(
network: str, pipe: str, minor_loss: float
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
minor_loss: float = Query(..., description="新的局部阻力系数值")
) -> ChangeSet:
"""
设置管道局部阻力系数。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
minor_loss: 新的局部阻力系数
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pipe, "minor_loss": minor_loss}
return set_pipe(network, ChangeSet(ps))
@router.post("/setpipestatus/", response_model=None)
async def fastapi_set_pipe_status(network: str, pipe: str, status: str) -> ChangeSet:
@router.post("/setpipestatus/", response_model=None, summary="设置管道状态", description="设置指定管道的状态(开启或关闭)")
async def fastapi_set_pipe_status(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
status: str = Query(..., description="新的管道状态(开启/关闭)")
) -> ChangeSet:
"""
设置管道状态。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
status: 新的管道状态
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pipe, "status": status}
return set_pipe(network, ChangeSet(ps))
@router.get("/getpipeproperties/")
async def fastapi_get_pipe_properties(network: str, pipe: str) -> dict[str, Any]:
@router.get("/getpipeproperties/", summary="获取管道属性", description="获取指定管道的所有属性信息")
async def fastapi_get_pipe_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID")
) -> dict[str, Any]:
"""
获取管道的所有属性。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
Returns:
包含管道所有属性的字典
"""
return get_pipe(network, pipe)
@router.get("/getallpipeproperties/")
async def fastapi_get_all_pipe_properties(network: str) -> list[dict[str, Any]]:
@router.get("/getallpipeproperties/", summary="获取所有管道属性", description="获取网络中所有管道的属性信息列表")
async def fastapi_get_all_pipe_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取网络中所有管道的属性。
Args:
network: 管网名称(或数据库名称)
Returns:
包含所有管道属性的字典列表
"""
# 缓存查询结果提高性能
# global redis_client
results = get_all_pipes(network)
return results
@router.post("/setpipeproperties/", response_model=None)
@router.post("/setpipeproperties/", response_model=None, summary="设置管道属性", description="批量设置指定管道的多个属性")
async def fastapi_set_pipe_properties(
network: str, pipe: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe: str = Query(..., description="管道ID"),
req: Request = None
) -> ChangeSet:
"""
批量设置管道属性。
Args:
network: 管网名称(或数据库名称)
pipe: 管道ID
req: 请求体,包含要设置的属性及其值
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
props = await req.json()
ps = {"id": pipe} | props
return set_pipe(network, ChangeSet(ps))
+165 -22
View File
@@ -1,60 +1,203 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_pump,
delete_pump,
get_all_pumps,
get_pump,
get_pump_schema,
set_pump,
)
router = APIRouter()
@router.get("/getpumpschema")
async def fastapi_get_pump_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getpumpschema", summary="获取水泵模式", description="获取水泵对象的模式定义,包含所有可用字段及其类型")
async def fastapi_get_pump_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""
获取水泵数据模式定义。
Args:
network: 管网名称(或数据库名称)
Returns:
包含水泵模式信息的字典
"""
return get_pump_schema(network)
@router.post("/addpump/", response_model=None)
@router.post("/addpump/", response_model=None, summary="添加水泵", description="向网络中添加新的水泵,需要提供水泵的基本参数如功率等")
async def fastapi_add_pump(
network: str, pump: str, node1: str, node2: str, power: float = 0.0
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="水泵标识符"),
node1: str = Query(..., description="水泵起始节点ID"),
node2: str = Query(..., description="水泵终止节点ID"),
power: float = Query(0.0, description="水泵功率(单位:千瓦)")
) -> ChangeSet:
"""
添加新水泵到网络。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
node1: 起始节点ID
node2: 终止节点ID
power: 水泵功率
Returns:
ChangeSet对象,包含本次操作的变更信息
"""
ps = {"id": pump, "node1": node1, "node2": node2, "power": power}
return add_pump(network, ChangeSet(ps))
@router.post("/deletepump/", response_model=None)
async def fastapi_delete_pump(network: str, pump: str) -> ChangeSet:
@router.post("/deletepump/", response_model=None, summary="删除水泵", description="从网络中删除指定的水泵")
async def fastapi_delete_pump(
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="要删除的水泵ID")
) -> ChangeSet:
"""
删除水泵。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
Returns:
ChangeSet对象,包含本次删除操作的变更信息
"""
ps = {"id": pump}
return delete_pump(network, ChangeSet(ps))
@router.get("/getpumpnode1/")
async def fastapi_get_pump_node1(network: str, pump: str) -> str | None:
@router.get("/getpumpnode1/", summary="获取水泵起始节点", description="获取指定水泵的起始节点ID")
async def fastapi_get_pump_node1(
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="水泵ID")
) -> str | None:
"""
获取水泵的起始节点。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
Returns:
起始节点ID,如果不存在则返回None
"""
ps = get_pump(network, pump)
return ps["node1"]
@router.get("/getpumpnode2/")
async def fastapi_get_pump_node2(network: str, pump: str) -> str | None:
@router.get("/getpumpnode2/", summary="获取水泵终止节点", description="获取指定水泵的终止节点ID")
async def fastapi_get_pump_node2(
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="水泵ID")
) -> str | None:
"""
获取水泵的终止节点。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
Returns:
终止节点ID,如果不存在则返回None
"""
ps = get_pump(network, pump)
return ps["node2"]
@router.post("/setpumpnode1/", response_model=None)
async def fastapi_set_pump_node1(network: str, pump: str, node1: str) -> ChangeSet:
@router.post("/setpumpnode1/", response_model=None, summary="设置水泵起始节点", description="设置指定水泵的起始节点")
async def fastapi_set_pump_node1(
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="水泵ID"),
node1: str = Query(..., description="新的起始节点ID")
) -> ChangeSet:
"""
设置水泵起始节点。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
node1: 新的起始节点ID
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pump, "node1": node1}
return set_pump(network, ChangeSet(ps))
@router.post("/setpumpnode2/", response_model=None)
async def fastapi_set_pump_node2(network: str, pump: str, node2: str) -> ChangeSet:
@router.post("/setpumpnode2/", response_model=None, summary="设置水泵终止节点", description="设置指定水泵的终止节点")
async def fastapi_set_pump_node2(
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="水泵ID"),
node2: str = Query(..., description="新的终止节点ID")
) -> ChangeSet:
"""
设置水泵终止节点。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
node2: 新的终止节点ID
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
ps = {"id": pump, "node2": node2}
return set_pump(network, ChangeSet(ps))
@router.get("/getpumpproperties/")
async def fastapi_get_pump_properties(network: str, pump: str) -> dict[str, Any]:
@router.get("/getpumpproperties/", summary="获取水泵属性", description="获取指定水泵的所有属性信息")
async def fastapi_get_pump_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="水泵ID")
) -> dict[str, Any]:
"""
获取水泵的所有属性。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
Returns:
包含水泵所有属性的字典
"""
return get_pump(network, pump)
@router.get("/getallpumpproperties/")
async def fastapi_get_all_pump_properties(network: str) -> list[dict[str, Any]]:
@router.get("/getallpumpproperties/", summary="获取所有水泵属性", description="获取网络中所有水泵的属性信息列表")
async def fastapi_get_all_pump_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取网络中所有水泵的属性。
Args:
network: 管网名称(或数据库名称)
Returns:
包含所有水泵属性的字典列表
"""
# 缓存查询结果提高性能
# global redis_client
results = get_all_pumps(network)
return results
@router.post("/setpumpproperties/", response_model=None)
@router.post("/setpumpproperties/", response_model=None, summary="设置水泵属性", description="批量设置指定水泵的多个属性")
async def fastapi_set_pump_properties(
network: str, pump: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
pump: str = Query(..., description="水泵ID"),
req: Request = None
) -> ChangeSet:
"""
批量设置水泵属性。
Args:
network: 管网名称(或数据库名称)
pump: 水泵ID
req: 请求体,包含要设置的属性及其值
Returns:
ChangeSet对象,包含本次修改的变更信息
"""
props = await req.json()
ps = {"id": pump} | props
return set_pump(network, ChangeSet(ps))
+385 -96
View File
@@ -1,6 +1,42 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_district_metering_area,
add_region,
add_service_area,
add_virtual_district,
calculate_district_metering_area_for_network,
calculate_district_metering_area_for_nodes,
calculate_district_metering_area_for_region,
calculate_service_area,
calculate_virtual_district,
delete_district_metering_area,
delete_region,
delete_service_area,
delete_virtual_district,
generate_district_metering_area,
generate_service_area,
generate_sub_district_metering_area,
generate_virtual_district,
get_all_district_metering_area_ids,
get_all_district_metering_areas,
get_all_service_areas,
get_all_virtual_districts,
get_district_metering_area,
get_district_metering_area_schema,
get_region,
get_region_schema,
get_service_area,
get_service_area_schema,
get_virtual_district,
get_virtual_district_schema,
set_district_metering_area,
set_region,
set_service_area,
set_virtual_district,
)
router = APIRouter()
@@ -8,64 +44,95 @@ router = APIRouter()
# region 32
############################################################
@router.get("/calculateregion/")
async def fastapi_calculate_region(network: str, time_index: int) -> dict[str, Any]:
return calculate_region(network, time_index)
@router.get("/getregionschema/")
async def fastapi_get_region_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/getregionschema/",
summary="获取区域属性架构",
description="获取指定水网的区域属性架构定义"
)
async def fastapi_get_region_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""获取区域的属性架构。"""
return get_region_schema(network)
@router.get("/getregion/")
async def fastapi_get_region(network: str, id: str) -> dict[str, Any]:
@router.get(
"/getregion/",
summary="获取区域信息",
description="获取指定ID的区域详细信息"
)
async def fastapi_get_region(
network: str = Query(..., description="管网名称(或数据库名称)"),
id: str = Query(..., description="区域ID")
) -> dict[str, Any]:
"""获取区域的详细信息。"""
return get_region(network, id)
@router.post("/setregion/", response_model=None)
async def fastapi_set_region(network: str, req: Request) -> ChangeSet:
@router.post(
"/setregion/",
response_model=None,
summary="设置区域属性",
description="修改指定区域的属性信息"
)
async def fastapi_set_region(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置区域属性。"""
props = await req.json()
return set_region(network, ChangeSet(props))
@router.post("/addregion/", response_model=None)
async def fastapi_add_region(network: str, req: Request) -> ChangeSet:
@router.post(
"/addregion/",
response_model=None,
summary="添加新区域",
description="向水网添加一个新的区域"
)
async def fastapi_add_region(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加新的区域。"""
props = await req.json()
return add_region(network, ChangeSet(props))
@router.post("/deleteregion/", response_model=None)
async def fastapi_delete_region(network: str, req: Request) -> ChangeSet:
@router.post(
"/deleteregion/",
response_model=None,
summary="删除区域",
description="删除指定的区域"
)
async def fastapi_delete_region(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""删除区域。"""
props = await req.json()
return delete_region(network, ChangeSet(props))
@router.get("/getallregions/")
async def fastapi_get_all_regions(network: str) -> list[dict[str, Any]]:
return get_all_regions(network)
@router.post("/generateregion/", response_model=None)
async def fastapi_generate_region(
network: str, inflate_delta: float
) -> ChangeSet:
return generate_region(network, inflate_delta)
############################################################
# district_metering_area 33
############################################################
@router.get("/calculatedistrictmeteringarea/")
async def fastapi_calculate_district_metering_area(
network: str, req: Request
) -> list[list[str]]:
props = await req.json()
nodes = props["nodes"]
part_count = props["part_count"]
part_type = props["part_type"]
return calculate_district_metering_area(
network, nodes, part_count, part_type
)
@router.get("/calculatedistrictmeteringareaforregion/")
@router.get(
"/calculatedistrictmeteringareaforregion/",
summary="计算区域内DMA分区",
description="为指定区域计算区域计量(DMA)分区方案"
)
async def fastapi_calculate_district_metering_area_for_region(
network: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> list[list[str]]:
"""
计算区域内DMA分区。
请求体格式:
{
"region": 区域ID(str),
"part_count": 分区数量(int),
"part_type": 分区类型(int)
}
"""
props = await req.json()
region = props["region"]
part_count = props["part_count"]
@@ -74,32 +141,77 @@ async def fastapi_calculate_district_metering_area_for_region(
network, region, part_count, part_type
)
@router.get("/calculatedistrictmeteringareafornetwork/")
@router.get(
"/calculatedistrictmeteringareafornetwork/",
summary="计算整网DMA分区",
description="为整个水网计算区域计量(DMA)分区方案"
)
async def fastapi_calculate_district_metering_area_for_network(
network: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> list[list[str]]:
"""
计算整网DMA分区。
请求体格式:
{
"part_count": 分区数量(int),
"part_type": 分区类型(int)
}
"""
props = await req.json()
part_count = props["part_count"]
part_type = props["part_type"]
return calculate_district_metering_area_for_network(network, part_count, part_type)
@router.get("/getdistrictmeteringareaschema/")
@router.get(
"/getdistrictmeteringareaschema/",
summary="获取DMA属性架构",
description="获取指定水网的区域计量(DMA)属性架构定义"
)
async def fastapi_get_district_metering_area_schema(
network: str,
network: str = Query(..., description="管网名称(或数据库名称)"),
) -> dict[str, dict[str, Any]]:
"""获取DMA的属性架构。"""
return get_district_metering_area_schema(network)
@router.get("/getdistrictmeteringarea/")
async def fastapi_get_district_metering_area(network: str, id: str) -> dict[str, Any]:
@router.get(
"/getdistrictmeteringarea/",
summary="获取DMA信息",
description="获取指定ID的区域计量(DMA)详细信息"
)
async def fastapi_get_district_metering_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
id: str = Query(..., description="DMA ID")
) -> dict[str, Any]:
"""获取DMA的详细信息。"""
return get_district_metering_area(network, id)
@router.post("/setdistrictmeteringarea/", response_model=None)
async def fastapi_set_district_metering_area(network: str, req: Request) -> ChangeSet:
@router.post(
"/setdistrictmeteringarea/",
response_model=None,
summary="设置DMA属性",
description="修改指定DMA的属性信息"
)
async def fastapi_set_district_metering_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置DMA属性。"""
props = await req.json()
return set_district_metering_area(network, ChangeSet(props))
@router.post("/adddistrictmeteringarea/", response_model=None)
async def fastapi_add_district_metering_area(network: str, req: Request) -> ChangeSet:
@router.post(
"/adddistrictmeteringarea/",
response_model=None,
summary="添加新DMA",
description="向水网添加一个新的区域计量(DMA)"
)
async def fastapi_add_district_metering_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加新的DMA。"""
props = await req.json()
# boundary should be [(x,y), (x,y)]
boundary = props.get("boundary", [])
@@ -110,33 +222,73 @@ async def fastapi_add_district_metering_area(network: str, req: Request) -> Chan
props["boundary"] = newBoundary
return add_district_metering_area(network, ChangeSet(props))
@router.post("/deletedistrictmeteringarea/", response_model=None)
@router.post(
"/deletedistrictmeteringarea/",
response_model=None,
summary="删除DMA",
description="删除指定的区域计量(DMA)"
)
async def fastapi_delete_district_metering_area(
network: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""删除DMA。"""
props = await req.json()
return delete_district_metering_area(network, ChangeSet(props))
@router.get("/getalldistrictmeteringareaids/")
async def fastapi_get_all_district_metering_area_ids(network: str) -> list[str]:
@router.get(
"/getalldistrictmeteringareaids/",
summary="获取所有DMA ID",
description="获取指定水网中所有DMA的ID列表"
)
async def fastapi_get_all_district_metering_area_ids(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[str]:
"""获取所有DMA的ID列表。"""
return get_all_district_metering_area_ids(network)
@router.get("/getalldistrictmeteringareas/")
async def getalldistrictmeteringareas(network: str) -> list[dict[str, Any]]:
@router.get(
"/getalldistrictmeteringareas/",
summary="获取所有DMA",
description="获取指定水网中所有DMA的详细信息"
)
async def getalldistrictmeteringareas(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""获取所有DMA的详细信息列表。"""
return get_all_district_metering_areas(network)
@router.post("/generatedistrictmeteringarea/", response_model=None)
@router.post(
"/generatedistrictmeteringarea/",
response_model=None,
summary="生成DMA分区",
description="根据参数自动生成水网的DMA分区方案"
)
async def fastapi_generate_district_metering_area(
network: str, part_count: int, part_type: int, inflate_delta: float
network: str = Query(..., description="管网名称(或数据库名称)"),
part_count: int = Query(..., description="分区数量", gt=0),
part_type: int = Query(..., description="分区类型"),
inflate_delta: float = Query(..., description="膨胀参数")
) -> ChangeSet:
"""生成DMA分区。"""
return generate_district_metering_area(
network, part_count, part_type, inflate_delta
)
@router.post("/generatesubdistrictmeteringarea/", response_model=None)
@router.post(
"/generatesubdistrictmeteringarea/",
response_model=None,
summary="生成DMA子分区",
description="为指定DMA生成子DMA分区"
)
async def fastapi_generate_sub_district_metering_area(
network: str, dma: str, part_count: int, part_type: int, inflate_delta: float
network: str = Query(..., description="管网名称(或数据库名称)"),
dma: str = Query(..., description="DMA ID"),
part_count: int = Query(..., description="分区数量", gt=0),
part_type: int = Query(..., description="分区类型"),
inflate_delta: float = Query(..., description="膨胀参数")
) -> ChangeSet:
"""生成DMA子分区。"""
return generate_sub_district_metering_area(
network, dma, part_count, part_type, inflate_delta
)
@@ -146,43 +298,104 @@ async def fastapi_generate_sub_district_metering_area(
# service_area 34
############################################################
@router.get("/calculateservicearea/")
@router.get(
"/calculateservicearea/",
summary="计算服务区",
description="计算指定水网的服务区分区,返回全部时间步结果"
)
async def fastapi_calculate_service_area(
network: str, time_index: int
) -> dict[str, Any]:
return calculate_service_area(network, time_index)
network: str = Query(..., description="管网名称(或数据库名称)"),
) -> list[dict[str, list[str]]]:
"""计算服务区分区,返回全部时间步结果。"""
return calculate_service_area(network)
@router.get("/getserviceareaschema/")
async def fastapi_get_service_area_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/getserviceareaschema/",
summary="获取服务区属性架构",
description="获取指定水网的服务区属性架构定义"
)
async def fastapi_get_service_area_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""获取服务区的属性架构。"""
return get_service_area_schema(network)
@router.get("/getservicearea/")
async def fastapi_get_service_area(network: str, id: str) -> dict[str, Any]:
@router.get(
"/getservicearea/",
summary="获取服务区信息",
description="获取指定ID的服务区详细信息"
)
async def fastapi_get_service_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
id: str = Query(..., description="服务区ID")
) -> dict[str, Any]:
"""获取服务区的详细信息。"""
return get_service_area(network, id)
@router.post("/setservicearea/", response_model=None)
async def fastapi_set_service_area(network: str, req: Request) -> ChangeSet:
@router.post(
"/setservicearea/",
response_model=None,
summary="设置服务区属性",
description="修改指定服务区的属性信息"
)
async def fastapi_set_service_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置服务区属性。"""
props = await req.json()
return set_service_area(network, ChangeSet(props))
@router.post("/addservicearea/", response_model=None)
async def fastapi_add_service_area(network: str, req: Request) -> ChangeSet:
@router.post(
"/addservicearea/",
response_model=None,
summary="添加新服务区",
description="向水网添加一个新的服务区"
)
async def fastapi_add_service_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加新的服务区。"""
props = await req.json()
return add_service_area(network, ChangeSet(props))
@router.post("/deleteservicearea/", response_model=None)
async def fastapi_delete_service_area(network: str, req: Request) -> ChangeSet:
@router.post(
"/deleteservicearea/",
response_model=None,
summary="删除服务区",
description="删除指定的服务区"
)
async def fastapi_delete_service_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""删除服务区。"""
props = await req.json()
return delete_service_area(network, ChangeSet(props))
@router.get("/getallserviceareas/")
async def fastapi_get_all_service_areas(network: str) -> list[dict[str, Any]]:
@router.get(
"/getallserviceareas/",
summary="获取所有服务区",
description="获取指定水网中的所有服务区信息"
)
async def fastapi_get_all_service_areas(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""获取所有服务区的信息列表。"""
return get_all_service_areas(network)
@router.post("/generateservicearea/", response_model=None)
@router.post(
"/generateservicearea/",
response_model=None,
summary="生成服务区分区",
description="根据参数自动生成水网的服务区分区"
)
async def fastapi_generate_service_area(
network: str, inflate_delta: float
network: str = Query(..., description="管网名称(或数据库名称)"),
inflate_delta: float = Query(..., description="膨胀参数")
) -> ChangeSet:
"""生成服务区分区。"""
return generate_service_area(network, inflate_delta)
@@ -190,52 +403,128 @@ async def fastapi_generate_service_area(
# virtual_district 35
############################################################
@router.get("/calculatevirtualdistrict/")
@router.get(
"/calculatevirtualdistrict/",
summary="计算虚拟分区",
description="根据指定的压力监测节点作为中心节点计算虚拟分区方案"
)
async def fastapi_calculate_virtual_district(
network: str, centers: list[str]
network: str = Query(..., description="管网名称(或数据库名称)"),
centers: list[str] = Query(..., description="压力监测节点ID列表")
) -> dict[str, list[Any]]:
"""计算虚拟分区。"""
return calculate_virtual_district(network, centers)
@router.get("/getvirtualdistrictschema/")
@router.get(
"/getvirtualdistrictschema/",
summary="获取虚拟分区属性架构",
description="获取指定水网的虚拟分区属性架构定义"
)
async def fastapi_get_virtual_district_schema(
network: str,
network: str = Query(..., description="管网名称(或数据库名称)"),
) -> dict[str, dict[str, Any]]:
"""获取虚拟分区的属性架构。"""
return get_virtual_district_schema(network)
@router.get("/getvirtualdistrict/")
async def fastapi_get_virtual_district(network: str, id: str) -> dict[str, Any]:
@router.get(
"/getvirtualdistrict/",
summary="获取虚拟分区信息",
description="获取指定ID的虚拟分区详细信息"
)
async def fastapi_get_virtual_district(
network: str = Query(..., description="管网名称(或数据库名称)"),
id: str = Query(..., description="虚拟分区ID")
) -> dict[str, Any]:
"""获取虚拟分区的详细信息。"""
return get_virtual_district(network, id)
@router.post("/setvirtualdistrict/", response_model=None)
async def fastapi_set_virtual_district(network: str, req: Request) -> ChangeSet:
@router.post(
"/setvirtualdistrict/",
response_model=None,
summary="设置虚拟分区属性",
description="修改指定虚拟分区的属性信息"
)
async def fastapi_set_virtual_district(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置虚拟分区属性。"""
props = await req.json()
return set_virtual_district(network, ChangeSet(props))
@router.post("/addvirtualdistrict/", response_model=None)
async def fastapi_add_virtual_district(network: str, req: Request) -> ChangeSet:
@router.post(
"/addvirtualdistrict/",
response_model=None,
summary="添加新虚拟分区",
description="向水网添加一个新的虚拟分区"
)
async def fastapi_add_virtual_district(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""添加新的虚拟分区。"""
props = await req.json()
return add_virtual_district(network, ChangeSet(props))
@router.post("/deletevirtualdistrict/", response_model=None)
async def fastapi_delete_virtual_district(network: str, req: Request) -> ChangeSet:
@router.post(
"/deletevirtualdistrict/",
response_model=None,
summary="删除虚拟分区",
description="删除指定的虚拟分区"
)
async def fastapi_delete_virtual_district(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""删除虚拟分区。"""
props = await req.json()
return delete_virtual_district(network, ChangeSet(props))
@router.get("/getallvirtualdistrict/")
async def fastapi_get_all_virtual_district(network: str) -> list[dict[str, Any]]:
@router.get(
"/getallvirtualdistrict/",
summary="获取所有虚拟分区",
description="获取指定水网中的所有虚拟分区信息"
)
async def fastapi_get_all_virtual_district(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""获取所有虚拟分区的信息列表。"""
return get_all_virtual_districts(network)
@router.post("/generatevirtualdistrict/", response_model=None)
@router.post(
"/generatevirtualdistrict/",
response_model=None,
summary="生成虚拟分区",
description="根据参数自动生成虚拟分区方案"
)
async def fastapi_generate_virtual_district(
network: str, inflate_delta: float, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
inflate_delta: float = Query(..., description="膨胀参数"),
req: Request = None
) -> ChangeSet:
"""生成虚拟分区。"""
props = await req.json()
return generate_virtual_district(network, props["centers"], inflate_delta)
@router.get("/calculatedistrictmeteringareafornodes/")
@router.get(
"/calculatedistrictmeteringareafornodes/",
summary="计算节点DMA分区",
description="为指定节点集计算区域计量(DMA)分区方案"
)
async def fastapi_calculate_district_metering_area_for_nodes(
network: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> list[list[str]]:
"""
计算节点DMA分区。
请求体格式:
{
"nodes": 节点ID列表(list[str]),
"part_count": 分区数量(int),
"part_type": 分区类型(int)
}
"""
props = await req.json()
nodes = props["nodes"]
part_count = props["part_count"]
+353 -36
View File
@@ -1,105 +1,422 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_reservoir,
delete_reservoir,
get_all_reservoirs,
get_reservoir,
get_reservoir_schema,
set_reservoir,
)
router = APIRouter()
@router.get("/getreservoirschema")
async def fast_get_reservoir_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/getreservoirschema",
summary="获取水库模式",
description="获取指定供水网络中所有水库的模式/属性字段定义"
)
async def fast_get_reservoir_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""
获取水库模式定义。
该端点返回指定网络中水库对象的模式定义,包括所有可用的属性字段。
Args:
network: 管网名称(或数据库名称)
Returns:
水库属性的模式定义字典
"""
return get_reservoir_schema(network)
@router.post("/addreservoir/", response_model=None)
@router.post(
"/addreservoir/",
response_model=None,
summary="添加水库",
description="在指定供水网络中添加新的水库/水源节点"
)
async def fastapi_add_reservoir(
network: str, reservoir: str, x: float, y: float, head: float
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符"),
x: float = Query(..., description="水库的X坐标"),
y: float = Query(..., description="水库的Y坐标"),
head: float = Query(..., description="水库的水头/总水头(米)")
) -> ChangeSet:
"""
添加新的水库/水源节点。
在指定的供水网络中创建一个新的水库,并设置其坐标和水头参数。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
x: 水库的X坐标位置
y: 水库的Y坐标位置
head: 水库的供水水头(以米为单位)
Returns:
包含操作变更集的ChangeSet对象
"""
ps = {"id": reservoir, "x": x, "y": y, "head": head}
return add_reservoir(network, ChangeSet(ps))
@router.post("/deletereservoir/", response_model=None)
async def fastapi_delete_reservoir(network: str, reservoir: str) -> ChangeSet:
@router.post(
"/deletereservoir/",
response_model=None,
summary="删除水库",
description="从指定供水网络中删除指定的水库/水源节点"
)
async def fastapi_delete_reservoir(
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="要删除的水库的唯一标识符")
) -> ChangeSet:
"""
删除指定的水库节点。
从指定的供水网络中删除一个水库及其相关的所有连接关系。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
Returns:
包含操作变更集的ChangeSet对象
"""
ps = {"id": reservoir}
return delete_reservoir(network, ChangeSet(ps))
@router.get("/getreservoirhead/")
async def fastapi_get_reservoir_head(network: str, reservoir: str) -> float | None:
@router.get(
"/getreservoirhead/",
summary="获取水库水头",
description="获取指定水库的供水水头/总水头值"
)
async def fastapi_get_reservoir_head(
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符")
) -> float | None:
"""
获取水库的水头参数。
返回指定水库的供水水头(总水头),单位为米。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
Returns:
水库的水头值(米),如果水库不存在则返回None
"""
ps = get_reservoir(network, reservoir)
return ps["head"]
@router.get("/getreservoirpattern/")
async def fastapi_get_reservoir_pattern(network: str, reservoir: str) -> str | None:
@router.get(
"/getreservoirpattern/",
summary="获取水库模式",
description="获取指定水库的运行模式/供水模式"
)
async def fastapi_get_reservoir_pattern(
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符")
) -> str | None:
"""
获取水库的运行模式。
返回指定水库的供水模式,如固定水头模式、时间序列模式等。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
Returns:
水库的运行模式字符串,如果水库不存在则返回None
"""
ps = get_reservoir(network, reservoir)
return ps["pattern"]
@router.get("/getreservoirx/")
@router.get(
"/getreservoirx/",
summary="获取水库X坐标",
description="获取指定水库的X坐标位置"
)
async def fastapi_get_reservoir_x(
network: str, reservoir: str
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符")
) -> dict[str, float] | None:
"""
获取水库的X坐标。
返回指定水库的X轴坐标值。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
Returns:
水库的X坐标值,如果水库不存在则返回None
"""
ps = get_reservoir(network, reservoir)
return ps["x"]
@router.get("/getreservoiry/")
@router.get(
"/getreservoiry/",
summary="获取水库Y坐标",
description="获取指定水库的Y坐标位置"
)
async def fastapi_get_reservoir_y(
network: str, reservoir: str
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符")
) -> dict[str, float] | None:
"""
获取水库的Y坐标。
返回指定水库的Y轴坐标值。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
Returns:
水库的Y坐标值,如果水库不存在则返回None
"""
ps = get_reservoir(network, reservoir)
return ps["y"]
@router.get("/getreservoircoord/")
@router.get(
"/getreservoircoord/",
summary="获取水库坐标",
description="获取指定水库的平面坐标(X和Y坐标)"
)
async def fastapi_get_reservoir_coord(
network: str, reservoir: str
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符")
) -> dict[str, float] | None:
"""
获取水库的坐标。
返回指定水库的平面坐标,包含水库ID、X坐标和Y坐标。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
Returns:
包含water库ID和X、Y坐标的字典,如果水库不存在则返回None
"""
ps = get_reservoir(network, reservoir)
coord = {"id": reservoir, "x": ps["x"], "y": ps["y"]}
return coord
@router.post("/setreservoirhead/", response_model=None)
@router.post(
"/setreservoirhead/",
response_model=None,
summary="设置水库水头",
description="更新指定水库的供水水头/总水头值"
)
async def fastapi_set_reservoir_head(
network: str, reservoir: str, head: float
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符"),
head: float = Query(..., description="新的水头值(米)")
) -> ChangeSet:
"""
设置水库的水头参数。
更新指定水库的供水水头(总水头)值。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
head: 新的水头值(以米为单位)
Returns:
包含操作变更集的ChangeSet对象
"""
ps = {"id": reservoir, "head": head}
return set_reservoir(network, ChangeSet(ps))
@router.post("/setreservoirpattern/", response_model=None)
@router.post(
"/setreservoirpattern/",
response_model=None,
summary="设置水库模式",
description="更新指定水库的运行模式/供水模式"
)
async def fastapi_set_reservoir_pattern(
network: str, reservoir: str, pattern: str
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符"),
pattern: str = Query(..., description="新的运行模式")
) -> ChangeSet:
"""
设置水库的运行模式。
更新指定水库的供水模式,如固定水头模式、时间序列模式等。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
pattern: 新的运行模式字符串
Returns:
包含操作变更集的ChangeSet对象
"""
ps = {"id": reservoir, "pattern": pattern}
return set_reservoir(network, ChangeSet(ps))
@router.post("/setreservoirx/", response_model=None)
async def fastapi_set_reservoir_x(network: str, reservoir: str, x: float) -> ChangeSet:
@router.post(
"/setreservoirx/",
response_model=None,
summary="设置水库X坐标",
description="更新指定水库的X坐标位置"
)
async def fastapi_set_reservoir_x(
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符"),
x: float = Query(..., description="新的X坐标值")
) -> ChangeSet:
"""
设置水库的X坐标。
更新指定水库的X轴坐标值。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
x: 新的X坐标值
Returns:
包含操作变更集的ChangeSet对象
"""
ps = {"id": reservoir, "x": x}
return set_reservoir(network, ChangeSet(ps))
@router.post("/setreservoiry/", response_model=None)
async def fastapi_set_reservoir_y(network: str, reservoir: str, y: float) -> ChangeSet:
@router.post(
"/setreservoiry/",
response_model=None,
summary="设置水库Y坐标",
description="更新指定水库的Y坐标位置"
)
async def fastapi_set_reservoir_y(
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符"),
y: float = Query(..., description="新的Y坐标值")
) -> ChangeSet:
"""
设置水库的Y坐标。
更新指定水库的Y轴坐标值。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
y: 新的Y坐标值
Returns:
包含操作变更集的ChangeSet对象
"""
ps = {"id": reservoir, "y": y}
return set_reservoir(network, ChangeSet(ps))
@router.post("/setreservoircoord/", response_model=None)
@router.post(
"/setreservoircoord/",
response_model=None,
summary="设置水库坐标",
description="更新指定水库的平面坐标(X和Y坐标)"
)
async def fastapi_set_reservoir_coord(
network: str, reservoir: str, x: float, y: float
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符"),
x: float = Query(..., description="新的X坐标值"),
y: float = Query(..., description="新的Y坐标值")
) -> ChangeSet:
"""
设置水库的坐标。
更新指定水库的平面坐标,包括X和Y坐标。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
x: 新的X坐标值
y: 新的Y坐标值
Returns:
包含操作变更集的ChangeSet对象
"""
ps = {"id": reservoir, "x": x, "y": y}
return set_reservoir(network, ChangeSet(ps))
@router.get("/getreservoirproperties/")
@router.get(
"/getreservoirproperties/",
summary="获取水库属性",
description="获取指定水库的所有属性"
)
async def fastapi_get_reservoir_properties(
network: str, reservoir: str
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符")
) -> dict[str, Any]:
"""
获取水库的所有属性。
返回指定水库的完整属性信息,包括ID、坐标、水头、模式等所有属性。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
Returns:
包含水库所有属性的字典
"""
return get_reservoir(network, reservoir)
@router.get("/getallreservoirproperties/")
async def fastapi_get_all_reservoir_properties(network: str) -> list[dict[str, Any]]:
# 缓存查询结果提高性能
# global redis_client
@router.get(
"/getallreservoirproperties/",
summary="获取所有水库属性",
description="获取指定供水网络中所有水库的属性"
)
async def fastapi_get_all_reservoir_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取所有水库的属性。
返回指定供水网络中所有水库的完整属性信息列表。
Args:
network: 管网名称(或数据库名称)
Returns:
包含所有水库属性的字典列表
"""
results = get_all_reservoirs(network)
return results
@router.post("/setreservoirproperties/", response_model=None)
@router.post(
"/setreservoirproperties/",
response_model=None,
summary="设置水库属性",
description="批量更新指定水库的多个属性"
)
async def fastapi_set_reservoir_properties(
network: str, reservoir: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
reservoir: str = Query(..., description="水库的唯一标识符"),
req: Request = None
) -> ChangeSet:
"""
设置水库的多个属性。
批量更新指定水库的属性。属性通过JSON请求体传递。
Args:
network: 管网名称(或数据库名称)
reservoir: 水库的唯一标识符
req: HTTP请求对象,包含JSON格式的属性数据
Returns:
包含操作变更集的ChangeSet对象
"""
props = await req.json()
ps = {"id": reservoir} | props
return set_reservoir(network, ChangeSet(ps))
+49 -10
View File
@@ -1,6 +1,13 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
get_tag,
get_tag_schema,
get_tags,
set_tag,
)
router = APIRouter()
@@ -8,20 +15,52 @@ router = APIRouter()
# tag 8.[TAGS]
############################################################
@router.get("/gettagschema/")
async def fastapi_get_tag_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/gettagschema/",
summary="获取标签属性架构",
description="获取指定水网的标签(Tag)属性架构定义"
)
async def fastapi_get_tag_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""获取标签的属性架构。"""
return get_tag_schema(network)
@router.get("/gettag/")
async def fastapi_get_tag(network: str, t_type: str, id: str) -> dict[str, Any]:
@router.get(
"/gettag/",
summary="获取标签信息",
description="获取指定类型和ID的标签信息"
)
async def fastapi_get_tag(
network: str = Query(..., description="管网名称(或数据库名称)"),
t_type: str = Query(..., description="标签类型"),
id: str = Query(..., description="元素ID")
) -> dict[str, Any]:
"""获取标签信息。"""
return get_tag(network, t_type, id)
@router.get("/gettags/")
async def fastapi_get_tags(network: str) -> list[dict[str, Any]]:
@router.get(
"/gettags/",
summary="获取所有标签",
description="获取指定水网中的所有标签信息"
)
async def fastapi_get_tags(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""获取水网中所有标签的列表。"""
tags = get_tags(network)
return tags
@router.post("/settag/", response_model=None)
async def fastapi_set_tag(network: str, req: Request) -> ChangeSet:
@router.post(
"/settag/",
response_model=None,
summary="设置标签",
description="为指定元素设置或修改标签信息"
)
async def fastapi_set_tag(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""设置标签信息。"""
props = await req.json()
return set_tag(network, ChangeSet(props))
+445 -67
View File
@@ -1,26 +1,62 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
add_tank,
delete_tank,
get_all_tanks,
get_tank,
get_tank_schema,
set_tank,
)
router = APIRouter()
@router.get("/gettankschema")
async def fast_get_tank_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/gettankschema", summary="获取水箱模式", description="获取指定网络的水箱数据结构模式定义")
async def fast_get_tank_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
"""
获取水箱的数据结构模式。
Args:
network: 管网名称(或数据库名称)
Returns:
包含水箱属性的模式定义字典
"""
return get_tank_schema(network)
@router.post("/addtank/", response_model=None)
@router.post("/addtank/", summary="新增水箱", description="向指定网络中新增一个水箱", response_model=None)
async def fastapi_add_tank(
network: str,
tank: str,
x: float,
y: float,
elevation: float,
init_level: float = 0,
min_level: float = 0,
max_level: float = 0,
diameter: float = 0,
min_vol: float = 0,
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
x: float = Query(..., description="X坐标"),
y: float = Query(..., description="Y坐标"),
elevation: float = Query(..., description="标高"),
init_level: float = Query(0, description="初始水位"),
min_level: float = Query(0, description="最小水位"),
max_level: float = Query(0, description="最大水位"),
diameter: float = Query(0, description="直径"),
min_vol: float = Query(0, description="最小体积"),
) -> ChangeSet:
"""
创建新水箱。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
x: X坐标
y: Y坐标
elevation: 水箱标高
init_level: 初始水位,默认为0
min_level: 最小水位,默认为0
max_level: 最大水位,默认为0
diameter: 水箱直径,默认为0
min_vol: 最小体积,默认为0
Returns:
包含变更信息的ChangeSet对象
"""
ps = {
"id": tank,
"x": x,
@@ -34,155 +70,497 @@ async def fastapi_add_tank(
}
return add_tank(network, ChangeSet(ps))
@router.post("/deletetank/", response_model=None)
async def fastapi_delete_tank(network: str, tank: str) -> ChangeSet:
@router.post("/deletetank/", summary="删除水箱", description="删除指定网络中的水箱", response_model=None)
async def fastapi_delete_tank(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> ChangeSet:
"""
删除指定的水箱。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank}
return delete_tank(network, ChangeSet(ps))
@router.get("/gettankelevation/")
async def fastapi_get_tank_elevation(network: str, tank: str) -> float | None:
@router.get("/gettankelevation/", summary="获取水箱标高", description="获取指定水箱的标高值")
async def fastapi_get_tank_elevation(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float | None:
"""
获取水箱的标高。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱标高值,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["elevation"]
@router.get("/gettankinitlevel/")
async def fastapi_get_tank_init_level(network: str, tank: str) -> float | None:
@router.get("/gettankinitlevel/", summary="获取水箱初始水位", description="获取指定水箱的初始水位值")
async def fastapi_get_tank_init_level(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float | None:
"""
获取水箱的初始水位。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱初始水位值,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["init_level"]
@router.get("/gettankminlevel/")
async def fastapi_get_tank_min_level(network: str, tank: str) -> float | None:
@router.get("/gettankminlevel/", summary="获取水箱最小水位", description="获取指定水箱的最小水位值")
async def fastapi_get_tank_min_level(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float | None:
"""
获取水箱的最小水位。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱最小水位值,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["min_level"]
@router.get("/gettankmaxlevel/")
async def fastapi_get_tank_max_level(network: str, tank: str) -> float | None:
@router.get("/gettankmaxlevel/", summary="获取水箱最大水位", description="获取指定水箱的最大水位值")
async def fastapi_get_tank_max_level(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float | None:
"""
获取水箱的最大水位。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱最大水位值,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["max_level"]
@router.get("/gettankdiameter/")
async def fastapi_get_tank_diameter(network: str, tank: str) -> float | None:
@router.get("/gettankdiameter/", summary="获取水箱直径", description="获取指定水箱的直径值")
async def fastapi_get_tank_diameter(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float | None:
"""
获取水箱的直径。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱直径值,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["diameter"]
@router.get("/gettankminvol/")
async def fastapi_get_tank_min_vol(network: str, tank: str) -> float | None:
@router.get("/gettankminvol/", summary="获取水箱最小体积", description="获取指定水箱的最小体积值")
async def fastapi_get_tank_min_vol(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float | None:
"""
获取水箱的最小体积。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱最小体积值,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["min_vol"]
@router.get("/gettankvolcurve/")
async def fastapi_get_tank_vol_curve(network: str, tank: str) -> str | None:
@router.get("/gettankvolcurve/", summary="获取水箱容积曲线", description="获取指定水箱的容积曲线标识")
async def fastapi_get_tank_vol_curve(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> str | None:
"""
获取水箱的容积曲线。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱容积曲线标识,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["vol_curve"]
@router.get("/gettankoverflow/")
async def fastapi_get_tank_overflow(network: str, tank: str) -> str | None:
@router.get("/gettankoverflow/", summary="获取水箱溢流口", description="获取指定水箱的溢流口配置")
async def fastapi_get_tank_overflow(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> str | None:
"""
获取水箱的溢流口配置。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱溢流口配置,如果不存在则返回None
"""
ps = get_tank(network, tank)
return ps["overflow"]
@router.get("/gettankx/")
async def fastapi_get_tank_x(network: str, tank: str) -> float:
@router.get("/gettankx/", summary="获取水箱X坐标", description="获取指定水箱的X坐标值")
async def fastapi_get_tank_x(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float:
"""
获取水箱的X坐标。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱X坐标值
"""
ps = get_tank(network, tank)
return ps["x"]
@router.get("/gettanky/")
async def fastapi_get_tank_y(network: str, tank: str) -> float:
@router.get("/gettanky/", summary="获取水箱Y坐标", description="获取指定水箱的Y坐标值")
async def fastapi_get_tank_y(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> float:
"""
获取水箱的Y坐标。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
水箱Y坐标值
"""
ps = get_tank(network, tank)
return ps["y"]
@router.get("/gettankcoord/")
async def fastapi_get_tank_coord(network: str, tank: str) -> dict[str, float]:
@router.get("/gettankcoord/", summary="获取水箱坐标", description="获取指定水箱的X和Y坐标")
async def fastapi_get_tank_coord(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> dict[str, float]:
"""
获取水箱的坐标。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
包含x和y坐标的字典
"""
ps = get_tank(network, tank)
coord = {"x": ps["x"], "y": ps["y"]}
return coord
@router.post("/settankelevation/", response_model=None)
@router.post("/settankelevation/", summary="设置水箱标高", description="设置指定水箱的标高值", response_model=None)
async def fastapi_set_tank_elevation(
network: str, tank: str, elevation: float
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
elevation: float = Query(..., description="新的标高值")
) -> ChangeSet:
"""
设置水箱的标高。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
elevation: 新的标高值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "elevation": elevation}
return set_tank(network, ChangeSet(ps))
@router.post("/settankinitlevel/", response_model=None)
@router.post("/settankinitlevel/", summary="设置水箱初始水位", description="设置指定水箱的初始水位值", response_model=None)
async def fastapi_set_tank_init_level(
network: str, tank: str, init_level: float
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
init_level: float = Query(..., description="新的初始水位值")
) -> ChangeSet:
"""
设置水箱的初始水位。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
init_level: 新的初始水位值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "init_level": init_level}
return set_tank(network, ChangeSet(ps))
@router.post("/settankminlevel/", response_model=None)
@router.post("/settankminlevel/", summary="设置水箱最小水位", description="设置指定水箱的最小水位值", response_model=None)
async def fastapi_set_tank_min_level(
network: str, tank: str, min_level: float
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
min_level: float = Query(..., description="新的最小水位值")
) -> ChangeSet:
"""
设置水箱的最小水位。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
min_level: 新的最小水位值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "min_level": min_level}
return set_tank(network, ChangeSet(ps))
@router.post("/settankmaxlevel/", response_model=None)
@router.post("/settankmaxlevel/", summary="设置水箱最大水位", description="设置指定水箱的最大水位值", response_model=None)
async def fastapi_set_tank_max_level(
network: str, tank: str, max_level: float
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
max_level: float = Query(..., description="新的最大水位值")
) -> ChangeSet:
"""
设置水箱的最大水位。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
max_level: 新的最大水位值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "max_level": max_level}
return set_tank(network, ChangeSet(ps))
@router.post("settankdiameter//", response_model=None)
@router.post("/settankdiameter/", summary="设置水箱直径", description="设置指定水箱的直径值", response_model=None)
async def fastapi_set_tank_diameter(
network: str, tank: str, diameter: float
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
diameter: float = Query(..., description="新的直径值")
) -> ChangeSet:
"""
设置水箱的直径。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
diameter: 新的直径值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "diameter": diameter}
return set_tank(network, ChangeSet(ps))
@router.post("/settankminvol/", response_model=None)
@router.post("/settankminvol/", summary="设置水箱最小体积", description="设置指定水箱的最小体积值", response_model=None)
async def fastapi_set_tank_min_vol(
network: str, tank: str, min_vol: float
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
min_vol: float = Query(..., description="新的最小体积值")
) -> ChangeSet:
"""
设置水箱的最小体积。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
min_vol: 新的最小体积值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "min_vol": min_vol}
return set_tank(network, ChangeSet(ps))
@router.post("/settankvolcurve/", response_model=None)
@router.post("/settankvolcurve/", summary="设置水箱容积曲线", description="设置指定水箱的容积曲线标识", response_model=None)
async def fastapi_set_tank_vol_curve(
network: str, tank: str, vol_curve: str
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
vol_curve: str = Query(..., description="新的容积曲线标识")
) -> ChangeSet:
"""
设置水箱的容积曲线。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
vol_curve: 新的容积曲线标识
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "vol_curve": vol_curve}
return set_tank(network, ChangeSet(ps))
@router.post("/settankoverflow/", response_model=None)
@router.post("/settankoverflow/", summary="设置水箱溢流口", description="设置指定水箱的溢流口配置", response_model=None)
async def fastapi_set_tank_overflow(
network: str, tank: str, overflow: str
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
overflow: str = Query(..., description="新的溢流口配置")
) -> ChangeSet:
"""
设置水箱的溢流口配置。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
overflow: 新的溢流口配置
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "overflow": overflow}
return set_tank(network, ChangeSet(ps))
@router.post("/settankx/", response_model=None)
async def fastapi_set_tank_x(network: str, tank: str, x: float) -> ChangeSet:
@router.post("/settankx/", summary="设置水箱X坐标", description="设置指定水箱的X坐标值", response_model=None)
async def fastapi_set_tank_x(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
x: float = Query(..., description="新的X坐标值")
) -> ChangeSet:
"""
设置水箱的X坐标。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
x: 新的X坐标值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "x": x}
return set_tank(network, ChangeSet(ps))
@router.post("/settanky/", response_model=None)
async def fastapi_set_tank_y(network: str, tank: str, y: float) -> ChangeSet:
@router.post("/settanky/", summary="设置水箱Y坐标", description="设置指定水箱的Y坐标值", response_model=None)
async def fastapi_set_tank_y(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
y: float = Query(..., description="新的Y坐标值")
) -> ChangeSet:
"""
设置水箱的Y坐标。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
y: 新的Y坐标值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "y": y}
return set_tank(network, ChangeSet(ps))
@router.post("/settankcoord/", response_model=None)
@router.post("/settankcoord/", summary="设置水箱坐标", description="设置指定水箱的X和Y坐标", response_model=None)
async def fastapi_set_tank_coord(
network: str, tank: str, x: float, y: float
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
x: float = Query(..., description="新的X坐标值"),
y: float = Query(..., description="新的Y坐标值")
) -> ChangeSet:
"""
设置水箱的坐标。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
x: 新的X坐标值
y: 新的Y坐标值
Returns:
包含变更信息的ChangeSet对象
"""
ps = {"id": tank, "x": x, "y": y}
return set_tank(network, ChangeSet(ps))
@router.get("/gettankproperties/")
async def fastapi_get_tank_properties(network: str, tank: str) -> dict[str, Any]:
@router.get("/gettankproperties/", summary="获取水箱属性", description="获取指定水箱的所有属性")
async def fastapi_get_tank_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID")
) -> dict[str, Any]:
"""
获取水箱的所有属性。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
Returns:
包含水箱所有属性的字典
"""
return get_tank(network, tank)
@router.get("/getalltankproperties/")
async def fastapi_get_all_tank_properties(network: str) -> list[dict[str, Any]]:
@router.get("/getalltankproperties/", summary="获取所有水箱属性", description="获取指定网络中所有水箱的属性")
async def fastapi_get_all_tank_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取网络中所有水箱的属性。
Args:
network: 管网名称(或数据库名称)
Returns:
包含所有水箱属性的字典列表
"""
# 缓存查询结果提高性能
# global redis_client
results = get_all_tanks(network)
return results
@router.post("/settankproperties/", response_model=None)
@router.post("/settankproperties/", summary="设置水箱属性", description="批量设置指定水箱的多个属性", response_model=None)
async def fastapi_set_tank_properties(
network: str, tank: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
tank: str = Query(..., description="水箱ID"),
req: Request = None
) -> ChangeSet:
"""
批量设置水箱的属性。
Args:
network: 管网名称(或数据库名称)
tank: 水箱ID
req: 包含水箱属性的请求体(JSON格式)
Returns:
包含变更信息的ChangeSet对象
"""
props = await req.json()
ps = {"id": tank} | props
return set_tank(network, ChangeSet(ps))
+260 -43
View File
@@ -1,24 +1,55 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query, Path, Body
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import (
Any,
ChangeSet,
VALVES_TYPE_PRV,
add_valve,
delete_valve,
get_all_valves,
get_valve,
get_valve_schema,
set_valve,
)
router = APIRouter()
@router.get("/getvalveschema")
async def fastapi_get_valve_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get(
"/getvalveschema",
summary="获取阀门架构",
description="获取指定水网中所有阀门的架构和字段定义",
)
async def fastapi_get_valve_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""
获取阀门架构。
返回指定水网中所有阀门类型的完整架构定义,包括字段名称、类型和默认值。
"""
return get_valve_schema(network)
@router.post("/addvalve/", response_model=None)
@router.post(
"/addvalve/",
response_model=None,
summary="添加阀门",
description="在指定的水网中添加新的阀门",
)
async def fastapi_add_valve(
network: str,
valve: str,
node1: str,
node2: str,
diameter: float = 0,
v_type: str = VALVES_TYPE_PRV,
setting: float = 0,
minor_loss: float = 0,
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
node1: str = Query(..., description="起点节点ID"),
node2: str = Query(..., description="终点节点ID"),
diameter: float = Query(0, description="阀门直径(mm"),
v_type: str = Query(VALVES_TYPE_PRV, description="阀门类型"),
setting: float = Query(0, description="阀门开度/设置值"),
minor_loss: float = Query(0, description="损失系数"),
) -> ChangeSet:
"""
添加新的阀门。
在指定的水网中创建一个新的阀门,设置其连接的两个节点、直径、类型、开度和损失系数。
"""
ps = {
"id": valve,
"node1": node1,
@@ -31,85 +62,271 @@ async def fastapi_add_valve(
return add_valve(network, ChangeSet(ps))
@router.post("/deletevalve/", response_model=None)
async def fastapi_delete_valve(network: str, valve: str) -> ChangeSet:
@router.post(
"/deletevalve/",
response_model=None,
summary="删除阀门",
description="从指定的水网中删除指定的阀门",
)
async def fastapi_delete_valve(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> ChangeSet:
"""
删除阀门。
从指定的水网中删除指定ID的阀门。
"""
ps = {"id": valve}
return delete_valve(network, ChangeSet(ps))
@router.get("/getvalvenode1/")
async def fastapi_get_valve_node1(network: str, valve: str) -> str | None:
@router.get(
"/getvalvenode1/",
summary="获取阀门起点节点",
description="获取指定阀门连接的起点节点ID",
)
async def fastapi_get_valve_node1(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> str | None:
"""
获取阀门的起点节点。
返回指定阀门连接的起点(第一个)节点的ID。
"""
ps = get_valve(network, valve)
return ps["node1"]
@router.get("/getvalvenode2/")
async def fastapi_get_valve_node2(network: str, valve: str) -> str | None:
@router.get(
"/getvalvenode2/",
summary="获取阀门终点节点",
description="获取指定阀门连接的终点节点ID",
)
async def fastapi_get_valve_node2(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> str | None:
"""
获取阀门的终点节点。
返回指定阀门连接的终点(第二个)节点的ID。
"""
ps = get_valve(network, valve)
return ps["node2"]
@router.get("/getvalvediameter/")
async def fastapi_get_valve_diameter(network: str, valve: str) -> float | None:
@router.get(
"/getvalvediameter/",
summary="获取阀门直径",
description="获取指定阀门的直径",
)
async def fastapi_get_valve_diameter(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> float | None:
"""
获取阀门的直径。
返回指定阀门的直径值(单位:mm)。
"""
ps = get_valve(network, valve)
return ps["diameter"]
@router.get("/getvalvetype/")
async def fastapi_get_valve_type(network: str, valve: str) -> str | None:
@router.get(
"/getvalvetype/",
summary="获取阀门类型",
description="获取指定阀门的类型",
)
async def fastapi_get_valve_type(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> str | None:
"""
获取阀门的类型。
返回指定阀门的类型(例如:减压阀、调节阀等)。
"""
ps = get_valve(network, valve)
return ps["type"]
@router.get("/getvalvesetting/")
async def fastapi_get_valve_setting(network: str, valve: str) -> float | None:
@router.get(
"/getvalvesetting/",
summary="获取阀门开度",
description="获取指定阀门的开度/设置值",
)
async def fastapi_get_valve_setting(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> float | None:
"""
获取阀门的开度。
返回指定阀门的开度/设置值。
"""
ps = get_valve(network, valve)
return ps["setting"]
@router.get("/getvalveminorloss/")
async def fastapi_get_valve_minor_loss(network: str, valve: str) -> float | None:
@router.get(
"/getvalveminorloss/",
summary="获取阀门损失系数",
description="获取指定阀门的损失系数",
)
async def fastapi_get_valve_minor_loss(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> float | None:
"""
获取阀门的损失系数。
返回指定阀门的损失系数值,用于计算流体通过阀门的压力损失。
"""
ps = get_valve(network, valve)
return ps["minor_loss"]
@router.post("/setvalvenode1/", response_model=None)
async def fastapi_set_valve_node1(network: str, valve: str, node1: str) -> ChangeSet:
@router.post(
"/setvalvenode1/",
response_model=None,
summary="设置阀门起点节点",
description="设置指定阀门的起点节点",
)
async def fastapi_set_valve_node1(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
node1: str = Query(..., description="新的起点节点ID"),
) -> ChangeSet:
"""
设置阀门的起点节点。
更新指定阀门的起点节点连接。
"""
ps = {"id": valve, "node1": node1}
return set_valve(network, ChangeSet(ps))
@router.post("/setvalvenode2/", response_model=None)
async def fastapi_set_valve_node2(network: str, valve: str, node2: str) -> ChangeSet:
@router.post(
"/setvalvenode2/",
response_model=None,
summary="设置阀门终点节点",
description="设置指定阀门的终点节点",
)
async def fastapi_set_valve_node2(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
node2: str = Query(..., description="新的终点节点ID"),
) -> ChangeSet:
"""
设置阀门的终点节点。
更新指定阀门的终点节点连接。
"""
ps = {"id": valve, "node2": node2}
return set_valve(network, ChangeSet(ps))
@router.post("/setvalvenodediameter/", response_model=None)
@router.post(
"/setvalvenodediameter/",
response_model=None,
summary="设置阀门直径",
description="设置指定阀门的直径",
)
async def fastapi_set_valve_diameter(
network: str, valve: str, diameter: float
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
diameter: float = Query(..., description="新的直径值(mm"),
) -> ChangeSet:
"""
设置阀门的直径。
更新指定阀门的直径值。
"""
ps = {"id": valve, "diameter": diameter}
return set_valve(network, ChangeSet(ps))
@router.post("/setvalvetype/", response_model=None)
async def fastapi_set_valve_type(network: str, valve: str, type: str) -> ChangeSet:
@router.post(
"/setvalvetype/",
response_model=None,
summary="设置阀门类型",
description="设置指定阀门的类型",
)
async def fastapi_set_valve_type(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
type: str = Query(..., description="新的阀门类型"),
) -> ChangeSet:
"""
设置阀门的类型。
更新指定阀门的类型(例如:减压阀、调节阀等)。
"""
ps = {"id": valve, "type": type}
return set_valve(network, ChangeSet(ps))
@router.post("/setvalvesetting/", response_model=None)
@router.post(
"/setvalvesetting/",
response_model=None,
summary="设置阀门开度",
description="设置指定阀门的开度/设置值",
)
async def fastapi_set_valve_setting(
network: str, valve: str, setting: float
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
setting: float = Query(..., description="新的开度值"),
) -> ChangeSet:
"""
设置阀门的开度。
更新指定阀门的开度/设置值。
"""
ps = {"id": valve, "setting": setting}
return set_valve(network, ChangeSet(ps))
@router.get("/getvalveproperties/")
async def fastapi_get_valve_properties(network: str, valve: str) -> dict[str, Any]:
@router.get(
"/getvalveproperties/",
summary="获取阀门所有属性",
description="获取指定阀门的所有属性",
)
async def fastapi_get_valve_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
) -> dict[str, Any]:
"""
获取阀门的所有属性。
返回指定阀门的完整属性集合,包括ID、节点、直径、类型、开度和损失系数。
"""
return get_valve(network, valve)
@router.get("/getallvalveproperties/")
async def fastapi_get_all_valve_properties(network: str) -> list[dict[str, Any]]:
@router.get(
"/getallvalveproperties/",
summary="获取所有阀门属性",
description="获取指定水网中所有阀门的属性",
)
async def fastapi_get_all_valve_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取所有阀门的属性。
返回指定水网中所有阀门的完整属性列表。
"""
# 缓存查询结果提高性能
# global redis_client
results = get_all_valves(network)
return results
@router.post("/setvalveproperties/", response_model=None)
@router.post(
"/setvalveproperties/",
response_model=None,
summary="批量设置阀门属性",
description="批量设置指定阀门的多个属性",
)
async def fastapi_set_valve_properties(
network: str, valve: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
valve: str = Query(..., description="阀门ID"),
req: Request = None,
) -> ChangeSet:
"""
批量设置阀门的属性。
更新指定阀门的一个或多个属性,通过JSON请求体传递要更新的属性。
"""
props = await req.json()
ps = {"id": valve} | props
return set_valve(network, ChangeSet(ps))
+401 -40
View File
@@ -1,12 +1,15 @@
import json
from fastapi import APIRouter, Request, HTTPException
from fastapi import APIRouter, Request, HTTPException, Query, Path, Body, Depends
from fastapi.responses import PlainTextResponse
from typing import Any, Dict
from typing import Any, Dict, List
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
from app.auth.project_dependencies import get_metadata_repository
from app.domain.schemas.metadata import ProjectMetaResponse, GeoServerConfigResponse
import app.services.project_info as project_info
from app.native.api import ChangeSet
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
from app.services.tjnetwork import (
ChangeSet,
list_project,
have_project,
create_project,
@@ -39,30 +42,106 @@ inpDir = "data/" # Assuming data directory exists or is defined somewhere.
router = APIRouter()
lockedPrjs: Dict[str, str] = {}
@router.get("/listprojects/")
@router.get("/project_info/", summary="获取项目信息", description="从数据库获取项目的详细信息,包括地图范围等。", response_model=ProjectMetaResponse)
async def get_project_info_endpoint(
network: str = Query(..., description="管网名称(或项目代码)"),
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
):
"""
获取项目信息
- **network**: 管网名称(或项目代码)
"""
project_detail = await metadata_repo.get_project_detail_by_code(network)
if not project_detail:
raise HTTPException(status_code=404, detail=f"Project {network} not found")
geoserver_payload = None
if project_detail.geoserver:
geoserver_payload = GeoServerConfigResponse(
gs_base_url=project_detail.geoserver.gs_base_url,
gs_admin_user=project_detail.geoserver.gs_admin_user,
gs_datastore_name=project_detail.geoserver.gs_datastore_name,
default_extent=project_detail.geoserver.default_extent,
srid=project_detail.geoserver.srid,
)
return ProjectMetaResponse(
project_id=project_detail.project_id,
name=project_detail.name,
code=project_detail.code,
description=project_detail.description,
gs_workspace=project_detail.gs_workspace,
map_extent=project_detail.map_extent,
status=project_detail.status,
project_role="viewer", # Default role for public access
geoserver=geoserver_payload
)
@router.get("/listprojects/", summary="获取项目列表", description="获取服务器上所有可用的供水管网项目名称列表。")
async def list_projects_endpoint() -> list[str]:
"""
获取项目列表
返回所有已创建项目的名称列表。
"""
return list_project()
@router.get("/haveproject/")
async def have_project_endpoint(network: str):
@router.get("/haveproject/", summary="检查项目是否存在", description="检查指定名称的项目是否存在。")
async def have_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
):
"""
检查项目是否存在
- **network**: 管网名称(或数据库名称)
"""
return have_project(network)
@router.post("/createproject/")
async def create_project_endpoint(network: str):
@router.post("/createproject/", summary="创建新项目", description="创建一个新的供水管网项目。如果项目已存在,可能会覆盖或报错(取决于底层实现)。")
async def create_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
):
"""
创建新项目
- **network**: 管网名称(或数据库名称)
"""
create_project(network)
return network
@router.post("/deleteproject/")
async def delete_project_endpoint(network: str):
@router.post("/deleteproject/", summary="删除项目", description="永久删除指定的供水管网项目。此操作不可恢复。")
async def delete_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
):
"""
删除项目
- **network**: 管网名称(或数据库名称)
"""
delete_project(network)
return True
@router.get("/isprojectopen/")
async def is_project_open_endpoint(network: str):
@router.get("/isprojectopen/", summary="检查项目是否已打开", description="检查指定项目是否已被加载到内存中。")
async def is_project_open_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
):
"""
检查项目是否已打开
- **network**: 管网名称(或数据库名称)
"""
return is_project_open(network)
@router.post("/openproject/")
async def open_project_endpoint(network: str):
@router.post("/openproject/", summary="打开项目", description="将指定项目加载到内存中,并初始化数据库连接池。")
async def open_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
):
"""
打开项目
- **network**: 管网名称(或数据库名称)
"""
open_project(network)
# 尝试连接指定数据库
@@ -88,18 +167,43 @@ async def open_project_endpoint(network: str):
return network
@router.post("/closeproject/")
async def close_project_endpoint(network: str):
@router.post("/closeproject/", summary="关闭项目", description="将指定项目从内存中卸载,释放资源。")
async def close_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)")
):
"""
关闭项目
- **network**: 管网名称(或数据库名称)
"""
close_project(network)
return True
@router.post("/copyproject/")
async def copy_project_endpoint(source: str, target: str):
@router.post("/copyproject/", summary="复制项目", description="将现有项目复制为新项目。")
async def copy_project_endpoint(
source: str = Query(..., description="管网名称(或数据库名称)"),
target: str = Query(..., description="管网名称(或数据库名称)")
):
"""
复制项目
- **source**: 管网名称(或数据库名称)
- **target**: 管网名称(或数据库名称)
"""
copy_project(source, target)
return True
@router.post("/importinp/")
async def import_inp_endpoint(network: str, req: Request):
@router.post("/importinp/", summary="导入 INP 文件内容", description="将 INP 格式的文本内容导入到指定项目中。")
async def import_inp_endpoint(
req: Request,
network: str = Query(..., description="管网名称(或数据库名称)")
):
"""
导入 INP 文件内容
- **network**: 管网名称(或数据库名称)
- **req**: 请求体,需包含 `{"inp": "..."}` 结构
"""
jo_root = await req.json()
inp_text = jo_root["inp"]
ps = {"inp": inp_text}
@@ -107,8 +211,17 @@ async def import_inp_endpoint(network: str, req: Request):
print(ret)
return ret
@router.get("/exportinp/", response_model=None)
async def export_inp_endpoint(network: str, version: str) -> ChangeSet:
@router.get("/exportinp/", response_model=None, summary="导出项目为 ChangeSet", description="导出项目的变更集 (ChangeSet),包含顶点、SCADA 元素、DMA、SA、VD 等信息。")
async def export_inp_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
version: str = Query(..., description="版本号 (通常用于增量更新)")
) -> ChangeSet:
"""
导出项目为 ChangeSet
- **network**: 管网名称(或数据库名称)
- **version**: 版本号
"""
cs = export_inp(network, version)
op = cs.operations[0]
open_project(network)
@@ -131,30 +244,75 @@ async def export_inp_endpoint(network: str, version: str) -> ChangeSet:
return cs
@router.post("/readinp/")
async def read_inp_endpoint(network: str, inp: str) -> bool:
@router.post("/readinp/", summary="读取 INP 文件到项目", description="从服务器文件系统中读取指定的 INP 文件并加载到项目中。")
async def read_inp_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
inp: str = Query(..., description="INP 文件名 (不包含路径)")
) -> bool:
"""
读取 INP 文件到项目
- **network**: 管网名称(或数据库名称)
- **inp**: INP 文件名
"""
read_inp(network, inp)
return True
@router.get("/dumpinp/")
async def dump_inp_endpoint(network: str, inp: str) -> bool:
@router.get("/dumpinp/", summary="导出项目到 INP 文件", description="将项目当前状态保存为 INP 文件到服务器文件系统。")
async def dump_inp_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
inp: str = Query(..., description="目标文件名")
) -> bool:
"""
导出项目到 INP 文件
- **network**: 管网名称(或数据库名称)
- **inp**: 目标文件名
"""
dump_inp(network, inp)
return True
@router.get("/isprojectlocked/")
async def is_project_locked_endpoint(network: str, req: Request):
@router.get("/isprojectlocked/", summary="检查项目是否被锁定", description="检查指定项目是否处于锁定状态。")
async def is_project_locked_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
检查项目是否被锁定
- **network**: 管网名称(或数据库名称)
"""
return network in lockedPrjs.keys()
@router.get("/isprojectlockedbyme/")
async def is_project_locked_by_me_endpoint(network: str, req: Request):
@router.get("/isprojectlockedbyme/", summary="检查项目是否被当前用户锁定", description="检查指定项目是否被当前客户端 (IP) 锁定。")
async def is_project_locked_by_me_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
检查项目是否被当前用户锁定
- **network**: 管网名称(或数据库名称)
"""
client_host = req.client.host
return lockedPrjs.get(network) == client_host
# 0 successfully locked
# 1 already locked by you
# 2 locked by others
@router.post("/lockproject/")
async def lock_project_endpoint(network: str, req: Request):
@router.post("/lockproject/", summary="锁定项目", description="锁定指定项目以防止并发修改。")
async def lock_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
锁定项目
返回值:
- **0**: 锁定成功
- **1**: 已被当前用户锁定
- **2**: 已被其他用户锁定
"""
client_host = req.client.host
if not network in lockedPrjs.keys():
lockedPrjs[network] = client_host
@@ -165,8 +323,16 @@ async def lock_project_endpoint(network: str, req: Request):
else:
return 2
@router.post("/unlockproject/")
def unlock_project_endpoint(network: str, req: Request):
@router.post("/unlockproject/", summary="解锁项目", description="释放对项目的锁定。")
def unlock_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
解锁项目
只有锁定者才能解锁。
"""
client_host = req.client.host
if lockedPrjs.get(network) == client_host:
print("delete key")
@@ -176,8 +342,17 @@ def unlock_project_endpoint(network: str, req: Request):
return False
# inp file operations
@router.post("/uploadinp/", status_code=status.HTTP_200_OK)
async def fastapi_upload_inp(afile: bytes, name: str):
@router.post("/uploadinp/", status_code=status.HTTP_200_OK, summary="上传 INP 文件", description="上传 INP 文件到服务器数据目录。")
async def fastapi_upload_inp(
afile: bytes = Body(..., description="文件二进制内容"),
name: str = Query(..., description="保存的文件名")
):
"""
上传 INP 文件
- **afile**: 文件内容
- **name**: 文件名
"""
if not os.path.exists(inpDir):
os.makedirs(inpDir, exist_ok=True)
@@ -186,8 +361,16 @@ async def fastapi_upload_inp(afile: bytes, name: str):
f.write(afile)
return True
@router.get("/downloadinp/", status_code=status.HTTP_200_OK)
async def fastapi_download_inp(name: str, response: Response):
@router.get("/downloadinp/", status_code=status.HTTP_200_OK, summary="下载 INP 文件", description="从服务器数据目录下载指定的 INP 文件。")
async def fastapi_download_inp(
name: str = Query(..., description="文件名"),
response: Response = None
):
"""
下载 INP 文件
- **name**: 文件名
"""
filePath = inpDir + name
if os.path.exists(filePath):
return FileResponse(
@@ -198,8 +381,186 @@ async def fastapi_download_inp(name: str, response: Response):
return True
# DingZQ, 2024-12-28, convert v3 to v2
@router.get("/convertv3tov2/", response_model=None)
async def fastapi_convert_v3_to_v2(req: Request) -> ChangeSet:
@router.get("/convertv3tov2/", response_model=None, summary="转换 INP V3 为 V2", description="将 EPANET 3.0 格式的 INP 内容转换为 2.x 格式。")
async def fastapi_convert_v3_to_v2(
req: Request
) -> ChangeSet:
"""
转换 INP V3 为 V2
- **req**: 请求体,需包含 `{"inp": "..."}` 结构
"""
network = "v3Tov2"
jo_root = await req.json()
inp = jo_root["inp"]
cs = convert_inp_v3_to_v2(inp)
op = cs.operations[0]
open_project(network)
op["vertex"] = json.dumps(get_all_vertices(network))
op["scada"] = json.dumps(get_all_scada_elements(network))
op["dma"] = json.dumps(get_all_district_metering_areas(network))
op["sa"] = json.dumps(get_all_service_areas(network))
op["vd"] = json.dumps(get_all_virtual_districts(network))
op["legend"] = get_extension_data(network, "legend")
db = get_extension_data(network, "scada_db")
print(db)
scada_db = ""
if db:
scada_db = db
print(scada_db)
op["scada_db"] = scada_db
close_project(network)
return cs
@router.post("/readinp/", summary="读取 INP 文件到项目", description="从服务器文件系统中读取指定的 INP 文件并加载到项目中。")
async def read_inp_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
inp: str = Query(..., description="INP 文件名 (不包含路径)")
) -> bool:
"""
读取 INP 文件到项目
- **network**: 管网名称(或数据库名称)
- **inp**: INP 文件名
"""
read_inp(network, inp)
return True
@router.get("/dumpinp/", summary="导出项目到 INP 文件", description="将项目当前状态保存为 INP 文件到服务器文件系统。")
async def dump_inp_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
inp: str = Query(..., description="目标文件名")
) -> bool:
"""
导出项目到 INP 文件
- **network**: 管网名称(或数据库名称)
- **inp**: 目标文件名
"""
dump_inp(network, inp)
return True
@router.get("/isprojectlocked/", summary="检查项目是否被锁定", description="检查指定项目是否处于锁定状态。")
async def is_project_locked_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
检查项目是否被锁定
- **network**: 管网名称(或数据库名称)
"""
return network in lockedPrjs.keys()
@router.get("/isprojectlockedbyme/", summary="检查项目是否被当前用户锁定", description="检查指定项目是否被当前客户端 (IP) 锁定。")
async def is_project_locked_by_me_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
检查项目是否被当前用户锁定
- **network**: 管网名称(或数据库名称)
"""
client_host = req.client.host
return lockedPrjs.get(network) == client_host
# 0 successfully locked
# 1 already locked by you
# 2 locked by others
@router.post("/lockproject/", summary="锁定项目", description="锁定指定项目以防止并发修改。")
async def lock_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
锁定项目
返回值:
- **0**: 锁定成功
- **1**: 已被当前用户锁定
- **2**: 已被其他用户锁定
"""
client_host = req.client.host
if not network in lockedPrjs.keys():
lockedPrjs[network] = client_host
return 0
else:
if lockedPrjs.get(network) == client_host:
return 1
else:
return 2
@router.post("/unlockproject/", summary="解锁项目", description="释放对项目的锁定。")
def unlock_project_endpoint(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
):
"""
解锁项目
只有锁定者才能解锁。
"""
client_host = req.client.host
if lockedPrjs.get(network) == client_host:
print("delete key")
del lockedPrjs[network]
return True
return False
# inp file operations
@router.post("/uploadinp/", status_code=status.HTTP_200_OK, summary="上传 INP 文件", description="上传 INP 文件到服务器数据目录。")
async def fastapi_upload_inp(
afile: bytes = Body(..., description="文件二进制内容"),
name: str = Query(..., description="保存的文件名")
):
"""
上传 INP 文件
- **afile**: 文件内容
- **name**: 文件名
"""
if not os.path.exists(inpDir):
os.makedirs(inpDir, exist_ok=True)
filePath = inpDir + str(name)
with open(filePath, "wb") as f:
f.write(afile)
return True
@router.get("/downloadinp/", status_code=status.HTTP_200_OK, summary="下载 INP 文件", description="从服务器数据目录下载指定的 INP 文件。")
async def fastapi_download_inp(
name: str = Query(..., description="文件名"),
response: Response = None
):
"""
下载 INP 文件
- **name**: 文件名
"""
filePath = inpDir + name
if os.path.exists(filePath):
return FileResponse(
filePath, media_type="application/octet-stream", filename="inp.inp"
)
else:
response.status_code = status.HTTP_400_BAD_REQUEST
return True
# DingZQ, 2024-12-28, convert v3 to v2
@router.get("/convertv3tov2/", response_model=None, summary="转换 INP V3 为 V2", description="将 EPANET 3.0 格式的 INP 内容转换为 2.x 格式。")
async def fastapi_convert_v3_to_v2(
req: Request
) -> ChangeSet:
"""
转换 INP V3 为 V2
- **req**: 请求体,需包含 `{"inp": "..."}` 结构
"""
network = "v3Tov2"
jo_root = await req.json()
inp = jo_root["inp"]
@@ -1,30 +1,34 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from psycopg import AsyncConnection
from .scada_info import ScadaRepository
from .scheme import SchemeRepository
import app.native.wndb as wndb
from app.infra.db.postgresql.scheme import SchemeRepository
from app.auth.project_dependencies import get_project_pg_connection
from app.services import project_info
router = APIRouter()
# 动态项目 PostgreSQL 连接依赖
async def get_database_connection(
conn: AsyncConnection = Depends(get_project_pg_connection),
):
"""获取数据库连接"""
yield conn
@router.get("/scada-info")
@router.get("/scada-info", summary="获取SCADA信息", description="使用连接池查询所有SCADA信息")
async def get_scada_info_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有SCADA信息
获取所有SCADA信息
返回项目中所有的SCADA设备信息
"""
try:
# 使用ScadaRepository查询SCADA信息
scada_data = await ScadaRepository.get_scadas(conn)
_ = conn
network_name = project_info.name
scada_data = wndb.get_all_scada_info(network_name) if network_name else []
return {"success": True, "data": scada_data, "count": len(scada_data)}
except Exception as e:
raise HTTPException(
@@ -32,30 +36,32 @@ async def get_scada_info_with_connection(
)
@router.get("/scheme-list")
@router.get("/scheme-list", summary="获取方案列表", description="使用连接池查询所有方案信息")
async def get_scheme_list_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有方案信息
获取所有方案信息
返回项目中所有方案的详细信息
"""
try:
# 使用SchemeRepository查询方案信息
scheme_data = await SchemeRepository.get_schemes(conn)
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}")
@router.get("/burst-locate-result")
@router.get("/burst-locate-result", summary="获取爆管定位结果", description="使用连接池查询所有爆管定位结果")
async def get_burst_locate_result_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有爆管定位结果
获取所有爆管定位结果
返回项目中所有的爆管定位分析结果
"""
try:
# 使用SchemeRepository查询爆管定位结果
burst_data = await SchemeRepository.get_burst_locate_results(conn)
return {"success": True, "data": burst_data, "count": len(burst_data)}
except Exception as e:
@@ -64,16 +70,18 @@ async def get_burst_locate_result_with_connection(
)
@router.get("/burst-locate-result/{burst_incident}")
@router.get("/burst-locate-result/{burst_incident}", summary="按事件查询爆管定位结果", description="根据爆管事件ID查询对应的爆管定位结果")
async def get_burst_locate_result_by_incident(
burst_incident: str,
burst_incident: str = Path(..., description="爆管事件ID"),
conn: AsyncConnection = Depends(get_database_connection),
):
"""
根据 burst_incident 查询爆管定位结果
根据爆管事件ID查询爆管定位结果
参数:
burst_incident: 爆管事件的唯一标识符
"""
try:
# 使用SchemeRepository查询爆管定位结果
return await SchemeRepository.get_burst_locate_result_by_incident(
conn, burst_incident
)
+94 -11
View File
@@ -1,5 +1,5 @@
from typing import Any, List, Dict
from fastapi import APIRouter
from fastapi import APIRouter, Query, Path
from app.services.tjnetwork import (
get_pipe_risk_probability_now,
get_pipe_risk_probability,
@@ -10,35 +10,118 @@ from app.services.tjnetwork import (
router = APIRouter()
@router.get("/getpiperiskprobabilitynow/")
@router.get(
"/getpiperiskprobabilitynow/",
summary="获取管道当前风险概率",
description="获取指定管道当前时刻的风险概率值"
)
async def fastapi_get_pipe_risk_probability_now(
network: str, pipe_id: str
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe_id: str = Query(..., description="管道ID")
) -> dict[str, Any]:
"""
获取管道当前风险概率。
查询指定管道在当前时刻的风险概率值。
Args:
network: 管网名称(或数据库名称)
pipe_id: 管道ID
Returns:
包含风险概率信息的字典
"""
return get_pipe_risk_probability_now(network, pipe_id)
@router.get("/getpiperiskprobability/")
@router.get(
"/getpiperiskprobability/",
summary="获取管道风险概率历史",
description="获取指定管道的风险概率历史数据"
)
async def fastapi_get_pipe_risk_probability(
network: str, pipe_id: str
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe_id: str = Query(..., description="管道ID")
) -> dict[str, Any]:
"""
获取管道风险概率历史。
查询指定管道的历史风险概率数据。
Args:
network: 管网名称(或数据库名称)
pipe_id: 管道ID
Returns:
包含风险概率历史的字典
"""
return get_pipe_risk_probability(network, pipe_id)
@router.get("/getpipesriskprobability/")
@router.get(
"/getpipesriskprobability/",
summary="批量获取多条管道风险概率",
description="批量获取多条管道的风险概率值"
)
async def fastapi_get_pipes_risk_probability(
network: str, pipe_ids: str
network: str = Query(..., description="管网名称(或数据库名称)"),
pipe_ids: str = Query(..., description="逗号分隔的管道ID列表")
) -> list[dict[str, Any]]:
"""
批量获取多条管道风险概率。
查询多条指定管道的风险概率值。
Args:
network: 管网名称(或数据库名称)
pipe_ids: 逗号分隔的管道ID列表(例如:pipe1,pipe2,pipe3
Returns:
包含多条管道风险概率的列表
"""
pipeids = pipe_ids.split(",")
return get_pipes_risk_probability(network, pipeids)
@router.get("/getnetworkpiperiskprobabilitynow/")
@router.get(
"/getnetworkpiperiskprobabilitynow/",
summary="获取整个网络的管道风险概率",
description="获取指定网络中所有管道的当前风险概率值"
)
async def fastapi_get_network_pipe_risk_probability_now(
network: str,
network: str = Query(..., description="管网名称(或数据库名称)"),
) -> list[dict[str, Any]]:
"""
获取整个网络的管道风险概率。
查询指定网络中所有管道在当前时刻的风险概率值。
Args:
network: 管网名称(或数据库名称)
Returns:
包含网络内所有管道风险概率的列表
"""
return get_network_pipe_risk_probability_now(network)
@router.get("/getpiperiskprobabilitygeometries/")
async def fastapi_get_pipe_risk_probability_geometries(network: str) -> dict[str, Any]:
@router.get(
"/getpiperiskprobabilitygeometries/",
summary="获取管道风险几何信息",
description="获取指定网络中管道的风险相关几何数据"
)
async def fastapi_get_pipe_risk_probability_geometries(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, Any]:
"""
获取管道风险几何信息。
查询指定网络中管道的地理和风险相关的几何数据。
Args:
network: 管网名称(或数据库名称)
Returns:
包含几何信息和风险数据的字典
"""
return get_pipe_risk_probability_geometries(network)
+416 -58
View File
@@ -1,7 +1,7 @@
from typing import Any
from fastapi import APIRouter, Request
from app.native.api import ChangeSet
from fastapi import APIRouter, Request, Query
from app.services.tjnetwork import (
ChangeSet,
get_scada_info,
get_all_scada_info,
get_scada_device_schema,
@@ -31,139 +31,497 @@ from app.services.tjnetwork import (
router = APIRouter()
@router.get("/getscadaproperties/")
async def fast_get_scada_properties(network: str, scada: str) -> dict[str, Any]:
@router.get("/getscadaproperties/", summary="获取SCADA属性", tags=["SCADA基础"])
async def fast_get_scada_properties(
network: str = Query(..., description="管网名称(或数据库名称)"),
scada: str = Query(..., description="SCADA设备ID")
) -> dict[str, Any]:
"""
获取单个SCADA设备的属性信息
根据管网名称和SCADA设备ID获取该设备的完整属性。
Args:
network: 管网名称(或数据库名称)
scada: SCADA设备ID
Returns:
SCADA设备的属性字典
"""
return get_scada_info(network, scada)
@router.get("/getallscadaproperties/")
async def fast_get_all_scada_properties(network: str) -> list[dict[str, Any]]:
@router.get("/getallscadaproperties/", summary="获取所有SCADA属性", tags=["SCADA基础"])
async def fast_get_all_scada_properties(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取指定管网所有SCADA设备的属性信息
查询该管网下所有已配置的SCADA设备的属性列表。
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA设备属性列表
"""
return get_all_scada_info(network)
############################################################
# scada_device 29
# scada_device 设备管理
############################################################
@router.get("/getscadadeviceschema/")
async def fastapi_get_scada_device_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getscadadeviceschema/", summary="获取SCADA设备架构", tags=["SCADA设备"])
async def fastapi_get_scada_device_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""
获取SCADA设备的数据架构
返回SCADA设备表的字段定义和类型信息。
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA设备的字段架构信息
"""
return get_scada_device_schema(network)
@router.get("/getscadadevice/")
async def fastapi_get_scada_device(network: str, id: str) -> dict[str, Any]:
@router.get("/getscadadevice/", summary="获取SCADA设备", tags=["SCADA设备"])
async def fastapi_get_scada_device(
network: str = Query(..., description="管网名称(或数据库名称)"),
id: str = Query(..., description="SCADA设备ID")
) -> dict[str, Any]:
"""
获取单个SCADA设备的信息
根据设备ID查询该设备的详细信息。
Args:
network: 管网名称(或数据库名称)
id: SCADA设备ID
Returns:
SCADA设备信息
"""
return get_scada_device(network, id)
@router.post("/setscadadevice/", response_model=None)
async def fastapi_set_scada_device(network: str, req: Request) -> ChangeSet:
@router.post("/setscadadevice/", response_model=None, summary="更新SCADA设备", tags=["SCADA设备"])
async def fastapi_set_scada_device(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
更新SCADA设备信息
修改指定SCADA设备的属性。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含要更新的设备属性
Returns:
变更集合信息
"""
props = await req.json()
return set_scada_device(network, ChangeSet(props))
@router.post("/addscadadevice/", response_model=None)
async def fastapi_add_scada_device(network: str, req: Request) -> ChangeSet:
@router.post("/addscadadevice/", response_model=None, summary="添加SCADA设备", tags=["SCADA设备"])
async def fastapi_add_scada_device(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
添加新的SCADA设备
在指定管网中添加一个新的SCADA设备。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含新设备的属性
Returns:
变更集合信息
"""
props = await req.json()
return add_scada_device(network, ChangeSet(props))
@router.post("/deletescadadevice/", response_model=None)
async def fastapi_delete_scada_device(network: str, req: Request) -> ChangeSet:
@router.post("/deletescadadevice/", response_model=None, summary="删除SCADA设备", tags=["SCADA设备"])
async def fastapi_delete_scada_device(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
删除SCADA设备
从指定管网中删除一个SCADA设备。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含要删除的设备ID
Returns:
变更集合信息
"""
props = await req.json()
return delete_scada_device(network, ChangeSet(props))
@router.post("/cleanscadadevice/", response_model=None)
async def fastapi_clean_scada_device(network: str) -> ChangeSet:
@router.post("/cleanscadadevice/", response_model=None, summary="清空SCADA设备表", tags=["SCADA设备"])
async def fastapi_clean_scada_device(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> ChangeSet:
"""
清空SCADA设备表
删除指定管网中所有的SCADA设备。
Args:
network: 管网名称(或数据库名称)
Returns:
变更集合信息
"""
return clean_scada_device(network)
@router.get("/getallscadadeviceids/")
async def fastapi_get_all_scada_device_ids(network: str) -> list[str]:
@router.get("/getallscadadeviceids/", summary="获取所有SCADA设备ID", tags=["SCADA设备"])
async def fastapi_get_all_scada_device_ids(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[str]:
"""
获取指定管网所有SCADA设备的ID列表
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA设备ID列表
"""
return get_all_scada_device_ids(network)
@router.get("/getallscadadevices/")
async def fastapi_get_all_scada_devices(network: str) -> list[dict[str, Any]]:
@router.get("/getallscadadevices/", summary="获取所有SCADA设备", tags=["SCADA设备"])
async def fastapi_get_all_scada_devices(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取指定管网所有SCADA设备的完整信息
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA设备信息列表
"""
return get_all_scada_devices(network)
############################################################
# scada_device_data 30
# scada_device_data 设备数据管理
############################################################
@router.get("/getscadadevicedataschema/")
@router.get("/getscadadevicedataschema/", summary="获取SCADA设备数据架构", tags=["SCADA设备数据"])
async def fastapi_get_scada_device_data_schema(
network: str,
network: str = Query(..., description="管网名称(或数据库名称)"),
) -> dict[str, dict[str, Any]]:
"""
获取SCADA设备数据的表结构
返回SCADA设备数据表的字段定义和类型信息。
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA设备数据的字段架构信息
"""
return get_scada_device_data_schema(network)
@router.get("/getscadadevicedata/")
async def fastapi_get_scada_device_data(network: str, device_id: str) -> dict[str, Any]:
@router.get("/getscadadevicedata/", summary="获取SCADA设备数据", tags=["SCADA设备数据"])
async def fastapi_get_scada_device_data(
network: str = Query(..., description="管网名称(或数据库名称)"),
device_id: str = Query(..., description="SCADA设备ID")
) -> dict[str, Any]:
"""
获取单个SCADA设备的数据
查询指定设备的监测数据或配置数据。
Args:
network: 管网名称(或数据库名称)
device_id: SCADA设备ID
Returns:
SCADA设备数据
"""
return get_scada_device_data(network, device_id)
@router.post("/setscadadevicedata/", response_model=None)
async def fastapi_set_scada_device_data(network: str, req: Request) -> ChangeSet:
@router.post("/setscadadevicedata/", response_model=None, summary="更新SCADA设备数据", tags=["SCADA设备数据"])
async def fastapi_set_scada_device_data(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
更新SCADA设备数据
修改指定SCADA设备的数据。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含要更新的数据
Returns:
变更集合信息
"""
props = await req.json()
return set_scada_device_data(network, ChangeSet(props))
@router.post("/addscadadevicedata/", response_model=None)
async def fastapi_add_scada_device_data(network: str, req: Request) -> ChangeSet:
@router.post("/addscadadevicedata/", response_model=None, summary="添加SCADA设备数据", tags=["SCADA设备数据"])
async def fastapi_add_scada_device_data(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
添加新的SCADA设备数据
为指定SCADA设备添加新的数据记录。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含新数据的内容
Returns:
变更集合信息
"""
props = await req.json()
return add_scada_device_data(network, ChangeSet(props))
@router.post("/deletescadadevicedata/", response_model=None)
async def fastapi_delete_scada_device_data(network: str, req: Request) -> ChangeSet:
@router.post("/deletescadadevicedata/", response_model=None, summary="删除SCADA设备数据", tags=["SCADA设备数据"])
async def fastapi_delete_scada_device_data(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
删除SCADA设备数据
删除指定SCADA设备的数据记录。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含要删除的数据ID
Returns:
变更集合信息
"""
props = await req.json()
return delete_scada_device_data(network, ChangeSet(props))
@router.post("/cleanscadadevicedata/", response_model=None)
async def fastapi_clean_scada_device_data(network: str) -> ChangeSet:
@router.post("/cleanscadadevicedata/", response_model=None, summary="清空SCADA设备数据表", tags=["SCADA设备数据"])
async def fastapi_clean_scada_device_data(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> ChangeSet:
"""
清空SCADA设备数据表
删除指定管网中所有SCADA设备的数据。
Args:
network: 管网名称(或数据库名称)
Returns:
变更集合信息
"""
return clean_scada_device_data(network)
############################################################
# scada_element 31
# scada_element SCADA元素映射
############################################################
@router.get("/getscadaelementschema/")
@router.get("/getscadaelementschema/", summary="获取SCADA元素架构", tags=["SCADA元素映射"])
async def fastapi_get_scada_element_schema(
network: str,
network: str = Query(..., description="管网名称(或数据库名称)"),
) -> dict[str, dict[str, Any]]:
"""
获取SCADA元素映射的表结构
返回SCADA元素映射表的字段定义和类型信息。
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA元素映射的字段架构信息
"""
return get_scada_element_schema(network)
@router.get("/getscadaelements/")
async def fastapi_get_scada_elements(network: str) -> list[dict[str, Any]]:
@router.get("/getscadaelements/", summary="获取所有SCADA元素映射", tags=["SCADA元素映射"])
async def fastapi_get_scada_elements(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取指定管网所有SCADA元素映射
查询所有SCADA设备与管网元素(节点/管道)的映射关系。
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA元素映射列表
"""
return get_all_scada_elements(network)
@router.get("/getscadaelement/")
async def fastapi_get_scada_element(network: str, id: str) -> dict[str, Any]:
@router.get("/getscadaelement/", summary="获取单个SCADA元素映射", tags=["SCADA元素映射"])
async def fastapi_get_scada_element(
network: str = Query(..., description="管网名称(或数据库名称)"),
id: str = Query(..., description="SCADA元素映射ID")
) -> dict[str, Any]:
"""
获取单个SCADA元素映射的信息
根据ID查询特定的SCADA设备与管网元素的映射关系。
Args:
network: 管网名称(或数据库名称)
id: SCADA元素映射ID
Returns:
SCADA元素映射信息
"""
return get_scada_element(network, id)
@router.post("/setscadaelement/", response_model=None)
async def fastapi_set_scada_element(network: str, req: Request) -> ChangeSet:
@router.post("/setscadaelement/", response_model=None, summary="更新SCADA元素映射", tags=["SCADA元素映射"])
async def fastapi_set_scada_element(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
更新SCADA元素映射
修改SCADA设备与管网元素的映射关系。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含要更新的映射信息
Returns:
变更集合信息
"""
props = await req.json()
return set_scada_element(network, ChangeSet(props))
@router.post("/addscadaelement/", response_model=None)
async def fastapi_add_scada_element(network: str, req: Request) -> ChangeSet:
@router.post("/addscadaelement/", response_model=None, summary="添加SCADA元素映射", tags=["SCADA元素映射"])
async def fastapi_add_scada_element(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
添加新的SCADA元素映射
创建SCADA设备与管网元素的新映射关系。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含新映射的信息
Returns:
变更集合信息
"""
props = await req.json()
return add_scada_element(network, ChangeSet(props))
@router.post("/deletescadaelement/", response_model=None)
async def fastapi_delete_scada_element(network: str, req: Request) -> ChangeSet:
@router.post("/deletescadaelement/", response_model=None, summary="删除SCADA元素映射", tags=["SCADA元素映射"])
async def fastapi_delete_scada_element(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
删除SCADA元素映射
移除SCADA设备与管网元素的映射关系。
Args:
network: 管网名称(或数据库名称)
req: 请求体,包含要删除的映射ID
Returns:
变更集合信息
"""
props = await req.json()
return delete_scada_element(network, ChangeSet(props))
@router.post("/cleanscadaelement/", response_model=None)
async def fastapi_clean_scada_element(network: str) -> ChangeSet:
@router.post("/cleanscadaelement/", response_model=None, summary="清空SCADA元素映射表", tags=["SCADA元素映射"])
async def fastapi_clean_scada_element(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> ChangeSet:
"""
清空SCADA元素映射表
删除指定管网中所有的SCADA元素映射。
Args:
network: 管网名称(或数据库名称)
Returns:
变更集合信息
"""
return clean_scada_element(network)
############################################################
# scada_info 38
# scada_info SCADA信息
############################################################
@router.get("/getscadainfoschema/")
async def fastapi_get_scada_info_schema(network: str) -> dict[str, dict[str, Any]]:
@router.get("/getscadainfoschema/", summary="获取SCADA信息架构", tags=["SCADA信息"])
async def fastapi_get_scada_info_schema(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> dict[str, dict[str, Any]]:
"""
获取SCADA信息表的结构
返回SCADA信息表的字段定义和类型信息。
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA信息的字段架构信息
"""
return get_scada_info_schema(network)
@router.get("/getscadainfo/")
async def fastapi_get_scada_info(network: str, id: str) -> dict[str, Any]:
@router.get("/getscadainfo/", summary="获取SCADA信息", tags=["SCADA信息"])
async def fastapi_get_scada_info(
network: str = Query(..., description="管网名称(或数据库名称)"),
id: str = Query(..., description="SCADA信息ID")
) -> dict[str, Any]:
"""
获取单个SCADA信息
根据ID查询SCADA的详细配置信息。
Args:
network: 管网名称(或数据库名称)
id: SCADA信息ID
Returns:
SCADA信息详情
"""
return get_scada_info(network, id)
@router.get("/getallscadainfo/")
async def fastapi_get_all_scada_info(network: str) -> list[dict[str, Any]]:
@router.get("/getallscadainfo/", summary="获取所有SCADA信息", tags=["SCADA信息"])
async def fastapi_get_all_scada_info(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""
获取指定管网所有SCADA的信息
查询该管网下所有已配置的SCADA的完整信息。
Args:
network: 管网名称(或数据库名称)
Returns:
SCADA信息列表
"""
return get_all_scada_info(network)
+22 -7
View File
@@ -1,17 +1,32 @@
from fastapi import APIRouter
from fastapi import APIRouter, Query
from typing import Any, List, Dict
from app.services.tjnetwork import get_scheme_schema, get_scheme, get_all_schemes
router = APIRouter()
@router.get("/getschemeschema/")
async def fastapi_get_scheme_schema(network: str) -> dict[str, dict[Any, Any]]:
@router.get("/getschemeschema/", summary="获取方案模式", description="获取指定网络的方案模式定义")
async def fastapi_get_scheme_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[Any, Any]]:
"""
获取方案模式定义
返回指定网络的方案模式结构定义
"""
return get_scheme_schema(network)
@router.get("/getscheme/")
async def fastapi_get_scheme(network: str, schema_name: str) -> dict[Any, Any]:
@router.get("/getscheme/", summary="获取单个方案", description="根据名称获取指定的方案信息")
async def fastapi_get_scheme(network: str = Query(..., description="管网名称(或数据库名称)"), schema_name: str = Query(..., description="方案名称")) -> dict[Any, Any]:
"""
获取单个方案详情
返回指定网络中指定名称的方案详细信息
"""
return get_scheme(network, schema_name)
@router.get("/getallschemes/")
async def fastapi_get_all_schemes(network: str) -> list[dict[Any, Any]]:
@router.get("/getallschemes/", summary="获取所有方案", description="获取指定网络的所有方案信息")
async def fastapi_get_all_schemes(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
"""
获取所有方案列表
返回指定网络中所有可用的方案
"""
return get_all_schemes(network)
+433 -289
View File
@@ -4,20 +4,17 @@ import json
import os
import shutil
import threading
import pandas as pd
from fastapi import APIRouter, HTTPException, File, UploadFile, Query
from fastapi import APIRouter, HTTPException, File, UploadFile, Query, Path, Body
from fastapi.responses import PlainTextResponse
import app.infra.db.influxdb.api as influxdb_api
import app.services.simulation as simulation
import app.services.globals as globals
from app.infra.cache.redis_client import redis_client
from app.services.tjnetwork import (
run_project,
run_project_return_dict,
run_inp,
dump_output,
)
from app.algorithms.simulations import (
from app.algorithms.simulation.scenarios import (
burst_analysis,
valve_close_analysis,
flushing_analysis,
@@ -26,12 +23,11 @@ from app.algorithms.simulations import (
# scheduling_analysis,
pressure_regulation,
)
from app.algorithms.sensors import (
from app.algorithms.sensor import (
pressure_sensor_placement_sensitivity,
pressure_sensor_placement_kmeans,
)
import app.algorithms.api_ex.flow_data_clean as flow_data_clean
import app.algorithms.api_ex.pressure_data_clean as pressure_data_clean
from app.services.network_import import network_update
from app.services.simulation_ops import (
project_management,
@@ -39,176 +35,180 @@ from app.services.simulation_ops import (
daily_scheduling_simulation,
)
from app.services.valve_isolation import analyze_valve_isolation
from pydantic import BaseModel
from app.services.time_api import parse_aware_time, parse_utc_time
from pydantic import BaseModel, Field, field_validator
router = APIRouter()
class RunSimulationManuallyByDate(BaseModel):
name: str
simulation_date: str
start_time: str
duration: int
name: str = Field(..., description="管网名称(或数据库名称)")
start_time: str = Field(..., description="开始时间 (ISO 8601 / RFC3339,必须显式带时区)")
duration: int = Field(..., gt=0, description="持续时间 (分钟)")
@field_validator("start_time")
@classmethod
def validate_start_time_timezone(cls, value: str) -> str:
parse_aware_time(value, field_name="start_time")
return value
class BurstAnalysis(BaseModel):
name: str
modify_pattern_start_time: str
burst_ID: List[str] | str | None = None
burst_size: List[float] | float | int | None = None
modify_total_duration: int = 900
modify_fixed_pump_pattern: Optional[dict[str, list]] = None
modify_variable_pump_pattern: Optional[dict[str, list]] = None
modify_valve_opening: Optional[dict[str, float]] = None
scheme_name: Optional[str] = None
name: str = Field(..., description="管网名称(或数据库名称)")
modify_pattern_start_time: str = Field(..., description="模式修改开始时间 (ISO 8601)")
burst_ID: List[str] | str | None = Field(None, description="爆管节点/管段ID列表")
burst_size: List[float] | float | int | None = Field(None, description="爆管流量大小")
modify_total_duration: int = Field(900, description="模拟总时长 (秒)")
modify_fixed_pump_pattern: Optional[dict[str, list]] = Field(None, description="定速泵模式修改")
modify_variable_pump_pattern: Optional[dict[str, list]] = Field(None, description="变速泵模式修改")
modify_valve_opening: Optional[dict[str, float]] = Field(None, description="阀门开度修改")
scheme_name: Optional[str] = Field(None, description="方案名称")
class SchedulingAnalysis(BaseModel):
network: str
start_time: str
pump_control: dict
tank_id: str
water_plant_output_id: str
time_delta: Optional[int] = 300
network: str = Field(..., description="管网名称(或数据库名称)")
start_time: str = Field(..., description="开始时间")
pump_control: dict = Field(..., description="泵控制策略")
tank_id: str = Field(..., description="水箱ID")
water_plant_output_id: str = Field(..., description="水厂出水ID")
time_delta: Optional[int] = Field(300, description="时间步长 (秒)")
class PressureRegulation(BaseModel):
network: str
start_time: str
pump_control: dict
tank_init_level: Optional[dict] = None
duration: Optional[int] = 900
scheme_name: Optional[str] = None
network: str = Field(..., description="管网名称(或数据库名称)")
start_time: str = Field(..., description="开始时间")
pump_control: dict = Field(..., description="泵控制策略")
tank_init_level: Optional[dict] = Field(None, description="水箱初始水位")
duration: Optional[int] = Field(900, description="持续时间 (秒)")
scheme_name: Optional[str] = Field(None, description="方案名称")
class ProjectManagement(BaseModel):
network: str
start_time: str
pump_control: dict
tank_init_level: Optional[dict] = None
region_demand: Optional[dict] = None
network: str = Field(..., description="管网名称(或数据库名称)")
start_time: str = Field(..., description="开始时间")
pump_control: dict = Field(..., description="泵控制策略")
tank_init_level: Optional[dict] = Field(None, description="水箱初始水位")
region_demand: Optional[dict] = Field(None, description="区域需水量控制")
class DailySchedulingAnalysis(BaseModel):
network: str
start_time: str
pump_control: dict
reservoir_id: str
tank_id: str
water_plant_output_id: str
time_delta: Optional[int] = 300
network: str = Field(..., description="管网名称(或数据库名称)")
start_time: str = Field(..., description="开始时间")
pump_control: dict = Field(..., description="泵控制策略")
reservoir_id: str = Field(..., description="水库ID")
tank_id: str = Field(..., description="水箱ID")
water_plant_output_id: str = Field(..., description="水厂出水ID")
time_delta: Optional[int] = Field(300, description="时间步长 (秒)")
class PumpFailureState(BaseModel):
time: str
pump_status: dict
time: str = Field(..., description="故障发生时间")
pump_status: dict = Field(..., description="泵状态字典")
class PressureSensorPlacement(BaseModel):
name: str
scheme_name: str
sensor_number: int
min_diameter: int = 0
username: str
name: str = Field(..., description="管网名称(或数据库名称)")
scheme_name: str = Field(..., description="方案名称")
sensor_number: int = Field(..., description="传感器数量")
min_diameter: int = Field(0, description="最小管径限制")
username: str = Field(..., description="用户名")
def run_simulation_manually_by_date(
network_name: str, base_date: datetime, start_time: str, duration: int
network_name: str, start_time: datetime, duration: int
) -> None:
time_parts = list(map(int, start_time.split(":")))
if len(time_parts) == 2:
start_hour, start_minute = time_parts
start_second = 0
elif len(time_parts) == 3:
start_hour, start_minute, start_second = time_parts
else:
raise ValueError("Invalid start_time format. Use HH:MM or HH:MM:SS")
start_datetime = base_date.replace(
hour=start_hour, minute=start_minute, second=start_second
)
end_datetime = start_datetime + timedelta(minutes=duration)
current_time = start_datetime
end_datetime = start_time + timedelta(minutes=duration)
current_time = start_time
while current_time < end_datetime:
iso_time = current_time.strftime("%Y-%m-%dT%H:%M:%S") + "+08:00"
simulation.run_simulation(
name=network_name,
simulation_type="realtime",
modify_pattern_start_time=iso_time,
modify_pattern_start_time=current_time.isoformat(timespec="seconds"),
)
current_time += timedelta(minutes=15)
# 必须用这个PlainTextResponse,不然每个key都有引号
@router.get("/runproject/", response_class=PlainTextResponse)
async def run_project_endpoint(network: str) -> str:
lock_key = "exclusive_api_lock"
timeout = 120 # 锁自动过期时间(秒)
# 尝试获取锁(NX=True: 不存在时设置,EX=timeout: 过期时间)
acquired = redis_client.set(lock_key, "locked", nx=True, ex=timeout)
if not acquired:
raise HTTPException(status_code=409, detail="is in simulation")
else:
try:
return run_project(network)
finally:
# 手动释放锁(可选,依赖过期时间自动释放更安全)
redis_client.delete(lock_key)
@router.get("/runproject/", response_class=PlainTextResponse, summary="运行项目模拟", description="基于指定的管网项目运行标准水力模拟,返回纯文本格式的模拟报告。")
async def run_project_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> str:
"""
运行项目模拟
- **network**: 管网名称(或数据库名称)
运行指定管网项目的标准水力模拟并返回文本报告。
"""
return run_project(network)
# DingZQ, 2025-02-04, 返回dict[str, Any]
# output 和 report
# output 是 json
# report 是 text
@router.get("/runprojectreturndict/")
async def run_project_return_dict_endpoint(network: str) -> dict[str, Any]:
lock_key = "exclusive_api_lock"
timeout = 120 # 锁自动过期时间(秒)
# 尝试获取锁(NX=True: 不存在时设置,EX=timeout: 过期时间)
acquired = redis_client.set(lock_key, "locked", nx=True, ex=timeout)
if not acquired:
raise HTTPException(status_code=409, detail="is in simulation")
else:
try:
return run_project_return_dict(network)
finally:
# 手动释放锁(可选,依赖过期时间自动释放更安全)
redis_client.delete(lock_key)
@router.get("/runprojectreturndict/", summary="运行项目模拟(返回字典)", description="基于指定的管网项目运行标准水力模拟,返回JSON格式的字典,包含输出数据和报告文本。")
async def run_project_return_dict_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
"""
运行项目模拟(返回字典)
- **network**: 管网名称(或数据库名称)
返回字典包含:
- output: JSON格式的模拟输出数据
- report: 文本格式的模拟报告
运行指定管网项目的标准水力模拟并返回字典结果。
"""
return run_project_return_dict(network)
# put in inp folder, name without extension
@router.get("/runinp/")
async def run_inp_endpoint(network: str) -> str:
@router.get("/runinp/", summary="运行INP文件", description="运行指定INP文件格式的管网模型进行水力模拟。INP文件应该放在inp文件夹中,参数为文件名不含扩展名。")
async def run_inp_endpoint(network: str = Query(..., description="inp文件名(不含扩展名)")) -> str:
"""
运行INP文件
- **network**: inp文件名(不含扩展名)
从inp文件夹中读取指定的INP文件并运行模拟。
"""
return run_inp(network)
# path is absolute path
@router.get("/dumpoutput/")
async def dump_output_endpoint(output: str) -> str:
@router.get("/dumpoutput/", summary="导出模拟输出", description="导出指定路径的模拟输出文件内容。参数应为绝对路径。")
async def dump_output_endpoint(output: str = Query(..., description="模拟输出文件的绝对路径")) -> str:
"""
导出模拟输出
- **output**: 模拟输出文件的绝对路径
读取并返回指定路径的模拟输出内容。
"""
return dump_output(output)
# Analysis Endpoints
@router.get("/burstanalysis/")
async def burst_analysis_endpoint(
network: str, pipe_id: str, start_time: str, end_time: str, burst_flow: float
):
return burst_analysis(network, pipe_id, start_time, end_time, burst_flow)
@router.get("/burst_analysis/")
@router.get("/burst_analysis/", summary="爆管分析(高级)", description="高级版本的爆管分析,支持在指定时间点修改泵控制模式和阀门开度,以分析这些改变对爆管影响的作用。支持固定泵和变速泵的独立控制。")
async def fastapi_burst_analysis(
network: str = Query(...),
modify_pattern_start_time: str = Query(...),
burst_ID: list[str] = Query(...),
burst_size: list[float] = Query(...),
modify_total_duration: int = Query(...),
scheme_name: str = Query(...),
network: str = Query(..., description="管网名称(或数据库名称)"),
modify_pattern_start_time: str = Query(..., description="模式修改开始时间(ISO 8601格式)"),
burst_ID: list[str] = Query(..., description="爆管节点/管段ID列表"),
burst_size: list[float] = Query(..., description="对应各爆管点的爆管流量大小列表(L/s)"),
modify_total_duration: int = Query(..., description="模拟总时长(秒)"),
scheme_name: str = Query(..., description="分析方案名称"),
) -> str:
"""
爆管分析(高级版本)
- **network**: 管网名称(或数据库名称)
- **modify_pattern_start_time**: 模式修改开始时间
- **burst_ID**: 爆管节点/管段ID列表
- **burst_size**: 爆管流量大小列表(与burst_ID对应)
- **modify_total_duration**: 模拟总时长(秒)
- **scheme_name**: 分析方案名称
支持在指定时间修改泵控制模式和阀门开度。
"""
burst_analysis(
name=network,
modify_pattern_start_time=modify_pattern_start_time,
@@ -220,74 +220,100 @@ async def fastapi_burst_analysis(
return "success"
@router.get("/valvecloseanalysis/")
async def valve_close_analysis_endpoint(
network: str, valve_id: str, start_time: str, end_time: str
):
return valve_close_analysis(network, valve_id, start_time, end_time)
@router.get("/valve_close_analysis/", response_class=PlainTextResponse)
@router.get("/valve_close_analysis/", response_class=PlainTextResponse, summary="阀门关闭分析(高级)", description="高级版本的阀门关闭分析,支持同时关闭多个阀门,并在指定持续时间内进行模拟。返回纯文本格式的分析结果。")
async def fastapi_valve_close_analysis(
network: str,
start_time: str,
valves: List[str] = Query(...),
duration: int | None = None,
network: str = Query(..., description="管网名称(或数据库名称)"),
start_time: str = Query(..., description="阀门关闭开始时间(ISO 8601格式)"),
valves: List[str] = Query(..., description="要关闭的阀门ID列表"),
duration: int | None = Query(None, description="模拟持续时间(秒),默认900秒"),
scheme_name: str = Query(..., description="阀门关闭方案名称"),
) -> str:
"""
阀门关闭分析(高级版本)
- **network**: 管网名称(或数据库名称)
- **start_time**: 阀门关闭开始时间
- **valves**: 要关闭的阀门ID列表
- **duration**: 模拟持续时间(秒,可选,默认900)
- **scheme_name**: 阀门关闭方案名称
支持同时关闭多个阀门进行分析。
"""
result = valve_close_analysis(
name=network,
modify_pattern_start_time=start_time,
modify_total_duration=duration or 900,
modify_valve_opening={valve_id: 0.0 for valve_id in valves},
scheme_name=scheme_name,
)
return result or "success"
@router.get("/valve_isolation_analysis/")
@router.get("/valve_isolation_analysis/", summary="阀门隔离分析", description="分析当发生突发事件时,通过关闭指定阀门进行隔离,确定哪些阀门必须关闭、哪些可选关闭,以及隔离的可行性。")
async def valve_isolation_endpoint(
network: str,
accident_element: List[str] = Query(...),
disabled_valves: List[str] = Query(None),
network: str = Query(..., description="管网名称(或数据库名称)"),
accident_element: List[str] = Query(..., description="发生事故的管段/节点ID列表"),
disabled_valves: List[str] = Query(None, description="已故障的阀门ID列表(可选)"),
):
result = {
"accident_element": "P461309",
"accident_elements": ["P461309"],
"affected_nodes": [
"J316629_A",
"J317037_B",
"J317060_B",
"J408189_B",
"J499996",
"J524940",
"J535933",
"J58841",
],
"isolatable": True,
"must_close_valves": ["210521658", "V12974", "V12986", "V12993"],
"optional_valves": [],
}
"""
阀门隔离分析
- **network**: 管网名称(或数据库名称)
- **accident_element**: 发生事故的管段/节点ID列表
- **disabled_valves**: 已故障的阀门ID列表(可选)
返回隔离方案,包括:
- must_close_valves: 必须关闭的阀门列表
- optional_valves: 可选关闭的阀门列表
- affected_nodes: 受影响的节点列表
- isolatable: 是否可以有效隔离
"""
# result = {
# "accident_element": "P461309",
# "accident_elements": ["P461309"],
# "affected_nodes": [
# "J316629_A",
# "J317037_B",
# "J317060_B",
# "J408189_B",
# "J499996",
# "J524940",
# "J535933",
# "J58841",
# ],
# "isolatable": True,
# "must_close_valves": ["210521658", "V12974", "V12986", "V12993"],
# "optional_valves": [],
# }
result = analyze_valve_isolation(network, accident_element, disabled_valves)
return result
@router.get("/flushinganalysis/")
async def flushing_analysis_endpoint(
network: str, pipe_id: str, start_time: str, duration: float, flow: float
):
return flushing_analysis(network, pipe_id, start_time, duration, flow)
@router.get("/flushing_analysis/", response_class=PlainTextResponse)
@router.get("/flushing_analysis/", response_class=PlainTextResponse, summary="冲洗分析(高级)", description="高级版本的冲洗分析,支持同时开启多个阀门进行冲洗,指定排污节点,并设置固定的冲洗流量。返回纯文本格式的分析结果。")
async def fastapi_flushing_analysis(
network: str,
start_time: str,
valves: List[str] = Query(...),
valves_k: List[float] = Query(...),
drainage_node_ID: str = Query(...),
flush_flow: float = 0,
duration: int | None = None,
scheme_name: str | None = None,
network: str = Query(..., description="管网名称(或数据库名称)"),
start_time: str = Query(..., description="冲洗开始时间(ISO 8601格式)"),
valves: List[str] = Query(..., description="要开启的阀门ID列表"),
valves_k: List[float] = Query(..., description="对应各阀门的开度列表(0-1"),
drainage_node_ID: str = Query(..., description="排污节点ID"),
flush_flow: float = Query(0, description="冲洗流量(L/s),0表示自动计算"),
duration: int | None = Query(None, description="模拟持续时间(秒),默认900秒"),
scheme_name: str = Query(..., description="冲洗方案名称"),
) -> str:
"""
冲洗分析(高级版本)
- **network**: 管网名称(或数据库名称)
- **start_time**: 冲洗开始时间
- **valves**: 要开启的阀门ID列表
- **valves_k**: 各阀门的开度列表(0-1,与valves对应)
- **drainage_node_ID**: 排污节点ID
- **flush_flow**: 冲洗流量(L/s
- **duration**: 模拟持续时间(秒,可选,默认900)
- **scheme_name**: 冲洗方案名称
支持多阀联合冲洗操作。
"""
valve_opening = {
valve_id: float(valves_k[idx]) for idx, valve_id in enumerate(valves)
}
@@ -303,16 +329,29 @@ async def fastapi_flushing_analysis(
return result or "success"
@router.get("/contaminant_simulation/", response_class=PlainTextResponse)
@router.get("/contaminant_simulation/", response_class=PlainTextResponse, summary="污染物模拟", description="对管网中的污染物扩散进行模拟,评估污染源对管网的影响范围和浓度分布。支持指定污染源位置、污染浓度和扩散模式。")
async def fastapi_contaminant_simulation(
network: str,
start_time: str,
source: str,
concentration: float,
duration: int,
scheme_name: str | None = None,
pattern: str | None = None,
network: str = Query(..., description="管网名称(或数据库名称)"),
start_time: str = Query(..., description="污染开始时间(ISO 8601格式)"),
source: str = Query(..., description="污染源节点ID"),
concentration: float = Query(..., description="污染浓度(mg/L"),
duration: int = Query(..., description="模拟持续时间(秒)"),
scheme_name: str = Query(..., description="模拟方案名称"),
pattern: str | None = Query(None, description="污染源模式ID(可选)"),
) -> str:
"""
污染物模拟
- **network**: 管网名称(或数据库名称)
- **start_time**: 污染开始时间
- **source**: 污染源节点ID
- **concentration**: 污染浓度(mg/L
- **duration**: 模拟持续时间(秒)
- **scheme_name**: 模拟方案名称
- **pattern**: 污染源模式ID(可选)
用于评估管网中污染物的传播和影响范围。
"""
result = contaminant_simulation(
name=network,
modify_pattern_start_time=start_time,
@@ -325,15 +364,21 @@ async def fastapi_contaminant_simulation(
return result or "success"
@router.get("/ageanalysis/")
async def age_analysis_endpoint(network: str):
return age_analysis(network)
@router.get("/age_analysis/", response_class=PlainTextResponse)
@router.get("/age_analysis/", response_class=PlainTextResponse, summary="水龄分析(高级)", description="高级版本的水龄分析,在指定时间点进行分析,支持自定义模拟持续时间。返回纯文本格式的分析结果。")
async def fastapi_age_analysis(
network: str, start_time: str, end_time: str, duration: int
network: str = Query(..., description="管网名称(或数据库名称)"),
start_time: str = Query(..., description="分析开始时间(ISO 8601格式)"),
duration: int = Query(..., description="模拟持续时间(秒)"),
) -> str:
"""
水龄分析(高级版本)
- **network**: 管网名称(或数据库名称)
- **start_time**: 分析开始时间
- **duration**: 模拟持续时间(秒)
分析指定时间段内管网中各节点的水体停留时间。
"""
result = age_analysis(network, start_time, duration)
return result or "success"
@@ -343,15 +388,39 @@ async def fastapi_age_analysis(
# return scheduling_analysis(network)
@router.get("/pressureregulation/")
@router.get("/pressureregulation/", summary="压力调节(基础)", description="对管网的压力进行调节分析,通过控制泵的运行来维持目标节点的目标压力。此为基础版本。")
async def pressure_regulation_endpoint(
network: str, target_node: str, target_pressure: float
network: str = Query(..., description="管网名称(或数据库名称)"),
target_node: str = Query(..., description="目标节点ID"),
target_pressure: float = Query(..., description="目标压力值(kPa"),
):
"""
压力调节(基础版本)
- **network**: 管网名称(或数据库名称)
- **target_node**: 目标节点ID
- **target_pressure**: 目标压力值(kPa
通过泵控制维持目标节点的压力。
"""
return pressure_regulation(network, target_node, target_pressure)
@router.post("/pressure_regulation/")
async def fastapi_pressure_regulation(data: PressureRegulation) -> str:
@router.post("/pressure_regulation/", summary="压力调节(高级)", description="高级版本的压力调节分析,通过JSON请求体提供详细的控制参数,包括固定泵和变速泵的独立控制、水箱初始水位等。")
async def fastapi_pressure_regulation(data: PressureRegulation = Body(..., description="压力调节控制参数")) -> str:
"""
压力调节(高级版本)
请求体参数:
- **network**: 管网名称(或数据库名称)
- **start_time**: 控制开始时间
- **pump_control**: 泵控制策略字典
- **tank_init_level**: 水箱初始水位字典(可选)
- **duration**: 模拟持续时间(秒,可选,默认900)
- **scheme_name**: 控制方案名称(可选)
支持固定泵和变速泵的独立控制。
"""
item = data.dict()
simulation.query_corresponding_element_id_and_query_id(item["network"])
fixed_pumps = set(globals.fixed_pumps_id.keys())
@@ -375,13 +444,20 @@ async def fastapi_pressure_regulation(data: PressureRegulation) -> str:
return "success"
@router.get("/projectmanagement/")
async def project_management_endpoint(network: str):
return project_management(network)
@router.post("/project_management/")
async def fastapi_project_management(data: ProjectManagement) -> str:
@router.post("/project_management/", summary="项目管理(高级)", description="高级版本的项目管理,通过JSON请求体提供详细的控制参数,包括泵控制策略、水箱初始水位和区域需水量控制。")
async def fastapi_project_management(data: ProjectManagement = Body(..., description="项目管理控制参数")) -> str:
"""
项目管理(高级版本)
请求体参数:
- **network**: 管网名称(或数据库名称)
- **start_time**: 管理开始时间
- **pump_control**: 泵控制策略字典
- **tank_init_level**: 水箱初始水位字典(可选)
- **region_demand**: 区域需水量控制字典(可选)
支持多维度的项目管理。
"""
item = data.dict()
return project_management(
prj_name=item["network"],
@@ -397,8 +473,21 @@ async def fastapi_project_management(data: ProjectManagement) -> str:
# return daily_scheduling_analysis(network)
@router.post("/scheduling_analysis/")
async def fastapi_scheduling_analysis(data: SchedulingAnalysis) -> str:
@router.post("/scheduling_analysis/", summary="排程分析", description="对管网的供水排程进行分析,优化泵的运行时间和出水流量,平衡水厂出水、水箱进出水,满足用户需求。")
async def fastapi_scheduling_analysis(data: SchedulingAnalysis = Body(..., description="排程分析参数")) -> str:
"""
排程分析
请求体参数:
- **network**: 管网名称(或数据库名称)
- **start_time**: 分析开始时间
- **pump_control**: 泵控制策略字典
- **tank_id**: 水箱ID
- **water_plant_output_id**: 水厂出水ID
- **time_delta**: 时间步长(秒,可选,默认300)
用于优化供水排程。
"""
item = data.dict()
return scheduling_simulation(
item["network"],
@@ -410,8 +499,22 @@ async def fastapi_scheduling_analysis(data: SchedulingAnalysis) -> str:
)
@router.post("/daily_scheduling_analysis/")
async def fastapi_daily_scheduling_analysis(data: DailySchedulingAnalysis) -> str:
@router.post("/daily_scheduling_analysis/", summary="日排程分析", description="对管网的每日供水排程进行分析,优化水库、水厂、水箱和用户需求的协调,制定合理的每日排程方案。")
async def fastapi_daily_scheduling_analysis(data: DailySchedulingAnalysis = Body(..., description="日排程分析参数")) -> str:
"""
日排程分析
请求体参数:
- **network**: 管网名称(或数据库名称)
- **start_time**: 分析开始时间
- **pump_control**: 泵控制策略字典
- **reservoir_id**: 水库ID
- **tank_id**: 水箱ID
- **water_plant_output_id**: 水厂出水ID
- **time_delta**: 时间步长(秒,可选,默认300)
用于制定每日供水排程方案。
"""
item = data.dict()
return daily_scheduling_simulation(
item["network"],
@@ -423,8 +526,15 @@ async def fastapi_daily_scheduling_analysis(data: DailySchedulingAnalysis) -> st
)
@router.post("/network_project/")
async def fastapi_network_project(file: UploadFile = File()) -> str:
@router.post("/network_project/", summary="导入网络项目", description="通过上传INP格式的管网文件导入新的网络项目。系统将自动处理文件并执行模拟。")
async def fastapi_network_project(file: UploadFile = File(..., description="INP格式的管网文件")) -> str:
"""
导入网络项目
- **file**: 上传的INP格式管网文件
系统将上传的文件保存到inp文件夹并执行模拟。
"""
temp_file_dir = "./inp/"
if not os.path.exists(temp_file_dir):
os.mkdir(temp_file_dir)
@@ -435,13 +545,15 @@ async def fastapi_network_project(file: UploadFile = File()) -> str:
return run_inp(temp_file_name)
@router.get("/networkupdate/")
async def network_update_endpoint(network: str):
return network_update(network)
@router.post("/network_update/")
async def fastapi_network_update(file: UploadFile = File()) -> str:
@router.post("/network_update/", summary="管网更新(高级)", description="通过上传更新文件对管网进行高级的更新操作。系统将处理更新文件并应用到数据库。")
async def fastapi_network_update(file: UploadFile = File(..., description="包含管网更新信息的文件")) -> str:
"""
管网更新(高级版本)
- **file**: 包含管网更新信息的文件
系统将处理上传的文件并应用管网更新。
"""
default_folder = "./"
temp_file_name = f'network_update_{datetime.now().strftime("%Y%m%d")}'
temp_file_path = os.path.join(default_folder, temp_file_name)
@@ -459,8 +571,17 @@ async def fastapi_network_update(file: UploadFile = File()) -> str:
# return pump_failure(network, pump_id, time)
@router.post("/pump_failure/")
async def fastapi_pump_failure(data: PumpFailureState) -> str:
@router.post("/pump_failure/", summary="泵故障管理", description="记录和管理泵的故障状态,包括故障发生时间和受影响的泵列表。系统将记录故障日志并更新泵状态。")
async def fastapi_pump_failure(data: PumpFailureState = Body(..., description="泵故障状态信息")) -> str:
"""
泵故障管理
请求体参数:
- **time**: 故障发生时间
- **pump_status**: 泵状态字典,包含第一阶段和第二阶段泵的故障状态
系统将验证泵信息的有效性并更新故障状态文件。
"""
item = data.dict()
with open("./pump_failure_message.txt", "a", encoding="utf-8-sig") as f1:
f1.write("[{}] {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), item))
@@ -494,19 +615,46 @@ async def fastapi_pump_failure(data: PumpFailureState) -> str:
return json.dumps("SUCCESS")
@router.get("/pressuresensorplacementsensitivity/")
@router.get("/pressuresensorplacementsensitivity/", summary="压力传感器放置-灵敏度分析(基础)", description="基于灵敏度分析方法,为指定管网项目确定最优的压力传感器放置位置。此为基础版本。")
async def pressure_sensor_placement_sensitivity_endpoint(
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
name: str = Query(..., description="管网名称(或数据库名称)"),
scheme_name: str = Query(..., description="放置方案名称"),
sensor_number: int = Query(..., description="传感器数量"),
min_diameter: int = Query(..., description="最小管径限制(毫米)"),
username: str = Query(..., description="用户名"),
):
"""
压力传感器放置-灵敏度分析(基础版本)
- **name**: 管网名称(或数据库名称)
- **scheme_name**: 放置方案名称
- **sensor_number**: 传感器数量
- **min_diameter**: 最小管径限制(毫米)
- **username**: 用户名
基于灵敏度分析方法确定传感器放置位置。
"""
return pressure_sensor_placement_sensitivity(
name, scheme_name, sensor_number, min_diameter, username
)
@router.post("/pressure_sensor_placement_sensitivity/")
@router.post("/pressure_sensor_placement_sensitivity/", summary="压力传感器放置-灵敏度分析(高级)", description="高级版本的压力传感器放置分析,通过JSON请求体提供详细参数。基于灵敏度分析方法确定最优放置位置。")
async def fastapi_pressure_sensor_placement_sensitivity(
data: PressureSensorPlacement,
data: PressureSensorPlacement = Body(..., description="传感器放置分析参数"),
) -> None:
"""
压力传感器放置-灵敏度分析(高级版本)
请求体参数:
- **name**: 管网名称(或数据库名称)
- **scheme_name**: 放置方案名称
- **sensor_number**: 传感器数量
- **min_diameter**: 最小管径限制(毫米)
- **username**: 用户名
基于灵敏度分析方法确定压力传感器的最优放置位置。
"""
item = data.dict()
pressure_sensor_placement_sensitivity(
name=item["name"],
@@ -517,19 +665,46 @@ async def fastapi_pressure_sensor_placement_sensitivity(
)
@router.get("/pressuresensorplacementkmeans/")
@router.get("/pressuresensorplacementkmeans/", summary="压力传感器放置-KMeans聚类分析(基础)", description="基于KMeans聚类算法,为指定管网项目确定压力传感器的最优放置位置。此为基础版本。")
async def pressure_sensor_placement_kmeans_endpoint(
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
name: str = Query(..., description="管网名称(或数据库名称)"),
scheme_name: str = Query(..., description="放置方案名称"),
sensor_number: int = Query(..., description="传感器数量"),
min_diameter: int = Query(..., description="最小管径限制(毫米)"),
username: str = Query(..., description="用户名"),
):
"""
压力传感器放置-KMeans聚类分析(基础版本)
- **name**: 管网名称(或数据库名称)
- **scheme_name**: 放置方案名称
- **sensor_number**: 传感器数量
- **min_diameter**: 最小管径限制(毫米)
- **username**: 用户名
基于KMeans聚类算法确定传感器放置位置。
"""
return pressure_sensor_placement_kmeans(
name, scheme_name, sensor_number, min_diameter, username
)
@router.post("/pressure_sensor_placement_kmeans/")
@router.post("/pressure_sensor_placement_kmeans/", summary="压力传感器放置-KMeans聚类分析(高级)", description="高级版本的压力传感器放置分析,通过JSON请求体提供详细参数。基于KMeans聚类算法确定最优放置位置。")
async def fastapi_pressure_sensor_placement_kmeans(
data: PressureSensorPlacement,
data: PressureSensorPlacement = Body(..., description="传感器放置分析参数"),
) -> None:
"""
压力传感器放置-KMeans聚类分析(高级版本)
请求体参数:
- **name**: 管网名称(或数据库名称)
- **scheme_name**: 放置方案名称
- **sensor_number**: 传感器数量
- **min_diameter**: 最小管径限制(毫米)
- **username**: 用户名
基于KMeans聚类算法确定压力传感器的最优放置位置。
"""
item = data.dict()
pressure_sensor_placement_kmeans(
name=item["name"],
@@ -540,16 +715,31 @@ async def fastapi_pressure_sensor_placement_kmeans(
)
@router.post("/sensorplacementscheme/create")
@router.post("/sensorplacementscheme/create", summary="传感器放置方案创建", description="创建新的传感器放置方案,支持灵敏度分析和KMeans聚类两种方法。根据指定的方法自动计算最优的传感器放置位置。")
async def fastapi_pressure_sensor_placement(
network: str = Query(...),
scheme_name: str = Query(...),
sensor_type: str = Query(...),
method: str = Query(...),
sensor_count: int = Query(...),
min_diameter: int = Query(0),
user_name: str = Query(...),
network: str = Query(..., description="管网名称(或数据库名称)"),
scheme_name: str = Query(..., description="放置方案名称"),
sensor_type: str = Query(..., description="传感器类型"),
method: str = Query(..., description="放置方法('sensitivity''kmeans'"),
sensor_count: int = Query(..., description="传感器数量"),
min_diameter: int = Query(0, description="最小管径限制(毫米),默认0"),
user_name: str = Query(..., description="用户名"),
) -> str:
"""
传感器放置方案创建
- **network**: 管网名称(或数据库名称)
- **scheme_name**: 放置方案名称
- **sensor_type**: 传感器类型
- **method**: 放置方法('sensitivity''kmeans'
- **sensor_count**: 传感器数量
- **min_diameter**: 最小管径限制(毫米,默认0)
- **user_name**: 用户名
支持两种放置方法:
- sensitivity: 基于灵敏度分析
- kmeans: 基于KMeans聚类
"""
if method not in ["sensitivity", "kmeans"]:
raise HTTPException(
status_code=400, detail="Invalid method. Must be 'sensitivity' or 'kmeans'"
@@ -573,68 +763,22 @@ async def fastapi_pressure_sensor_placement(
return "success"
@router.post("/scadadevicedatacleaning/")
async def fastapi_scada_device_data_cleaning(
network: str = Query(...),
ids_list: List[str] = Query(...),
start_time: str = Query(...),
end_time: str = Query(...),
user_name: str = Query(...),
) -> str:
item = {
"network": network,
"ids": ids_list,
"start_time": start_time,
"end_time": end_time,
"user_name": user_name,
}
query_ids_list = item["ids"][0].split(",")
scada_data = influxdb_api.query_SCADA_data_by_device_ID_and_timerange(
query_ids_list=query_ids_list,
start_time=item["start_time"],
end_time=item["end_time"],
)
scada_device_info = influxdb_api.query_pg_scada_info(item["network"])
scada_device_info_dict = {info["id"]: info for info in scada_device_info}
type_groups: dict[str, list[str]] = {}
for device_id in query_ids_list:
device_info = scada_device_info_dict.get(device_id, {})
device_type = device_info.get("type", "unknown")
type_groups.setdefault(device_type, []).append(device_id)
for device_type, device_ids in type_groups.items():
if device_type not in ["pressure", "pipe_flow"]:
continue
type_scada_data = {
device_id: scada_data[device_id]
for device_id in device_ids
if device_id in scada_data
}
if not type_scada_data:
continue
time_list = [record["time"] for record in next(iter(type_scada_data.values()))]
df = pd.DataFrame({"time": time_list})
for device_id in device_ids:
if device_id in type_scada_data:
values = [record["value"] for record in type_scada_data[device_id]]
df[device_id] = values
if device_type == "pressure":
cleaned_value_df = pressure_data_clean.clean_pressure_data_df_km(df)
elif device_type == "pipe_flow":
cleaned_value_df = flow_data_clean.clean_flow_data_df_kf(df)
cleaned_value_df = pd.DataFrame(cleaned_value_df)
cleaned_df = pd.concat([df["time"], cleaned_value_df], axis=1)
influxdb_api.import_multicolumn_data_from_dict(
data_dict=cleaned_df.to_dict("list"),
raw=False,
)
return "success"
@router.post("/runsimulationmanuallybydate/")
@router.post("/runsimulationmanuallybydate/", summary="手动运行日期指定模拟", description="根据指定的开始时间和持续时间,手动运行水力模拟。开始时间必须是显式带时区的 ISO 8601 / RFC3339 时间。")
async def fastapi_run_simulation_manually_by_date(
data: RunSimulationManuallyByDate,
data: RunSimulationManuallyByDate = Body(..., description="模拟运行参数"),
) -> dict[str, str]:
item = data.dict()
"""
手动运行日期指定模拟
请求体参数:
- **name**: 管网名称(或数据库名称)
- **start_time**: 开始时间(ISO 8601 / RFC3339,必须显式带时区)
- **duration**: 模拟持续时间(分钟)
系统将从指定时间开始,按15分钟间隔多次运行模拟。
每次模拟间隔15分钟,直至达到指定的总持续时间。
"""
item = data.model_dump()
try:
simulation.query_corresponding_element_id_and_query_id(item["name"])
simulation.query_corresponding_pattern_id_and_query_id(item["name"])
@@ -661,10 +805,10 @@ async def fastapi_run_simulation_manually_by_date(
globals.source_outflow_region_id,
globals.realtime_region_pipe_flow_and_demand_id,
)
base_date = datetime.strptime(item["simulation_date"], "%Y-%m-%d")
start_time = parse_utc_time(item["start_time"], field_name="start_time")
run_simulation_manually_by_date(
item["name"], base_date, item["start_time"], item["duration"]
item["name"], start_time, item["duration"]
)
return {"status": "success"}
except Exception as exc:
return {"status": "error", "message": str(exc)}
raise HTTPException(status_code=500, detail=str(exc)) from exc
+133 -40
View File
@@ -1,6 +1,6 @@
from fastapi import APIRouter, Request
from app.native.api import ChangeSet
from fastapi import APIRouter, Request, Query
from app.services.tjnetwork import (
ChangeSet,
get_current_operation,
execute_undo,
execute_redo,
@@ -22,90 +22,183 @@ from app.services.tjnetwork import (
router = APIRouter()
@router.get("/getcurrentoperationid/")
async def get_current_operation_id_endpoint(network: str) -> int:
@router.get("/getcurrentoperationid/", summary="获取当前操作ID", description="获取网络当前的操作ID")
async def get_current_operation_id_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> int:
"""
获取当前操作ID
返回网络当前正在执行的操作ID
"""
return get_current_operation(network)
@router.post("/undo/")
async def undo_endpoint(network: str):
@router.post("/undo/", summary="撤销操作", description="撤销网络上最后的一个操作")
async def undo_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")):
"""
撤销操作
撤销网络上最近执行的一个操作
"""
return execute_undo(network)
@router.post("/redo/")
async def redo_endpoint(network: str):
@router.post("/redo/", summary="重做操作", description="重做网络上被撤销的操作")
async def redo_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")):
"""
重做操作
重做网络上被撤销的操作
"""
return execute_redo(network)
@router.get("/getsnapshots/")
async def list_snapshot_endpoint(network: str) -> list[tuple[int, str]]:
@router.get("/getsnapshots/", summary="获取快照列表", description="获取网络中的所有快照")
async def list_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[tuple[int, str]]:
"""
获取快照列表
返回网络中所有可用的快照及其信息
"""
return list_snapshot(network)
@router.get("/havesnapshot/")
async def have_snapshot_endpoint(network: str, tag: str) -> bool:
@router.get("/havesnapshot/", summary="检查快照是否存在", description="检查指定标签的快照是否存在")
async def have_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> bool:
"""
检查快照是否存在
返回指定标签的快照是否存在
"""
return have_snapshot(network, tag)
@router.get("/havesnapshotforoperation/")
async def have_snapshot_for_operation_endpoint(network: str, operation: int) -> bool:
@router.get("/havesnapshotforoperation/", summary="检查操作快照是否存在", description="检查指定操作ID的快照是否存在")
async def have_snapshot_for_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), operation: int = Query(..., description="操作ID")) -> bool:
"""
检查操作快照是否存在
返回指定操作ID的快照是否存在
"""
return have_snapshot_for_operation(network, operation)
@router.get("/havesnapshotforcurrentoperation/")
async def have_snapshot_for_current_operation_endpoint(network: str) -> bool:
@router.get("/havesnapshotforcurrentoperation/", summary="检查当前操作快照是否存在", description="检查当前操作的快照是否存在")
async def have_snapshot_for_current_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> bool:
"""
检查当前操作快照是否存在
返回当前操作的快照是否存在
"""
return have_snapshot_for_current_operation(network)
@router.post("/takesnapshotforoperation/")
@router.post("/takesnapshotforoperation/", summary="为操作创建快照", description="为指定的操作创建快照")
async def take_snapshot_for_operation_endpoint(
network: str, operation: int, tag: str
network: str = Query(..., description="管网名称(或数据库名称)"),
operation: int = Query(..., description="操作ID"),
tag: str = Query(..., description="快照标签")
) -> None:
"""
为操作创建快照
为指定操作创建一个带标签的快照
"""
return take_snapshot_for_operation(network, operation, tag)
@router.post("/takesnapshotforcurrentoperation")
async def take_snapshot_for_current_operation_endpoint(network: str, tag: str) -> None:
@router.post("/takesnapshotforcurrentoperation", summary="为当前操作创建快照", description="为当前操作创建快照")
async def take_snapshot_for_current_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> None:
"""
为当前操作创建快照
为网络当前操作创建一个快照
"""
return take_snapshot_for_current_operation(network, tag)
# 兼容旧拼写: takenapshotforcurrentoperation
@router.post("/takenapshotforcurrentoperation")
async def take_snapshot_for_current_operation_legacy_endpoint(
network: str, tag: str
) -> None:
@router.post("/takenapshotforcurrentoperation", summary="为当前操作创建快照(兼容模式)", description="为当前操作创建快照(兼容旧的API路径)")
async def take_snapshot_for_current_operation_legacy_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> None:
"""
为当前操作创建快照(兼容模式)
兼容旧的API路径,为网络当前操作创建一个快照
"""
return take_snapshot_for_current_operation(network, tag)
@router.post("/takesnapshot/")
async def take_snapshot_endpoint(network: str, tag: str) -> None:
@router.post("/takesnapshot/", summary="创建快照", description="为网络创建一个快照")
async def take_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> None:
"""
创建快照
为网络创建一个带标签的快照
"""
return take_snapshot(network, tag)
@router.post("/picksnapshot/", response_model=None)
async def pick_snapshot_endpoint(network: str, tag: str, discard: bool = False) -> ChangeSet:
@router.post("/picksnapshot/", summary="选择快照", description="选择并恢复到指定的快照", response_model=None)
async def pick_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签"), discard: bool = Query(False, description="是否丢弃当前更改")) -> ChangeSet:
"""
选择快照
选择并恢复到指定的快照
"""
return pick_snapshot(network, tag, discard)
@router.post("/pickoperation/", response_model=None)
@router.post("/pickoperation/", summary="选择操作", description="选择并恢复到指定的操作", response_model=None)
async def pick_operation_endpoint(
network: str, operation: int, discard: bool = False
network: str = Query(..., description="管网名称(或数据库名称)"),
operation: int = Query(..., description="操作ID"),
discard: bool = Query(False, description="是否丢弃当前更改")
) -> ChangeSet:
"""
选择操作
选择并恢复到指定的操作
"""
return pick_operation(network, operation, discard)
@router.get("/syncwithserver/", response_model=None)
async def sync_with_server_endpoint(network: str, operation: int) -> ChangeSet:
@router.get("/syncwithserver/", summary="与服务器同步", description="将网络与服务器同步到指定操作", response_model=None)
async def sync_with_server_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), operation: int = Query(..., description="目标操作ID")) -> ChangeSet:
"""
与服务器同步
将网络与服务器同步到指定的操作
"""
return sync_with_server(network, operation)
@router.post("/batch/", response_model=None)
async def execute_batch_commands_endpoint(network: str, req: Request) -> ChangeSet:
@router.post("/batch/", summary="执行批量命令", description="执行多个网络操作命令", response_model=None)
async def execute_batch_commands_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), req: Request = None) -> ChangeSet:
"""
执行批量命令
在网络上执行多个操作命令
"""
jo_root = await req.json()
cs: ChangeSet = ChangeSet()
cs.operations = jo_root["operations"]
rcs = execute_batch_commands(network, cs)
return rcs
@router.post("/compressedbatch/", response_model=None)
@router.post("/compressedbatch/", summary="执行压缩批量命令", description="执行压缩的批量命令", response_model=None)
async def execute_compressed_batch_commands_endpoint(
network: str, req: Request
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> ChangeSet:
"""
执行压缩批量命令
执行压缩格式的批量命令
"""
jo_root = await req.json()
cs: ChangeSet = ChangeSet()
cs.operations = jo_root["operations"]
return execute_batch_command(network, cs)
@router.get("/getrestoreoperation/")
async def get_restore_operation_endpoint(network: str) -> int:
@router.get("/getrestoreoperation/", summary="获取恢复操作ID", description="获取网络的恢复操作ID")
async def get_restore_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> int:
"""
获取恢复操作ID
返回网络的恢复操作ID
"""
return get_restore_operation(network)
@router.post("/setrestoreoperation/")
async def set_restore_operation_endpoint(network: str, operation: int) -> None:
@router.post("/setrestoreoperation/", summary="设置恢复操作ID", description="设置网络的恢复操作ID")
async def set_restore_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), operation: int = Query(..., description="操作ID")) -> None:
"""
设置恢复操作ID
设置网络的恢复操作ID
"""
return set_restore_operation(network, operation)
@@ -0,0 +1,263 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from datetime import datetime
from psycopg import AsyncConnection
from app.infra.db.timescaledb.composite_queries import CompositeQueries
from .dependencies import get_timescale_connection, get_postgres_connection
router = APIRouter()
@router.get("/composite/scada-simulation", summary="获取SCADA关联的模拟数据")
async def get_scada_associated_simulation_data(
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
device_ids: str = Query(..., description="SCADA设备ID列表,逗号分隔"),
scheme_type: str = Query(None, description="方案类型,若为空则查询实时数据"),
scheme_name: str = Query(None, description="方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取SCADA关联的link/node模拟值
根据传入的SCADA device_ids,找到关联的link/node
并根据对应的type,查询对应的模拟数据。支持查询实时或方案数据。
Args:
start_time: 查询开始时间
end_time: 查询结束时间
device_ids: SCADA设备ID列表,用逗号分隔
scheme_type: 方案类型,若为空则查询实时数据
scheme_name: 方案名称,若为空则查询实时数据
timescale_conn: TimescaleDB连接
postgres_conn: PostgreSQL连接
Returns:
SCADA关联的模拟数据
Raises:
HTTPException: 当查询参数无效时返回400错误,未找到数据时返回404错误
"""
try:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
if scheme_type and scheme_name:
result = await CompositeQueries.get_scada_associated_scheme_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list,
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = (
await CompositeQueries.get_scada_associated_realtime_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list,
start_time,
end_time,
)
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-simulation", summary="获取管网元素的模拟数据")
async def get_feature_simulation_data(
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
feature_infos: str = Query(
..., description="特征信息,格式: id1:type1,id2:type2type为pipe(管道)或junction(节点)"
),
scheme_type: str = Query(None, description="方案类型,若为空则查询实时数据"),
scheme_name: str = Query(None, description="方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
获取link/node模拟值
根据传入的featureInfos,找到关联的link/node
并根据对应的type,查询对应的模拟数据。支持查询实时或方案数据。
Args:
start_time: 查询开始时间
end_time: 查询结束时间
feature_infos: 格式为 "element_id1:type1,element_id2:type2"
例如: "P1:pipe,J1:junction"
scheme_type: 方案类型,若为空则查询实时数据
scheme_name: 方案名称,若为空则查询实时数据
timescale_conn: TimescaleDB连接
Returns:
管网元素的模拟数据
Raises:
HTTPException: 当feature_infos为空返回400错误,未找到数据返回404错误,其他错误返回400错误
"""
try:
feature_infos_list = []
if feature_infos:
for item in feature_infos.split(","):
item = item.strip()
if ":" in item:
element_id, element_type = item.split(":", 1)
feature_infos_list.append(
(element_id.strip(), element_type.strip())
)
if not feature_infos_list:
raise HTTPException(status_code=400, detail="feature_infos cannot be empty")
if scheme_type and scheme_name:
result = await CompositeQueries.get_scheme_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = await CompositeQueries.get_realtime_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-scada", summary="获取管网元素关联的SCADA监测数据")
async def get_element_associated_scada_data(
element_id: str = Query(..., description="管网元素ID(管道或节点)"),
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
use_cleaned: bool = Query(False, description="是否使用清洗后的数据"),
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取link/node关联的SCADA监测值
根据传入的link/node id,匹配SCADA信息,
如果存在关联的SCADA device_id,获取实际的监测数据。
Args:
element_id: 管网元素ID
start_time: 查询开始时间
end_time: 查询结束时间
use_cleaned: 是否使用清洗后的数据,默认为False使用原始数据
timescale_conn: TimescaleDB连接
postgres_conn: PostgreSQL连接
Returns:
管网元素关联的SCADA监测数据
Raises:
HTTPException: 当查询参数无效时返回400错误,未找到关联数据返回404错误
"""
try:
result = await CompositeQueries.get_element_associated_scada_data(
timescale_conn, postgres_conn, element_id, start_time, end_time, use_cleaned
)
if result is None:
raise HTTPException(
status_code=404, detail="No associated SCADA data found"
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/composite/clean-scada", summary="清洗SCADA监测数据")
async def clean_scada_data(
device_ids: str = Query(..., description="设备ID列表或 'all' 表示清洗所有设备"),
start_time: datetime = Query(..., description="清洗数据的开始时间"),
end_time: datetime = Query(..., description="清洗数据的结束时间"),
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
清洗SCADA监测数据
根据device_ids查询monitored_value,清洗后更新cleaned_value。
支持清洗指定设备或所有设备的数据。
Args:
device_ids: 设备ID列表,用逗号分隔,或 'all' 表示清洗所有设备
start_time: 清洗数据的开始时间
end_time: 清洗数据的结束时间
timescale_conn: TimescaleDB连接
postgres_conn: PostgreSQL连接
Returns:
清洗结果信息
Raises:
HTTPException: 当清洗过程出现错误时返回400错误
"""
try:
if device_ids == "all":
device_ids_list = []
else:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await CompositeQueries.clean_scada_data(
timescale_conn, postgres_conn, device_ids_list, start_time, end_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/pipeline-health-prediction", summary="预测管道健康状况")
async def predict_pipeline_health(
query_time: datetime = Query(..., description="查询时间"),
network_name: str = Query(..., description="管网名称(或数据库名称)"),
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
预测管道健康状况
根据管网名称和当前时间,查询管道信息和实时数据,
使用随机生存森林模型预测管道的生存概率。
Args:
query_time: 查询时间
network_name: 管网名称(或数据库名称)
timescale_conn: TimescaleDB连接
Returns:
预测结果列表,每个元素包含 link_id 和对应的生存函数
Raises:
HTTPException: 当模型文件不存在返回404错误,其他错误返回400或500错误
"""
try:
return await CompositeQueries.predict_pipeline_health(
timescale_conn, network_name, query_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
@@ -0,0 +1,19 @@
from fastapi import Depends
from psycopg import AsyncConnection
from app.auth.project_dependencies import (
get_project_pg_connection,
get_project_timescale_connection,
)
async def get_timescale_connection(
conn: AsyncConnection = Depends(get_project_timescale_connection),
):
yield conn
async def get_postgres_connection(
conn: AsyncConnection = Depends(get_project_pg_connection),
):
yield conn
+289
View File
@@ -0,0 +1,289 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from typing import List
from datetime import datetime
from psycopg import AsyncConnection
from app.infra.db.timescaledb.repositories.realtime import RealtimeRepository
from .dependencies import get_timescale_connection
router = APIRouter()
TIME_WITH_TZ_DESC = "ISO 8601 / RFC 3339 时间,必须显式带时区;可直接传 UTC+8,服务端会先转换为 UTC 再处理。"
TIME_RANGE_START_DESC = f"时间范围开始时间。{TIME_WITH_TZ_DESC}"
TIME_RANGE_END_DESC = f"时间范围结束时间。{TIME_WITH_TZ_DESC}"
@router.post("/realtime/links/batch", status_code=201, summary="批量插入实时管道数据")
async def insert_realtime_links(
data: List[dict] = Body(..., description="管道数据列表,每项包含管道ID、时间戳等信息"),
conn: AsyncConnection = Depends(get_timescale_connection)
):
"""
批量插入实时管道数据
将管道的实时监测数据批量插入时间序列数据库。
Args:
data: 管道数据列表
Returns:
插入成功的记录数
"""
await RealtimeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get(
"/realtime/links",
summary="查询实时管道数据",
description="按时间范围查询实时管道数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
)
async def get_realtime_links(
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
查询指定时间范围内的实时管道数据
根据时间范围查询所有实时管道的监测值。传入时间必须显式包含时区,
可以直接使用 UTC+8,服务端会先统一转换为 UTC 再参与数据库查询。
Args:
start_time: 查询开始时间
end_time: 查询结束时间
Returns:
实时管道数据列表
"""
return await RealtimeRepository.get_links_by_time_range(conn, start_time, end_time)
@router.delete(
"/realtime/links",
summary="删除实时管道数据",
description="按时间范围删除实时管道数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端按请求中的绝对时间删除对应 UTC 数据。",
)
async def delete_realtime_links(
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
删除指定时间范围内的实时管道数据
删除在指定时间范围内的所有实时管道监测数据。
Args:
start_time: 删除开始时间
end_time: 删除结束时间
Returns:
删除结果信息
"""
await RealtimeRepository.delete_links_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.patch("/realtime/links/{link_id}/field", summary="更新实时管道字段")
async def update_realtime_link_field(
link_id: str = Path(..., description="管道ID"),
time: datetime = Query(..., description=f"要更新记录的时间戳。{TIME_WITH_TZ_DESC}"),
field: str = Query(..., description="要更新的字段名称"),
value: float = Query(..., description="更新的字段值"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
更新指定管道的字段值
更新实时管道在特定时间的某个字段数据。
Args:
link_id: 管道ID
time: 数据时间戳
field: 字段名称
value: 字段新值
Returns:
更新结果信息
Raises:
HTTPException: 当字段不存在或更新失败时返回400错误
"""
try:
await RealtimeRepository.update_link_field(conn, time, link_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/realtime/nodes/batch", status_code=201, summary="批量插入实时节点数据")
async def insert_realtime_nodes(
data: List[dict] = Body(..., description="节点数据列表,每项包含节点ID、时间戳等信息"),
conn: AsyncConnection = Depends(get_timescale_connection)
):
"""
批量插入实时节点数据
将节点的实时监测数据批量插入时间序列数据库。
Args:
data: 节点数据列表
Returns:
插入成功的记录数
"""
await RealtimeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get(
"/realtime/nodes",
summary="查询实时节点数据",
description="按时间范围查询实时节点数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
)
async def get_realtime_nodes(
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
查询指定时间范围内的实时节点数据
根据时间范围查询所有实时节点的监测值。传入时间必须显式包含时区,
可以直接使用 UTC+8,服务端会先统一转换为 UTC 再参与数据库查询。
Args:
start_time: 查询开始时间
end_time: 查询结束时间
Returns:
实时节点数据列表
"""
return await RealtimeRepository.get_nodes_by_time_range(conn, start_time, end_time)
@router.delete(
"/realtime/nodes",
summary="删除实时节点数据",
description="按时间范围删除实时节点数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端按请求中的绝对时间删除对应 UTC 数据。",
)
async def delete_realtime_nodes(
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
删除指定时间范围内的实时节点数据
删除在指定时间范围内的所有实时节点监测数据。
Args:
start_time: 删除开始时间
end_time: 删除结束时间
Returns:
删除结果信息
"""
await RealtimeRepository.delete_nodes_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.post("/realtime/simulation/store", status_code=201, summary="存储实时模拟结果")
async def store_realtime_simulation_result(
node_result_list: List[dict] = Body(..., description="节点模拟结果列表"),
link_result_list: List[dict] = Body(..., description="管道模拟结果列表"),
result_start_time: str = Query(..., description=f"模拟结果开始时间。{TIME_WITH_TZ_DESC}"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
存储实时模拟结果到时间序列数据库
将节点和管道的实时模拟计算结果批量存储到TimescaleDB数据库。
Args:
node_result_list: 节点模拟结果列表
link_result_list: 管道模拟结果列表
result_start_time: 模拟结果对应的起始时间
Returns:
存储结果信息
"""
await RealtimeRepository.store_realtime_simulation_result(
conn, node_result_list, link_result_list, result_start_time
)
return {"message": "Simulation results stored successfully"}
@router.get(
"/realtime/query/by-time-property",
summary="按时间和属性查询实时数据",
description="查询指定时间点的实时属性值。query_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
)
async def query_realtime_records_by_time_property(
query_time: str = Query(..., description=f"查询时间。{TIME_WITH_TZ_DESC}"),
type: str = Query(..., description="数据类型,pipe(管道)或 junction(节点)"),
property: str = Query(..., description="要查询的属性名称"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
按指定时间和属性查询所有实时监测数据
查询在特定时间点,所有指定类型元素的特定属性值。
Args:
query_time: 查询时间
type: 元素类型(pipe或junction
property: 属性名称
Returns:
查询结果列表
Raises:
HTTPException: 当查询参数无效时返回400错误
"""
try:
results = await RealtimeRepository.query_all_record_by_time_property(
conn, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get(
"/realtime/query/by-id-time",
summary="按ID和时间查询实时模拟数据",
description="查询指定元素在某一时间点的实时模拟结果。query_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
)
async def query_realtime_simulation_by_id_time(
id: str = Query(..., description="元素ID(管道ID或节点ID"),
type: str = Query(..., description="元素类型,pipe(管道)或 junction(节点)"),
query_time: str = Query(..., description=f"查询时间。{TIME_WITH_TZ_DESC}"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
按指定ID和时间查询实时模拟结果
查询特定元素在某一时间点的实时模拟数据。
Args:
id: 元素ID
type: 元素类型(pipe或junction
query_time: 查询时间
Returns:
模拟结果数据
Raises:
HTTPException: 当查询参数无效时返回400错误
"""
try:
results = await RealtimeRepository.query_simulation_result_by_id_time(
conn, id, type, query_time
)
return {"result": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
+159
View File
@@ -0,0 +1,159 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from typing import List
from datetime import datetime
from psycopg import AsyncConnection
from app.infra.db.timescaledb.repositories.scada import ScadaRepository
from .dependencies import get_timescale_connection
router = APIRouter()
@router.post("/scada/batch", status_code=201, summary="批量插入SCADA监测数据")
async def insert_scada_data(
data: List[dict] = Body(..., description="SCADA设备监测数据列表"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
批量插入SCADA监测数据
将多个设备的实时监测数据批量插入时间序列数据库。
Args:
data: SCADA设备监测数据列表,每项包含device_id、时间戳和监测值等信息
Returns:
插入成功的记录数
"""
await ScadaRepository.insert_scada_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scada/by-ids-time-range", summary="按设备ID和时间范围查询SCADA数据")
async def get_scada_by_ids_time_range(
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
device_ids: str = Query(
..., description="设备ID列表,逗号分隔,如 'device1,device2,device3'"
),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
按设备ID和时间范围查询SCADA监测数据
查询多个设备在指定时间范围内的所有监测数据。
Args:
start_time: 查询开始时间
end_time: 查询结束时间
device_ids: 设备ID列表,用逗号分隔
Returns:
SCADA监测数据列表
"""
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()] if device_ids else []
)
return await ScadaRepository.get_scada_by_ids_time_range(
conn, device_ids_list, start_time, end_time
)
@router.get(
"/scada/by-ids-field-time-range", summary="按设备ID、字段和时间范围查询SCADA数据"
)
async def get_scada_field_by_ids_time_range(
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
field: str = Query(..., description="要查询的字段名称"),
device_ids: str = Query(
..., description="设备ID列表,逗号分隔,如 'device1,device2,device3'"
),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
按设备ID、字段和时间范围查询特定SCADA数据
查询多个设备在指定时间范围内的特定字段监测数据。
Args:
start_time: 查询开始时间
end_time: 查询结束时间
field: 字段名称
device_ids: 设备ID列表,用逗号分隔
Returns:
SCADA字段数据列表
Raises:
HTTPException: 当字段不存在或查询参数无效时返回400错误
"""
try:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await ScadaRepository.get_scada_field_by_id_time_range(
conn, device_ids_list, start_time, end_time, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scada/{device_id}/field", summary="更新SCADA设备字段")
async def update_scada_field(
device_id: str = Path(..., description="设备ID"),
time: datetime = Query(..., description="更新数据的时间戳"),
field: str = Query(..., description="要更新的字段名称"),
value: float = Query(..., description="更新的字段值"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
更新指定设备的字段值
更新SCADA设备在特定时间的某个字段监测数据。
Args:
device_id: 设备ID
time: 数据时间戳
field: 字段名称
value: 字段新值
Returns:
更新结果信息
Raises:
HTTPException: 当字段不存在或更新失败时返回400错误
"""
try:
await ScadaRepository.update_scada_field(conn, time, device_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scada/by-id-time-range", summary="按设备ID和时间范围删除SCADA数据")
async def delete_scada_data(
device_id: str = Query(..., description="设备ID"),
start_time: datetime = Query(..., description="删除开始时间"),
end_time: datetime = Query(..., description="删除结束时间"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
删除指定设备和时间范围内的SCADA数据
删除在指定时间范围内的特定设备监测数据。
Args:
device_id: 设备ID
start_time: 删除开始时间
end_time: 删除结束时间
Returns:
删除结果信息
"""
await ScadaRepository.delete_scada_by_id_time_range(
conn, device_id, start_time, end_time
)
return {"message": "Deleted successfully"}
+391
View File
@@ -0,0 +1,391 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
from typing import List
from datetime import datetime
from psycopg import AsyncConnection
from app.infra.db.timescaledb.repositories.scheme import SchemeRepository
from .dependencies import get_timescale_connection
router = APIRouter()
@router.post("/scheme/links/batch", status_code=201, summary="批量插入方案管道数据")
async def insert_scheme_links(
data: List[dict] = Body(..., description="方案管道数据列表"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
批量插入方案管道数据
将特定方案的管道模拟数据批量插入时间序列数据库。
Args:
data: 方案管道数据列表
Returns:
插入成功的记录数
"""
await SchemeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/links", summary="查询方案管道数据")
async def get_scheme_links(
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
查询指定方案和时间范围内的管道数据
根据方案和时间范围查询管道的模拟值。
Args:
scheme_type: 方案类型
scheme_name: 方案名称
start_time: 查询开始时间
end_time: 查询结束时间
Returns:
方案管道数据列表
"""
return await SchemeRepository.get_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
@router.get("/scheme/links/{link_id}/field", summary="查询方案管道字段数据")
async def get_scheme_link_field(
link_id: str = Path(..., description="管道ID"),
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
field: str = Query(..., description="要查询的字段名称"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
查询指定方案管道的特定字段数据
查询特定方案中指定管道在时间范围内的特定字段值。
Args:
link_id: 管道ID
scheme_type: 方案类型
scheme_name: 方案名称
start_time: 查询开始时间
end_time: 查询结束时间
field: 字段名称
Returns:
字段数据列表
Raises:
HTTPException: 当查询参数无效时返回400错误
"""
try:
return await SchemeRepository.get_link_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, link_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/links/{link_id}/field", summary="更新方案管道字段")
async def update_scheme_link_field(
link_id: str = Path(..., description="管道ID"),
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
time: datetime = Query(..., description="更新数据的时间戳"),
field: str = Query(..., description="要更新的字段名称"),
value: float = Query(..., description="更新的字段值"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
更新指定方案管道的字段值
更新特定方案中指定管道在某个时间的字段数据。
Args:
link_id: 管道ID
scheme_type: 方案类型
scheme_name: 方案名称
time: 数据时间戳
field: 字段名称
value: 字段新值
Returns:
更新结果信息
Raises:
HTTPException: 当字段不存在或更新失败时返回400错误
"""
try:
await SchemeRepository.update_link_field(
conn, time, scheme_type, scheme_name, link_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/links", summary="删除方案管道数据")
async def delete_scheme_links(
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
start_time: datetime = Query(..., description="删除开始时间"),
end_time: datetime = Query(..., description="删除结束时间"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
删除指定方案和时间范围内的管道数据
删除在指定方案和时间范围内的所有管道模拟数据。
Args:
scheme_type: 方案类型
scheme_name: 方案名称
start_time: 删除开始时间
end_time: 删除结束时间
Returns:
删除结果信息
"""
await SchemeRepository.delete_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/nodes/batch", status_code=201, summary="批量插入方案节点数据")
async def insert_scheme_nodes(
data: List[dict] = Body(..., description="方案节点数据列表"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
批量插入方案节点数据
将特定方案的节点模拟数据批量插入时间序列数据库。
Args:
data: 方案节点数据列表
Returns:
插入成功的记录数
"""
await SchemeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/nodes/{node_id}/field", summary="查询方案节点字段数据")
async def get_scheme_node_field(
node_id: str = Path(..., description="节点ID"),
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
start_time: datetime = Query(..., description="查询开始时间"),
end_time: datetime = Query(..., description="查询结束时间"),
field: str = Query(..., description="要查询的字段名称"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
查询指定方案节点的特定字段数据
查询特定方案中指定节点在时间范围内的特定字段值。
Args:
node_id: 节点ID
scheme_type: 方案类型
scheme_name: 方案名称
start_time: 查询开始时间
end_time: 查询结束时间
field: 字段名称
Returns:
字段数据列表
Raises:
HTTPException: 当查询参数无效时返回400错误
"""
try:
return await SchemeRepository.get_node_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, node_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/nodes/{node_id}/field", summary="更新方案节点字段")
async def update_scheme_node_field(
node_id: str = Path(..., description="节点ID"),
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
time: datetime = Query(..., description="更新数据的时间戳"),
field: str = Query(..., description="要更新的字段名称"),
value: float = Query(..., description="更新的字段值"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
更新指定方案节点的字段值
更新特定方案中指定节点在某个时间的字段数据。
Args:
node_id: 节点ID
scheme_type: 方案类型
scheme_name: 方案名称
time: 数据时间戳
field: 字段名称
value: 字段新值
Returns:
更新结果信息
Raises:
HTTPException: 当字段不存在或更新失败时返回400错误
"""
try:
await SchemeRepository.update_node_field(
conn, time, scheme_type, scheme_name, node_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/nodes", summary="删除方案节点数据")
async def delete_scheme_nodes(
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
start_time: datetime = Query(..., description="删除开始时间"),
end_time: datetime = Query(..., description="删除结束时间"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
删除指定方案和时间范围内的节点数据
删除在指定方案和时间范围内的所有节点模拟数据。
Args:
scheme_type: 方案类型
scheme_name: 方案名称
start_time: 删除开始时间
end_time: 删除结束时间
Returns:
删除结果信息
"""
await SchemeRepository.delete_nodes_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/simulation/store", status_code=201, summary="存储方案模拟结果")
async def store_scheme_simulation_result(
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
node_result_list: List[dict] = Body(..., description="节点模拟结果列表"),
link_result_list: List[dict] = Body(..., description="管道模拟结果列表"),
result_start_time: str = Query(..., description="模拟结果开始时间"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
存储方案模拟结果到时间序列数据库
将特定方案的节点和管道模拟计算结果批量存储到TimescaleDB数据库。
Args:
scheme_type: 方案类型
scheme_name: 方案名称
node_result_list: 节点模拟结果列表
link_result_list: 管道模拟结果列表
result_start_time: 模拟结果对应的起始时间
Returns:
存储结果信息
"""
await SchemeRepository.store_scheme_simulation_result(
conn,
scheme_type,
scheme_name,
node_result_list,
link_result_list,
result_start_time,
)
return {"message": "Scheme simulation results stored successfully"}
@router.get(
"/scheme/query/by-scheme-time-property", summary="按方案、时间和属性查询数据"
)
async def query_scheme_records_by_scheme_time_property(
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
query_time: str = Query(..., description="查询时间"),
type: str = Query(..., description="元素类型,pipe(管道)或 junction(节点)"),
property: str = Query(..., description="要查询的属性名称"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
按指定方案、时间和属性查询所有方案数据
查询在特定方案和时间点,所有指定类型元素的特定属性值。
Args:
scheme_type: 方案类型
scheme_name: 方案名称
query_time: 查询时间
type: 元素类型(pipe或junction
property: 属性名称
Returns:
查询结果列表
Raises:
HTTPException: 当查询参数无效时返回400错误
"""
try:
results = await SchemeRepository.query_all_record_by_scheme_time_property(
conn, scheme_type, scheme_name, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/scheme/query/by-id-time", summary="按ID和时间查询方案模拟数据")
async def query_scheme_simulation_by_id_time(
scheme_type: str = Query(..., description="方案类型"),
scheme_name: str = Query(..., description="方案名称"),
id: str = Query(..., description="元素ID(管道ID或节点ID"),
type: str = Query(..., description="元素类型,pipe(管道)或 junction(节点)"),
query_time: str = Query(..., description="查询时间"),
conn: AsyncConnection = Depends(get_timescale_connection),
):
"""
按指定ID和时间查询方案模拟结果
查询特定方案中的元素在某一时间点的模拟数据。
Args:
scheme_type: 方案类型
scheme_name: 方案名称
id: 元素ID
type: 元素类型(pipe或junction
query_time: 查询时间
Returns:
模拟结果数据
Raises:
HTTPException: 当查询参数无效时返回400错误
"""
try:
result = await SchemeRepository.query_scheme_simulation_result_by_id_time(
conn, scheme_type, scheme_name, id, type, query_time
)
return {"result": result}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
+94 -59
View File
@@ -3,178 +3,213 @@
演示权限控制的使用
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from app.domain.schemas.user import UserResponse, UserUpdate, UserCreate
from app.domain.models.role import UserRole
from app.domain.schemas.user import UserInDB
from app.infra.repositories.user_repository import UserRepository
from app.infra.db.metadb.repositories.user_repository import UserRepository
from app.auth.dependencies import get_user_repository, get_current_active_user
from app.auth.permissions import get_current_admin, require_role, check_resource_owner
router = APIRouter()
@router.get("/", response_model=List[UserResponse])
@router.get(
"/",
summary="列出所有用户",
description="获取用户列表(仅管理员)",
response_model=List[UserResponse],
)
async def list_users(
skip: int = 0,
limit: int = 100,
skip: int = Query(0, ge=0, description="跳过的用户数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大用户数"),
current_user: UserInDB = Depends(require_role(UserRole.ADMIN)),
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> List[UserResponse]:
"""
获取用户列表(仅管理员)
获取用户列表
获取系统中所有的用户信息(需要管理员权限)
"""
users = await user_repo.get_all_users(skip=skip, limit=limit)
return [UserResponse.model_validate(user) for user in users]
@router.get("/{user_id}", response_model=UserResponse)
@router.get(
"/{user_id}",
summary="获取用户详情",
description="获取指定用户的详细信息",
response_model=UserResponse,
)
async def get_user(
user_id: int,
user_id: int = Path(..., gt=0, description="用户ID"),
current_user: UserInDB = Depends(get_current_active_user),
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> UserResponse:
"""
获取用户详情
管理员可查看所有用户,普通用户只能查看自己
"""
# 检查权限
if not check_resource_owner(user_id, current_user):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to view this user"
detail="You don't have permission to view this user",
)
user = await user_repo.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return UserResponse.model_validate(user)
@router.put("/{user_id}", response_model=UserResponse)
@router.put(
"/{user_id}",
summary="更新用户信息",
description="更新指定用户的信息",
response_model=UserResponse,
)
async def update_user(
user_id: int,
user_update: UserUpdate,
user_id: int = Path(..., gt=0, description="用户ID"),
user_update: UserUpdate = None,
current_user: UserInDB = Depends(get_current_active_user),
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> UserResponse:
"""
更新用户信息
管理员可更新所有用户,普通用户只能更新自己(且不能修改角色)
"""
# 检查用户是否存在
target_user = await user_repo.get_user_by_id(user_id)
if not target_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# 权限检查
is_owner = current_user.id == user_id
is_admin = UserRole(current_user.role).has_permission(UserRole.ADMIN)
if not is_owner and not is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to update this user"
detail="You don't have permission to update this user",
)
# 非管理员不能修改角色和激活状态
if not is_admin:
if user_update.role is not None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admins can change user roles"
detail="Only admins can change user roles",
)
if user_update.is_active is not None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only admins can change user active status"
detail="Only admins can change user active status",
)
# 更新用户
updated_user = await user_repo.update_user(user_id, user_update)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update user"
detail="Failed to update user",
)
return UserResponse.model_validate(updated_user)
@router.delete("/{user_id}")
@router.delete("/{user_id}", summary="删除用户", description="删除指定用户(仅管理员)")
async def delete_user(
user_id: int,
user_id: int = Path(..., gt=0, description="用户ID"),
current_user: UserInDB = Depends(get_current_admin),
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> dict:
"""
删除用户(仅管理员)
删除用户
删除指定用户(需要管理员权限,不能删除自己)
"""
# 不能删除自己
if current_user.id == user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="You cannot delete your own account"
detail="You cannot delete your own account",
)
success = await user_repo.delete_user(user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return {"message": "User deleted successfully"}
@router.post("/{user_id}/activate")
@router.post(
"/{user_id}/activate",
summary="激活用户",
description="激活指定用户账户(仅管理员)",
response_model=UserResponse,
)
async def activate_user(
user_id: int,
user_id: int = Path(..., gt=0, description="用户ID"),
current_user: UserInDB = Depends(get_current_admin),
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> UserResponse:
"""
激活用户(仅管理员)
激活用户
激活指定用户的账户(需要管理员权限)
"""
user_update = UserUpdate(is_active=True)
updated_user = await user_repo.update_user(user_id, user_update)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return UserResponse.model_validate(updated_user)
@router.post("/{user_id}/deactivate")
@router.post(
"/{user_id}/deactivate",
summary="停用用户",
description="停用指定用户账户(仅管理员)",
response_model=UserResponse,
)
async def deactivate_user(
user_id: int,
user_id: int = Path(..., gt=0, description="用户ID"),
current_user: UserInDB = Depends(get_current_admin),
user_repo: UserRepository = Depends(get_user_repository)
user_repo: UserRepository = Depends(get_user_repository),
) -> UserResponse:
"""
停用用户(仅管理员)
停用用户
停用指定用户的账户(需要管理员权限,不能停用自己)
"""
# 不能停用自己
if current_user.id == user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="You cannot deactivate your own account"
detail="You cannot deactivate your own account",
)
user_update = UserUpdate(is_active=False)
updated_user = await user_repo.update_user(user_id, user_update)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return UserResponse.model_validate(updated_user)
+23 -8
View File
@@ -1,6 +1,6 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Query
from typing import Any, List, Dict, Union
from app.services.tjnetwork import *
from app.services.tjnetwork import Any, get_all_users, get_user, get_user_schema
router = APIRouter()
@@ -8,14 +8,29 @@ router = APIRouter()
# user 39
###########################################################
@router.get("/getuserschema/")
async def fastapi_get_user_schema(network: str) -> dict[str, dict[Any, Any]]:
@router.get("/getuserschema/", summary="获取用户模式", description="获取指定网络的用户模式定义")
async def fastapi_get_user_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[Any, Any]]:
"""
获取用户模式定义
返回指定网络的用户模式结构定义
"""
return get_user_schema(network)
@router.get("/getuser/")
async def fastapi_get_user(network: str, user_name: str) -> dict[Any, Any]:
@router.get("/getuser/", summary="获取单个用户", description="获取指定网络中的单个用户信息")
async def fastapi_get_user(network: str = Query(..., description="管网名称(或数据库名称)"), user_name: str = Query(..., description="用户名")) -> dict[Any, Any]:
"""
获取用户信息
返回指定网络中指定用户名的详细信息
"""
return get_user(network, user_name)
@router.get("/getallusers/")
async def fastapi_get_all_users(network: str) -> list[dict[Any, Any]]:
@router.get("/getallusers/", summary="获取所有用户", description="获取指定网络的所有用户列表")
async def fastapi_get_all_users(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
"""
获取所有用户列表
返回指定网络中所有用户的信息
"""
return get_all_users(network)
+31 -9
View File
@@ -6,12 +6,15 @@ from app.api.v1.endpoints import (
scada,
extension,
snapshots,
data_query,
# data_query,
users,
schemes,
misc,
risk,
cache,
leakage,
burst_detection,
burst_location,
user_management, # 新增:用户管理
audit, # 新增:审计日志
meta,
@@ -38,14 +41,21 @@ from app.api.v1.endpoints.components import (
visuals,
)
from app.infra.db.postgresql import router as postgresql_router
from app.infra.db.timescaledb import router as timescaledb_router
from app.api.v1.endpoints import project_data
from app.api.v1.endpoints.timeseries import (
realtime as ts_realtime,
scheme as ts_scheme,
scada as ts_scada,
composite as ts_composite,
)
api_router = APIRouter()
# Core Services
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
api_router.include_router(
user_management.router, prefix="/users", tags=["User Management"]
) # 新增
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
api_router.include_router(meta.router, tags=["Metadata"])
api_router.include_router(project.router, tags=["Project"])
@@ -75,18 +85,30 @@ api_router.include_router(visuals.router, tags=["Visuals"])
# Simulation & Data
api_router.include_router(simulation.router, tags=["Simulation Control"])
api_router.include_router(data_query.router, tags=["Data Query & InfluxDB"])
api_router.include_router(scada.router, tags=["SCADA"])
# api_router.include_router(data_query.router, tags=["Data Query & InfluxDB"])
api_router.include_router(scada.router)
api_router.include_router(snapshots.router, tags=["Snapshots"])
api_router.include_router(users.router, tags=["Users"])
api_router.include_router(schemes.router, tags=["Schemes"])
api_router.include_router(misc.router, tags=["Misc"])
api_router.include_router(risk.router, tags=["Risk"])
api_router.include_router(cache.router, tags=["Cache"])
api_router.include_router(leakage.router, prefix="/leakage", tags=["Leakage"])
api_router.include_router(
burst_detection.router, prefix="/burst-detection", tags=["Burst Detection"]
)
api_router.include_router(
burst_location.router, prefix="/burst-location", tags=["Burst Location"]
)
# Database Routers
api_router.include_router(timescaledb_router, tags=["TimescaleDB"])
api_router.include_router(postgresql_router, tags=["PostgreSQL"])
# TimescaleDB Data Access
api_router.include_router(ts_realtime.router, tags=["TimescaleDB - Realtime"])
api_router.include_router(ts_scheme.router, tags=["TimescaleDB - Scheme"])
api_router.include_router(ts_scada.router, tags=["TimescaleDB - SCADA"])
api_router.include_router(ts_composite.router, tags=["TimescaleDB - Composite"])
# Project Data (PostgreSQL)
api_router.include_router(project_data.router, tags=["Project Data"])
# Extension
api_router.include_router(extension.router, tags=["Extension"])
+9
View File
@@ -0,0 +1,9 @@
"""
This module is reserved for future implementation of audit logging and compliance features.
Current implementation of audit logging can be found in `app.core.audit`.
Future expansion may include:
- Comprehensive audit trails for all system actions
- Compliance reporting (e.g., for industrial control systems)
- Anomaly detection in user behavior
"""
+1 -1
View File
@@ -4,7 +4,7 @@ from fastapi.security import OAuth2PasswordBearer
from jose import jwt, JWTError
from app.core.config import settings
from app.domain.schemas.user import UserInDB, TokenPayload
from app.infra.repositories.user_repository import UserRepository
from app.infra.db.metadb.repositories.user_repository import UserRepository
from app.infra.db.postgresql.database import Database
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
+40
View File
@@ -61,3 +61,43 @@ async def get_current_keycloak_sub(
detail="Invalid subject claim",
headers={"WWW-Authenticate": "Bearer"},
) from exc
async def get_current_keycloak_username(
token: str | None = Depends(oauth2_optional),
) -> str:
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
if settings.KEYCLOAK_PUBLIC_KEY:
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
algorithms = [settings.KEYCLOAK_ALGORITHM]
else:
key = settings.SECRET_KEY
algorithms = [settings.ALGORITHM]
try:
payload = jwt.decode(
token,
key,
algorithms=algorithms,
audience=settings.KEYCLOAK_AUDIENCE or None,
)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
username = payload.get("preferred_username") or payload.get("username")
if not username:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing username claim",
headers={"WWW-Authenticate": "Bearer"},
)
return str(username)
+2 -2
View File
@@ -8,8 +8,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
from app.infra.db.metadb.database import get_metadata_session
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
+2 -2
View File
@@ -11,8 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.keycloak_dependencies import get_current_keycloak_sub
from app.core.config import settings
from app.infra.db.dynamic_manager import project_connection_manager
from app.infra.db.metadata.database import get_metadata_session
from app.infra.repositories.metadata_repository import MetadataRepository
from app.infra.db.metadb.database import get_metadata_session
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
DB_ROLE_BIZ_DATA = "biz_data"
DB_ROLE_IOT_DATA = "iot_data"
+2 -2
View File
@@ -66,8 +66,8 @@ async def log_audit_event(
response_status: 响应状态码
session: 元数据库会话(可选)
"""
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.audit_repository import AuditRepository
from app.infra.db.metadb.database import SessionLocal
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
if request_data:
request_data = sanitize_sensitive_data(request_data)
+72
View File
@@ -1,12 +1,16 @@
from pathlib import Path
from typing import Optional
from urllib.parse import quote_plus
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
PROJECT_NAME: str = "TJWater Server"
ENVIRONMENT: str = "production"
API_V1_STR: str = "/api/v1"
NETWORK_NAME: str = "default_network"
# JWT 配置
SECRET_KEY: str = (
"your-secret-key-here-change-in-production-use-openssl-rand-hex-32"
@@ -80,3 +84,71 @@ class Settings(BaseSettings):
settings = Settings()
def get_pgconn_string(
db_name: Optional[str] = None,
db_host: Optional[str] = None,
db_port: Optional[str] = None,
db_user: Optional[str] = None,
db_password: Optional[str] = None,
) -> str:
"""Return PostgreSQL connection string in psycopg conninfo format."""
resolved_db_name = db_name or settings.DB_NAME
resolved_db_host = db_host or settings.DB_HOST
resolved_db_port = db_port or settings.DB_PORT
resolved_db_user = db_user or settings.DB_USER
resolved_db_password = db_password or settings.DB_PASSWORD
return (
f"dbname={resolved_db_name} host={resolved_db_host} port={resolved_db_port} "
f"user={resolved_db_user} password={resolved_db_password}"
)
def get_pg_config() -> dict:
"""Return PostgreSQL configuration except password."""
return {
"name": settings.DB_NAME,
"host": settings.DB_HOST,
"port": settings.DB_PORT,
"user": settings.DB_USER,
}
def get_pg_password() -> str:
"""Return PostgreSQL password (use with care)."""
return settings.DB_PASSWORD
def get_timescaledb_pgconn_string(
db_name: Optional[str] = None,
db_host: Optional[str] = None,
db_port: Optional[str] = None,
db_user: Optional[str] = None,
db_password: Optional[str] = None,
) -> str:
"""Return TimescaleDB connection string in psycopg conninfo format."""
resolved_db_name = db_name or settings.TIMESCALEDB_DB_NAME
resolved_db_host = db_host or settings.TIMESCALEDB_DB_HOST
resolved_db_port = db_port or settings.TIMESCALEDB_DB_PORT
resolved_db_user = db_user or settings.TIMESCALEDB_DB_USER
resolved_db_password = db_password or settings.TIMESCALEDB_DB_PASSWORD
return (
f"dbname={resolved_db_name} host={resolved_db_host} port={resolved_db_port} "
f"user={resolved_db_user} password={resolved_db_password}"
)
def get_timescaledb_pg_config() -> dict:
"""Return TimescaleDB configuration except password."""
return {
"name": settings.TIMESCALEDB_DB_NAME,
"host": settings.TIMESCALEDB_DB_HOST,
"port": settings.TIMESCALEDB_DB_PORT,
"user": settings.TIMESCALEDB_DB_USER,
}
def get_timescaledb_pg_password() -> str:
"""Return TimescaleDB password (use with care)."""
return settings.TIMESCALEDB_DB_PASSWORD
+10 -6
View File
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, Union, Any
from jose import jwt
@@ -8,6 +8,10 @@ from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def create_access_token(
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
@@ -22,9 +26,9 @@ def create_access_token(
JWT token 字符串
"""
if expires_delta:
expire = datetime.now() + expires_delta
expire = _utc_now() + expires_delta
else:
expire = datetime.now() + timedelta(
expire = _utc_now() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
@@ -32,7 +36,7 @@ def create_access_token(
"exp": expire,
"sub": str(subject),
"type": "access",
"iat": datetime.now(),
"iat": _utc_now(),
}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
@@ -50,13 +54,13 @@ def create_refresh_token(subject: Union[str, Any]) -> str:
Returns:
JWT refresh token 字符串
"""
expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
expire = _utc_now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {
"exp": expire,
"sub": str(subject),
"type": "refresh",
"iat": datetime.now(),
"iat": _utc_now(),
}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
+10
View File
@@ -0,0 +1,10 @@
"""
This module is reserved for future implementation of advanced cryptographic operations.
Current basic encryption (Fernet) and password hashing are implemented in `app.core.encryption` and `app.core.security`.
Future expansion may include:
- Asymmetric encryption (RSA/ECC) for secure communication
- Key management and rotation services
- Integration with Hardware Security Modules (HSM)
- Digital signatures for data integrity verification
"""
+7 -6
View File
@@ -5,10 +5,10 @@ from pydantic import BaseModel
class GeoServerConfigResponse(BaseModel):
gs_base_url: Optional[str]
gs_admin_user: Optional[str]
gs_base_url: Optional[str] = None
gs_admin_user: Optional[str] = None
gs_datastore_name: str
default_extent: Optional[dict]
default_extent: Optional[dict] = None
srid: int
@@ -16,18 +16,19 @@ class ProjectMetaResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
description: Optional[str] = None
gs_workspace: str
map_extent: Optional[dict] = None
status: str
project_role: str
geoserver: Optional[GeoServerConfigResponse]
geoserver: Optional[GeoServerConfigResponse] = None
class ProjectSummaryResponse(BaseModel):
project_id: UUID
name: str
code: str
description: Optional[str]
description: Optional[str] = None
gs_workspace: str
status: str
project_role: str
+72 -9
View File
@@ -15,8 +15,8 @@ import logging
from jose import JWTError, jwt
from app.core.config import settings
from app.infra.db.metadata.database import SessionLocal
from app.infra.repositories.metadata_repository import MetadataRepository
from app.infra.db.metadb.database import SessionLocal
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
logger = logging.getLogger(__name__)
@@ -55,11 +55,26 @@ class AuditMiddleware(BaseHTTPMiddleware):
# 需要审计的HTTP方法
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
EXCLUDED_PATHS = {
"/api/v1/meta/projects",
"/meta/projects",
"/api/v1/openproject/",
"/openproject/",
}
EXCLUDED_PATH_PREFIXES = (
)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 提取开始时间
start_time = time.time()
# 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期
if self._is_excluded_path(request.url.path):
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 1. 预判是否需要读取Body (针对写操作)
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
@@ -68,13 +83,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
if should_capture_body:
try:
# 注意:读取 body 后需要重新设置,避免影响后续处理
original_receive = request._receive
body = await request.body()
if body:
request_data = json.loads(body.decode())
# 重新构造请求以供后续使用
# 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive
body_sent = False
async def receive():
return {"type": "http.request", "body": body}
nonlocal body_sent
if not body_sent:
body_sent = True
return {
"type": "http.request",
"body": body,
"more_body": False,
}
return await original_receive()
request._receive = receive
except Exception as e:
@@ -84,6 +110,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
# 3. 决定是否审计
if self._is_excluded_path(request.url.path):
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# 检查方法
is_audit_method = request.method in self.AUDIT_METHODS
# 检查路径
@@ -139,6 +170,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
return response
def _is_excluded_path(self, path: str) -> bool:
if path in self.EXCLUDED_PATHS:
return True
return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES)
def _resolve_project_id(self, request: Request) -> UUID | None:
project_header = request.headers.get("X-Project-Id")
if not project_header:
@@ -155,6 +191,7 @@ class AuditMiddleware(BaseHTTPMiddleware):
token = auth_header.split(" ", 1)[1].strip()
if not token:
return None
sub = None
try:
key = (
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
@@ -166,17 +203,25 @@ class AuditMiddleware(BaseHTTPMiddleware):
if settings.KEYCLOAK_PUBLIC_KEY
else [settings.ALGORITHM]
)
payload = jwt.decode(token, key, algorithms=algorithms)
payload = jwt.decode(
token,
key,
algorithms=algorithms,
audience=settings.KEYCLOAK_AUDIENCE or None,
)
sub = payload.get("sub")
if not sub:
return None
keycloak_id = UUID(sub)
except (JWTError, ValueError):
except JWTError:
return None
async with SessionLocal() as session:
repo = MetadataRepository(session)
user = await repo.get_user_by_keycloak_id(keycloak_id)
try:
keycloak_id = UUID(sub)
user = await repo.get_user_by_keycloak_id(keycloak_id)
except ValueError:
user = await repo.get_user_by_username(sub)
if user and user.is_active:
return user.id
return None
@@ -218,7 +263,25 @@ class AuditMiddleware(BaseHTTPMiddleware):
if len(path_parts) >= 4:
resource_type = path_parts[3].rstrip("s") # 移除复数s
if len(path_parts) >= 5 and path_parts[4].isdigit():
if len(path_parts) >= 5 and path_parts[4]:
resource_id = path_parts[4]
# 无路径ID时,尝试从查询参数提取业务ID
if not resource_id:
for key in (
"id",
"resource_id",
"device_id",
"device_ids",
"element_id",
"user_id",
"project_id",
"network",
"name",
):
value = request.query_params.get(key)
if value:
resource_id = value
break
return resource_type, resource_id
+96 -30
View File
@@ -24,6 +24,17 @@ logger = logging.getLogger(__name__)
class PgEngineEntry:
engine: AsyncEngine
sessionmaker: async_sessionmaker[AsyncSession]
connection_url: str
pool_min_size: int
pool_max_size: int
@dataclass(frozen=True)
class PoolEntry:
pool: AsyncConnectionPool
connection_url: str
pool_min_size: int
pool_max_size: int
@dataclass(frozen=True)
@@ -35,8 +46,8 @@ class CacheKey:
class ProjectConnectionManager:
def __init__(self) -> None:
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._ts_cache: Dict[CacheKey, PoolEntry] = OrderedDict()
self._pg_raw_cache: Dict[CacheKey, PoolEntry] = OrderedDict()
self._pg_lock = asyncio.Lock()
self._ts_lock = asyncio.Lock()
self._pg_raw_lock = asyncio.Lock()
@@ -56,15 +67,29 @@ class ProjectConnectionManager:
pool_max_size: int,
) -> async_sessionmaker[AsyncSession]:
async with self._pg_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._pg_cache.get(key)
if entry:
self._pg_cache.move_to_end(key)
return entry.sessionmaker
normalized_url = self._normalize_pg_url(connection_url)
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._pg_cache.get(key)
if entry:
if (
entry.connection_url == normalized_url
and entry.pool_min_size == pool_min_size
and entry.pool_max_size == pool_max_size
):
self._pg_cache.move_to_end(key)
return entry.sessionmaker
await entry.engine.dispose()
logger.info(
"Rebuilding PostgreSQL engine for project %s (%s) due to config change",
project_id,
db_role,
)
self._pg_cache.pop(key, None)
engine = create_async_engine(
normalized_url,
pool_size=pool_min_size,
@@ -75,6 +100,9 @@ class ProjectConnectionManager:
self._pg_cache[key] = PgEngineEntry(
engine=engine,
sessionmaker=sessionmaker,
connection_url=normalized_url,
pool_min_size=pool_min_size,
pool_max_size=pool_max_size,
)
await self._evict_pg_if_needed()
logger.info(
@@ -91,14 +119,28 @@ class ProjectConnectionManager:
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._ts_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._ts_cache.get(key)
if pool:
self._ts_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._ts_cache.get(key)
if entry:
if (
entry.connection_url == connection_url
and entry.pool_min_size == pool_min_size
and entry.pool_max_size == pool_max_size
):
self._ts_cache.move_to_end(key)
return entry.pool
await entry.pool.close()
logger.info(
"Rebuilding TimescaleDB pool for project %s (%s) due to config change",
project_id,
db_role,
)
self._ts_cache.pop(key, None)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
@@ -107,7 +149,12 @@ class ProjectConnectionManager:
kwargs={"row_factory": dict_row},
)
await pool.open()
self._ts_cache[key] = pool
self._ts_cache[key] = PoolEntry(
pool=pool,
connection_url=connection_url,
pool_min_size=pool_min_size,
pool_max_size=pool_max_size,
)
await self._evict_ts_if_needed()
logger.info(
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
@@ -123,14 +170,28 @@ class ProjectConnectionManager:
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._pg_raw_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._pg_raw_cache.get(key)
if pool:
self._pg_raw_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._pg_raw_cache.get(key)
if entry:
if (
entry.connection_url == connection_url
and entry.pool_min_size == pool_min_size
and entry.pool_max_size == pool_max_size
):
self._pg_raw_cache.move_to_end(key)
return entry.pool
await entry.pool.close()
logger.info(
"Rebuilding PostgreSQL pool for project %s (%s) due to config change",
project_id,
db_role,
)
self._pg_raw_cache.pop(key, None)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
@@ -139,7 +200,12 @@ class ProjectConnectionManager:
kwargs={"row_factory": dict_row},
)
await pool.open()
self._pg_raw_cache[key] = pool
self._pg_raw_cache[key] = PoolEntry(
pool=pool,
connection_url=connection_url,
pool_min_size=pool_min_size,
pool_max_size=pool_max_size,
)
await self._evict_pg_raw_if_needed()
logger.info(
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
@@ -158,8 +224,8 @@ class ProjectConnectionManager:
async def _evict_ts_if_needed(self) -> None:
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
key, pool = self._ts_cache.popitem(last=False)
await pool.close()
key, entry = self._ts_cache.popitem(last=False)
await entry.pool.close()
logger.info(
"Evicted TimescaleDB pool for project %s (%s)",
key.project_id,
@@ -168,8 +234,8 @@ class ProjectConnectionManager:
async def _evict_pg_raw_if_needed(self) -> None:
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, pool = self._pg_raw_cache.popitem(last=False)
await pool.close()
key, entry = self._pg_raw_cache.popitem(last=False)
await entry.pool.close()
logger.info(
"Evicted PostgreSQL pool for project %s (%s)",
key.project_id,
@@ -188,8 +254,8 @@ class ProjectConnectionManager:
self._pg_cache.clear()
async with self._ts_lock:
for key, pool in list(self._ts_cache.items()):
await pool.close()
for key, entry in list(self._ts_cache.items()):
await entry.pool.close()
logger.info(
"Closed TimescaleDB pool for project %s (%s)",
key.project_id,
@@ -198,8 +264,8 @@ class ProjectConnectionManager:
self._ts_cache.clear()
async with self._pg_raw_lock:
for key, pool in list(self._pg_raw_cache.items()):
await pool.close()
for key, entry in list(self._pg_raw_cache.items()):
await entry.pool.close()
logger.info(
"Closed PostgreSQL pool for project %s (%s)",
key.project_id,
+2 -2
View File
@@ -18,7 +18,7 @@ from dateutil import parser
import psycopg
import time
import app.services.simulation as simulation
from app.services.tjnetwork import *
from app.services.tjnetwork import close_project, get_time, is_project_open, open_project
import schedule
import threading
import app.services.globals as globals
@@ -29,7 +29,7 @@ import pytz
import app.infra.db.influxdb.info as influxdb_info
import app.services.project_info as project_info
import app.services.time_api as time_api
from app.native.api.postgresql_info import get_pgconn_string
from app.core.config import get_pgconn_string
# influxdb数据库连接信息
url = influxdb_info.url
@@ -1,5 +1,5 @@
from datetime import datetime
from uuid import UUID
from uuid import UUID, uuid4
from sqlalchemy import Boolean, DateTime, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
@@ -42,6 +42,7 @@ class Project(Base):
code: Mapped[str] = mapped_column(String(50), unique=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
map_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
status: Mapped[str] = mapped_column(String(20), default="active")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=datetime.utcnow
@@ -95,7 +96,9 @@ class UserProjectMembership(Base):
class AuditLog(Base):
__tablename__ = "audit_logs"
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
user_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), nullable=True, index=True
)

Some files were not shown because too many files have changed in this diff Show More