Compare commits
109 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e336ffcd46 | |||
| 52b8f07abd | |||
| 7efaeb41e8 | |||
| 9a7aad2d36 | |||
| b7872f29a9 | |||
| 233960d8db | |||
| b9410b0ff3 | |||
| 4982efba5e | |||
| f87dd91b2b | |||
| c16e6e3d0c | |||
| 40e699e173 | |||
| 9b8a517092 | |||
| f274cf5122 | |||
| 60db2a7193 | |||
| b72e42521c | |||
| c2ccb7bc4e | |||
| 88be97ddeb | |||
| 2317f4d527 | |||
| 751950e5b5 | |||
| a1dcbd4230 | |||
| 3b712ea467 | |||
| bf2aaa5ff7 | |||
| 51b481d174 | |||
| 644babf77e | |||
| 6b09c6b20d | |||
| 93cbd7e7b3 | |||
| 0196206ed3 | |||
| 88eec2787b | |||
| 621cd9d2f9 | |||
| 600ddd329c | |||
| c184610035 | |||
| 21dd393aee | |||
| b0acfb21ec | |||
| 20ec7d9c8d | |||
| 7c44654195 | |||
| c5d3075ae2 | |||
| 2ea5ce14ba | |||
| adb5dc01fb | |||
| fb9f3217e2 | |||
| 5e8600a0a7 | |||
| 1dcaf5ae9f | |||
| a792838e80 | |||
| 3cd76b9b52 | |||
| e6d00e9bc6 | |||
| 68c12cc4eb | |||
| e0c247f3b2 | |||
| c3bf48499b | |||
| 102cfffefe | |||
| 1a76c89054 | |||
| 1673396e1a | |||
| c137adedad | |||
| 5041922c84 | |||
| cfe69e581b | |||
| b513d05611 | |||
| 9a8d851275 | |||
| 50a1e78073 | |||
| 83a6143146 | |||
| 9aa0646bc6 | |||
| d34c61a051 | |||
| baf899eaeb | |||
| 72d642fcf6 | |||
| 4ea0b8f05b | |||
| aa68bc73ca | |||
| bef1c74782 | |||
| 90216a762a | |||
| 559d5bb8e3 | |||
| 7345210bdd | |||
| 0d8a7f5cb7 | |||
| efeca41cbd | |||
| 8c7d77e6ee | |||
| c946e1b58b | |||
| 0b72ac959a | |||
| 48f836d667 | |||
| 6eec6c04de | |||
| 61d540356d | |||
| eb1d9cce56 | |||
| 78978c6931 | |||
| 747b4cd229 | |||
| ed1eb74cfb | |||
| 20ab08e206 | |||
| 6b85cfc666 | |||
| a56e041cfc | |||
| f9111ab9c1 | |||
| d55e23bc44 | |||
| b3d58379ef | |||
| 9a4a91c328 | |||
| a7e3b6aff9 | |||
| 05ca940c9f | |||
| 0f8d33291d | |||
| 143b918b86 | |||
| 7ff28893a1 | |||
| b9d9cef5ef | |||
| 0c6c27a0c1 | |||
| f5a7e5b3c9 | |||
| 78a57f5c56 | |||
| 7f481ca261 | |||
| bc74e94fbb | |||
| b83b895e2b | |||
| 63d3458fb4 | |||
| b8aee14c00 | |||
| 340808e85e | |||
| 2464c7f612 | |||
| 61f6975296 | |||
| d0abad3c65 | |||
| e7a3aec02f | |||
| 1d662f973a | |||
| 5566172e26 | |||
| df76e40b0a | |||
| 2e479868f8 |
+8
-37
@@ -1,6 +1,7 @@
|
||||
# TJWater Server 环境变量配置模板
|
||||
# 复制此文件为 .env 并填写实际值
|
||||
NETWORK_NAME="szh"
|
||||
ENVIRONMENT="production"
|
||||
NETWORK_NAME="tjwater"
|
||||
# ============================================
|
||||
# 安全配置 (必填)
|
||||
# ============================================
|
||||
@@ -20,17 +21,17 @@ DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
DB_NAME="tjwater"
|
||||
DB_HOST="localhost"
|
||||
DB_PORT="5432"
|
||||
DB_USER="postgres"
|
||||
DB_USER="tjwater"
|
||||
DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 数据库配置 (TimescaleDB)
|
||||
# ============================================
|
||||
TIMESCALEDB_DB_NAME="szh"
|
||||
TIMESCALEDB_DB_NAME="tjwater"
|
||||
TIMESCALEDB_DB_HOST="localhost"
|
||||
TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
TIMESCALEDB_DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 元数据数据库配置 (Metadata DB)
|
||||
@@ -41,39 +42,9 @@ METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 项目连接缓存与连接池配置
|
||||
# ============================================
|
||||
PROJECT_PG_CACHE_SIZE=50
|
||||
PROJECT_TS_CACHE_SIZE=50
|
||||
PROJECT_PG_POOL_SIZE=5
|
||||
PROJECT_PG_MAX_OVERFLOW=10
|
||||
PROJECT_TS_POOL_MIN_SIZE=1
|
||||
PROJECT_TS_POOL_MAX_SIZE=10
|
||||
|
||||
# ============================================
|
||||
# InfluxDB 配置 (时序数据)
|
||||
# ============================================
|
||||
# INFLUXDB_URL=http://localhost:8086
|
||||
# INFLUXDB_TOKEN=your-influxdb-token
|
||||
# INFLUXDB_ORG=your-org
|
||||
# INFLUXDB_BUCKET=tjwater
|
||||
|
||||
# ============================================
|
||||
# JWT 配置 (可选)
|
||||
# ============================================
|
||||
# ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
# ALGORITHM=HS256
|
||||
|
||||
# ============================================
|
||||
# Keycloak JWT (可选)
|
||||
# ============================================
|
||||
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||
# KEYCLOAK_ALGORITHM=RS256
|
||||
|
||||
# ============================================
|
||||
# 其他配置
|
||||
# ============================================
|
||||
# PROJECT_NAME=TJWater Server
|
||||
# API_V1_STR=/api/v1
|
||||
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||
KEYCLOAK_ALGORITHM=RS256
|
||||
KEYCLOAK_AUDIENCE="account"
|
||||
|
||||
-23
@@ -1,23 +0,0 @@
|
||||
NETWORK_NAME="tjwater"
|
||||
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
|
||||
KEYCLOAK_ALGORITHM="RS256"
|
||||
KEYCLOAK_AUDIENCE="account"
|
||||
|
||||
DB_NAME="tjwater"
|
||||
DB_HOST="192.168.1.114"
|
||||
DB_PORT="5432"
|
||||
DB_USER="tjwater"
|
||||
DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
TIMESCALEDB_DB_NAME="tjwater"
|
||||
TIMESCALEDB_DB_HOST="192.168.1.114"
|
||||
TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
METADATA_DB_NAME="system_hub"
|
||||
METADATA_DB_HOST="192.168.1.114"
|
||||
METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="Tjwater@123456"
|
||||
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
+63
-179
@@ -1,198 +1,82 @@
|
||||
# TJWater Server - Copilot Instructions
|
||||
# Copilot Instructions for TJWater Server
|
||||
|
||||
This is a FastAPI-based water network management system (供水管网智能管理系统) that provides hydraulic simulation, SCADA data integration, network element management, and risk analysis capabilities.
|
||||
This repository contains the backend code for the TJWater Server, a water distribution network management system built with FastAPI.
|
||||
|
||||
## Running the Server
|
||||
## High-Level Architecture
|
||||
|
||||
The application follows a layered architecture:
|
||||
|
||||
- **Entry Point**: `app/main.py` initializes the FastAPI application, database connections (PostgreSQL & TimescaleDB), and middleware.
|
||||
- **API Layer**: `app/api/v1` contains the route handlers.
|
||||
- **Service Layer**: `app/services` contains business logic and orchestration.
|
||||
- **Infrastructure Layer**: `app/infra` handles database connections (`db`), audit logging (`audit`), and external integrations.
|
||||
- **Domain Layer**: `app/domain` likely contains core domain models.
|
||||
- **Native/Algorithms**: `app/native` and `app/algorithms` handle specialized water network calculations (possibly using EPANET/WNTR).
|
||||
|
||||
## Build, Test, and Run Commands
|
||||
|
||||
### Environment Setup
|
||||
|
||||
- Dependencies are listed in `requirements.txt`.
|
||||
- Configuration is managed via environment variables (see `.env.example` if available, or `app/core/config.py`).
|
||||
- **Important**: Ensure `.env` is configured with correct database credentials for both PostgreSQL and TimescaleDB.
|
||||
|
||||
If first time setting up, you may want to create a Conda environment:
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Start the server (default: http://0.0.0.0:8000 with 2 workers)
|
||||
python scripts/run_server.py
|
||||
|
||||
# Note: On Windows, the script automatically sets WindowsSelectorEventLoopPolicy
|
||||
conda create -n server python=3.12
|
||||
conda activate server
|
||||
pip install uv
|
||||
uv pip install -r requirements.txt
|
||||
conda install -c conda-forge pymetis
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
### Running the Server
|
||||
|
||||
The preferred way to run the server locally is using the helper script which sets up the Python path correctly:
|
||||
|
||||
```bash
|
||||
conda activate server
|
||||
python scripts/run_server.py
|
||||
```
|
||||
|
||||
Alternatively, you can run directly with uvicorn (ensure PYTHONPATH includes the root):
|
||||
|
||||
```bash
|
||||
conda activate server
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
Use `pytest` to run tests. The `tests/conftest.py` handles path setup.
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest
|
||||
|
||||
# Run a specific test file with verbose output
|
||||
pytest tests/unit/test_pipeline_health_analyzer.py -v
|
||||
# Run a specific test file
|
||||
pytest tests/unit/test_specific_file.py
|
||||
|
||||
# Run from conftest helper
|
||||
python tests/conftest.py
|
||||
# Run a specific test case
|
||||
pytest tests/unit/test_specific_file.py::test_function_name
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
### Building (Optional)
|
||||
|
||||
### Core Components
|
||||
The project includes scripts to compile Python modules to `.pyd` files using Cython (see `scripts/build_pyd.py`). This is likely for distribution/performance but not required for standard development.
|
||||
|
||||
1. **Native Modules** (`app/native/`): Platform-specific compiled extensions (`.so` for Linux, `.pyd` for Windows) providing performance-critical functionality including:
|
||||
- SCADA device integration
|
||||
- Water distribution analysis (WDA)
|
||||
- Pipe risk probability calculations
|
||||
- Wrapped through `app.services.tjnetwork` interface
|
||||
## Key Conventions
|
||||
|
||||
2. **Services Layer** (`app/services/`):
|
||||
- `tjnetwork.py`: Main network API wrapper around native modules
|
||||
- `simulation.py`: Hydraulic simulation orchestration (EPANET integration)
|
||||
- `project_info.py`: Project configuration management
|
||||
- `epanet/`: EPANET hydraulic engine integration
|
||||
- **Async/Await**: The codebase heavily uses `async` and `await` for I/O operations, especially database interactions.
|
||||
- **Database Management**:
|
||||
- Connections are managed globally in `app.infra.db` and initialized in `lifespan` (app/main.py).
|
||||
- Use `app.infra.db.dynamic_manager` for project-specific database connections (multi-tenancy/dynamic projects).
|
||||
- **Pydantic**: extensively used for data validation and settings management.
|
||||
- **Scripts**: The `scripts/` directory contains many utility scripts for maintenance, data processing, and server management. Check there before writing new operational scripts.
|
||||
- **Water Network Modeling**: Interactions with water network models often involve `epanet` or `wntr` libraries. Be aware of domain-specific terminology (nodes, links, junctions, tanks).
|
||||
|
||||
3. **API Layer** (`app/api/v1/`):
|
||||
- **Network Elements**: Separate endpoint modules for junctions, reservoirs, tanks, pipes, pumps, valves
|
||||
- **Components**: Curves, patterns, controls, options, quality, visuals
|
||||
- **Network Features**: Tags, demands, geometry, regions/DMAs
|
||||
- **Core Services**: Auth, project, simulation, SCADA, data query, snapshots
|
||||
## Code Style
|
||||
|
||||
4. **Database Infrastructure** (`app/infra/db/`):
|
||||
- **PostgreSQL**: Primary relational database (users, audit logs, project metadata)
|
||||
- **TimescaleDB**: Time-series extension for historical data
|
||||
- **InfluxDB**: Optional time-series database for high-frequency SCADA data
|
||||
- Connection pools initialized in `main.py` lifespan context
|
||||
- Database instance stored in `app.state.db` for dependency injection
|
||||
|
||||
5. **Domain Layer** (`app/domain/`):
|
||||
- `models/`: Enums and domain objects (e.g., `UserRole`)
|
||||
- `schemas/`: Pydantic models for request/response validation
|
||||
|
||||
6. **Algorithms** (`app/algorithms/`):
|
||||
- `api_ex/`: Analysis algorithms (k-means sensor placement, sensitivity analysis, pipeline health)
|
||||
- `data_cleaning.py`: Data preprocessing utilities
|
||||
- `simulations.py`: Simulation helpers
|
||||
|
||||
### Security & Authentication
|
||||
|
||||
- **Authentication**: JWT-based with access tokens (30 min) and refresh tokens (7 days)
|
||||
- **Authorization**: Role-based access control (RBAC) with 4 roles:
|
||||
- `VIEWER`: Read-only access
|
||||
- `USER`: Read-write access
|
||||
- `OPERATOR`: Modify data
|
||||
- `ADMIN`: Full permissions
|
||||
- **Audit Logging**: `AuditMiddleware` automatically logs POST/PUT/DELETE requests
|
||||
- **Encryption**: Fernet symmetric encryption for sensitive data (`app.core.encryption`)
|
||||
|
||||
Default admin accounts:
|
||||
- `admin` / `admin123`
|
||||
- `tjwater` / `tjwater@123`
|
||||
|
||||
### Key Files
|
||||
|
||||
- `app/main.py`: FastAPI app initialization, lifespan (DB pools), CORS, middleware, router mounting
|
||||
- `app/api/v1/router.py`: Central router aggregating all endpoint modules
|
||||
- `app/core/config.py`: Settings management using `pydantic-settings`
|
||||
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
|
||||
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
|
||||
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
|
||||
|
||||
## Important Conventions
|
||||
|
||||
### Database Connections
|
||||
|
||||
- Database instances are initialized in `main.py` lifespan and stored in `app.state.db`
|
||||
- Access via dependency injection:
|
||||
```python
|
||||
from app.auth.dependencies import get_db
|
||||
|
||||
async def endpoint(db = Depends(get_db)):
|
||||
# Use db connection
|
||||
```
|
||||
|
||||
### Authentication in Endpoints
|
||||
|
||||
Use dependency injection for auth requirements:
|
||||
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
from app.auth.permissions import require_role, get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
# Require any authenticated user
|
||||
@router.get("/data")
|
||||
async def get_data(current_user = Depends(get_current_active_user)):
|
||||
return data
|
||||
|
||||
# Require specific role (USER or higher)
|
||||
@router.post("/data")
|
||||
async def create_data(current_user = Depends(require_role(UserRole.USER))):
|
||||
return result
|
||||
|
||||
# Admin-only access
|
||||
@router.delete("/data/{id}")
|
||||
async def delete_data(id: int, current_user = Depends(get_current_admin)):
|
||||
return result
|
||||
```
|
||||
|
||||
### API Routing Structure
|
||||
|
||||
- All v1 APIs are mounted under `/api/v1` prefix via `api_router`
|
||||
- Legacy routes without version prefix are also mounted for backward compatibility
|
||||
- Group related endpoints in separate router modules under `app/api/v1/endpoints/`
|
||||
- Use descriptive tags in `router.py` for OpenAPI documentation grouping
|
||||
|
||||
### Native Module Integration
|
||||
|
||||
- Native modules are pre-compiled for specific platforms
|
||||
- Always import through `app.native.api` or `app.services.tjnetwork`
|
||||
- The `tjnetwork` service wraps native APIs with constants like:
|
||||
- Element types: `JUNCTION`, `RESERVOIR`, `TANK`, `PIPE`, `PUMP`, `VALVE`
|
||||
- Operations: `API_ADD`, `API_UPDATE`, `API_DELETE`
|
||||
- `ChangeSet` for batch operations
|
||||
|
||||
### Project Initialization
|
||||
|
||||
- On startup, `main.py` automatically loads project from `project_info.name` if set
|
||||
- Projects are opened via `open_project(name)` from `tjnetwork` service
|
||||
|
||||
### Audit Logging
|
||||
|
||||
Manual audit logging for critical operations:
|
||||
|
||||
```python
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
|
||||
await log_audit_event(
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
username=current_user.username,
|
||||
resource_type="resource_name",
|
||||
resource_id=str(resource_id),
|
||||
ip_address=request.client.host,
|
||||
request_data=data
|
||||
)
|
||||
```
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
- Copy `.env.example` to `.env` before first run
|
||||
- Required environment variables:
|
||||
- `SECRET_KEY`: JWT signing (generate with `openssl rand -hex 32`)
|
||||
- `ENCRYPTION_KEY`: Data encryption (generate with Fernet)
|
||||
- Database credentials for PostgreSQL, TimescaleDB, and optionally InfluxDB
|
||||
|
||||
### Database Migrations
|
||||
|
||||
SQL migration scripts are in `migrations/`:
|
||||
- `001_create_users_table.sql`: User authentication tables
|
||||
- `002_create_audit_logs_table.sql`: Audit logging tables
|
||||
|
||||
Apply with:
|
||||
```bash
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
```
|
||||
|
||||
## API Documentation
|
||||
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
- OpenAPI schema: http://localhost:8000/openapi.json
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `SECURITY_README.md`: Comprehensive security feature documentation
|
||||
- `DEPLOYMENT.md`: Integration guide for security features
|
||||
- `readme.md`: Project overview and directory structure (in Chinese)
|
||||
- Follow standard PEP 8 guidelines.
|
||||
- No specific linter configuration was found, so default to standard Python formatting.
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
name: Build And Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
jobs:
|
||||
build-package:
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest]
|
||||
|
||||
steps:
|
||||
- name: Checkout source
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install system build tools
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Install compile dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install cython setuptools wheel
|
||||
|
||||
- name: Run Cython compile
|
||||
run: |
|
||||
python scripts/compile.py
|
||||
|
||||
- name: Prepare package and archive
|
||||
run: |
|
||||
python - <<'PY'
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import zipfile
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root = Path.cwd()
|
||||
package_dir = root / "package"
|
||||
dist_dir = root / "dist"
|
||||
|
||||
for d in [package_dir, dist_dir]:
|
||||
if d.exists():
|
||||
shutil.rmtree(d)
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Define directories with compiled artifacts
|
||||
compile_dirs = ["app/services", "app/native/wndb", "app/algorithms"]
|
||||
# Global ignore list
|
||||
ignore_names = {
|
||||
".git",
|
||||
".github",
|
||||
"__pycache__",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".venv",
|
||||
"venv",
|
||||
"temp",
|
||||
"tests",
|
||||
"package",
|
||||
"dist",
|
||||
}
|
||||
|
||||
def ignore_func(directory, names):
|
||||
rel_dir = os.path.relpath(directory, root).replace("\\", "/")
|
||||
is_in_compile_path = any(rel_dir.startswith(d) for d in compile_dirs)
|
||||
|
||||
ignored = []
|
||||
for name in names:
|
||||
if name in ignore_names or name.endswith(".pyc"):
|
||||
ignored.append(name)
|
||||
# Exclude source .py files only in compiled directories
|
||||
elif is_in_compile_path and name.endswith(".py"):
|
||||
ignored.append(name)
|
||||
return ignored
|
||||
|
||||
for item in root.iterdir():
|
||||
if item.name in ignore_names:
|
||||
continue
|
||||
target = package_dir / item.name
|
||||
if item.is_dir():
|
||||
shutil.copytree(item, target, ignore=ignore_func)
|
||||
else:
|
||||
shutil.copy2(item, target)
|
||||
|
||||
# Safety guard: ensure no .github directory remains
|
||||
github_paths = [p for p in package_dir.rglob(".github") if p.is_dir()]
|
||||
for p in github_paths:
|
||||
shutil.rmtree(p, ignore_errors=True)
|
||||
|
||||
sha = os.environ["GITHUB_SHA"]
|
||||
run_os = os.environ["RUNNER_OS"].lower()
|
||||
|
||||
if run_os == "windows":
|
||||
archive_path = dist_dir / f"tjwater-server-{run_os}-{sha}.zip"
|
||||
with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
for f in package_dir.rglob("*"):
|
||||
if f.is_file():
|
||||
zf.write(f, f.relative_to(package_dir))
|
||||
else:
|
||||
archive_path = dist_dir / f"tjwater-server-{run_os}-{sha}.tar.gz"
|
||||
with tarfile.open(archive_path, "w:gz") as tf:
|
||||
tf.add(package_dir, arcname=".")
|
||||
|
||||
print(f"Archive created: {archive_path}")
|
||||
PY
|
||||
shell: bash
|
||||
|
||||
- name: Upload package artifact
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: tjwater-server-package-${{ runner.os }}
|
||||
path: dist/*
|
||||
retention-days: 14
|
||||
+2
-1
@@ -5,5 +5,6 @@ build/
|
||||
*.pyc
|
||||
.env
|
||||
*.dump
|
||||
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
|
||||
.vscode/
|
||||
app/algorithms/health/model/my_survival_forest_model_quxi.joblib
|
||||
inp/
|
||||
|
||||
-391
@@ -1,391 +0,0 @@
|
||||
# 部署和集成指南
|
||||
|
||||
本文档说明如何将新的安全功能集成到现有系统中。
|
||||
|
||||
## 📦 已完成的功能
|
||||
|
||||
### 1. 数据加密模块
|
||||
- ✅ `app/core/encryption.py` - Fernet 对称加密实现
|
||||
- ✅ 支持敏感数据加密/解密
|
||||
- ✅ 密钥管理和生成工具
|
||||
|
||||
### 2. 用户认证系统
|
||||
- ✅ `app/domain/models/role.py` - 用户角色枚举 (ADMIN/OPERATOR/USER/VIEWER)
|
||||
- ✅ `app/domain/schemas/user.py` - 用户数据模型和验证
|
||||
- ✅ `app/infra/repositories/user_repository.py` - 用户数据访问层
|
||||
- ✅ `app/api/v1/endpoints/auth.py` - 注册/登录/刷新Token接口
|
||||
- ✅ `app/auth/dependencies.py` - 认证依赖项
|
||||
- ✅ `migrations/001_create_users_table.sql` - 用户表迁移脚本
|
||||
|
||||
### 3. 权限控制系统
|
||||
- ✅ `app/auth/permissions.py` - RBAC 权限控制装饰器
|
||||
- ✅ `app/api/v1/endpoints/user_management.py` - 用户管理接口示例
|
||||
- ✅ 支持基于角色的访问控制
|
||||
- ✅ 支持资源所有者检查
|
||||
|
||||
### 4. 审计日志系统
|
||||
- ✅ `app/core/audit.py` - 审计日志核心功能
|
||||
- ✅ `app/domain/schemas/audit.py` - 审计日志数据模型
|
||||
- ✅ `app/infra/repositories/audit_repository.py` - 审计日志数据访问层
|
||||
- ✅ `app/api/v1/endpoints/audit.py` - 审计日志查询接口
|
||||
- ✅ `app/infra/audit/middleware.py` - 自动审计中间件
|
||||
- ✅ `migrations/002_create_audit_logs_table.sql` - 审计日志表迁移脚本
|
||||
|
||||
### 5. 文档和测试
|
||||
- ✅ `SECURITY_README.md` - 完整的使用文档
|
||||
- ✅ `.env.example` - 环境变量配置模板
|
||||
- ✅ `tests/test_encryption.py` - 加密功能测试
|
||||
|
||||
---
|
||||
|
||||
## 🔧 集成步骤
|
||||
|
||||
### 步骤 1: 环境配置
|
||||
|
||||
1. 复制环境变量模板:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
2. 生成密钥并填写 `.env`:
|
||||
```bash
|
||||
# JWT 密钥
|
||||
openssl rand -hex 32
|
||||
|
||||
# 加密密钥
|
||||
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
```
|
||||
|
||||
3. 编辑 `.env` 填写所有必需的配置项。
|
||||
|
||||
### 步骤 2: 数据库迁移
|
||||
|
||||
执行数据库迁移脚本:
|
||||
|
||||
```bash
|
||||
# 方法 1: 使用 psql 命令
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
|
||||
# 方法 2: 在 psql 交互界面
|
||||
psql -U postgres -d tjwater
|
||||
\i migrations/001_create_users_table.sql
|
||||
\i migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
验证表已创建:
|
||||
|
||||
```sql
|
||||
-- 检查用户表
|
||||
SELECT * FROM users;
|
||||
|
||||
-- 检查审计日志表
|
||||
SELECT * FROM audit_logs;
|
||||
```
|
||||
|
||||
### 步骤 3: 更新 main.py
|
||||
|
||||
在 `app/main.py` 中集成新功能:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from app.core.config import settings
|
||||
from app.infra.audit.middleware import AuditMiddleware
|
||||
|
||||
app = FastAPI(title=settings.PROJECT_NAME)
|
||||
|
||||
# 1. 添加审计中间件(可选)
|
||||
app.add_middleware(AuditMiddleware)
|
||||
|
||||
# 2. 注册路由
|
||||
from app.api.v1.endpoints import auth, user_management, audit
|
||||
|
||||
app.include_router(
|
||||
auth.router,
|
||||
prefix=f"{settings.API_V1_STR}/auth",
|
||||
tags=["认证"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
user_management.router,
|
||||
prefix=f"{settings.API_V1_STR}/users",
|
||||
tags=["用户管理"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
audit.router,
|
||||
prefix=f"{settings.API_V1_STR}/audit",
|
||||
tags=["审计日志"]
|
||||
)
|
||||
|
||||
# 3. 确保数据库在启动时初始化
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# 初始化数据库连接池
|
||||
from app.infra.db.postgresql.database import Database
|
||||
global db
|
||||
db = Database()
|
||||
db.init_pool()
|
||||
await db.open()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
# 关闭数据库连接
|
||||
await db.close()
|
||||
```
|
||||
|
||||
### 步骤 4: 保护现有接口
|
||||
|
||||
#### 方法 1: 为路由添加全局依赖
|
||||
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
|
||||
# 为整个路由器添加认证
|
||||
router = APIRouter(dependencies=[Depends(get_current_active_user)])
|
||||
```
|
||||
|
||||
#### 方法 2: 为单个端点添加依赖
|
||||
|
||||
```python
|
||||
from app.auth.permissions import require_role, get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
@router.get("/data")
|
||||
async def get_data(
|
||||
current_user = Depends(require_role(UserRole.USER))
|
||||
):
|
||||
"""需要 USER 及以上角色"""
|
||||
return {"data": "protected"}
|
||||
|
||||
@router.delete("/data/{id}")
|
||||
async def delete_data(
|
||||
id: int,
|
||||
current_user = Depends(get_current_admin)
|
||||
):
|
||||
"""仅管理员可访问"""
|
||||
return {"message": "deleted"}
|
||||
```
|
||||
|
||||
### 步骤 5: 添加审计日志
|
||||
|
||||
#### 自动审计(推荐)
|
||||
|
||||
使用中间件自动记录(已在 main.py 中添加):
|
||||
|
||||
```python
|
||||
app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
#### 手动审计
|
||||
|
||||
在关键业务逻辑中手动记录:
|
||||
|
||||
```python
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
|
||||
@router.post("/important-action")
|
||||
async def important_action(
|
||||
data: dict,
|
||||
request: Request,
|
||||
current_user = Depends(get_current_active_user)
|
||||
):
|
||||
# 执行业务逻辑
|
||||
result = do_something(data)
|
||||
|
||||
# 记录审计日志
|
||||
await log_audit_event(
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
username=current_user.username,
|
||||
resource_type="important_resource",
|
||||
resource_id=str(result.id),
|
||||
ip_address=request.client.host,
|
||||
request_data=data
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### 步骤 6: 更新 auth/dependencies.py
|
||||
|
||||
确保 `get_db()` 函数正确获取数据库实例:
|
||||
|
||||
```python
|
||||
async def get_db() -> Database:
|
||||
"""获取数据库实例"""
|
||||
# 方法 1: 从 main.py 导入
|
||||
from app.main import db
|
||||
return db
|
||||
|
||||
# 方法 2: 从 FastAPI app.state 获取
|
||||
# from fastapi import Request
|
||||
# def get_db_from_request(request: Request):
|
||||
# return request.app.state.db
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧪 测试
|
||||
|
||||
### 1. 测试加密功能
|
||||
|
||||
```bash
|
||||
python tests/test_encryption.py
|
||||
```
|
||||
|
||||
### 2. 测试 API
|
||||
|
||||
启动服务器:
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
访问交互式文档:
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
|
||||
### 3. 测试登录
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=admin&password=admin123"
|
||||
```
|
||||
|
||||
### 4. 测试受保护接口
|
||||
|
||||
```bash
|
||||
TOKEN="your-access-token"
|
||||
curl -X GET "http://localhost:8000/api/v1/auth/me" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 迁移现有接口
|
||||
|
||||
### 原有硬编码认证
|
||||
|
||||
**旧代码** (`app/api/v1/endpoints/auth.py`):
|
||||
```python
|
||||
AUTH_TOKEN = "567e33c876a2"
|
||||
|
||||
async def verify_token(authorization: str = Header()):
|
||||
token = authorization.split(" ")[1]
|
||||
if token != AUTH_TOKEN:
|
||||
raise HTTPException(status_code=403)
|
||||
```
|
||||
|
||||
**新代码** (已更新):
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
|
||||
@router.get("/protected")
|
||||
async def protected_route(
|
||||
current_user = Depends(get_current_active_user)
|
||||
):
|
||||
return {"user": current_user.username}
|
||||
```
|
||||
|
||||
### 更新其他端点
|
||||
|
||||
搜索项目中使用旧认证的地方:
|
||||
|
||||
```bash
|
||||
grep -r "AUTH_TOKEN" app/
|
||||
grep -r "verify_token" app/
|
||||
```
|
||||
|
||||
替换为新的依赖注入系统。
|
||||
|
||||
---
|
||||
|
||||
## 📋 检查清单
|
||||
|
||||
部署前检查:
|
||||
|
||||
- [ ] 环境变量已配置(`.env`)
|
||||
- [ ] 数据库迁移已执行
|
||||
- [ ] 默认管理员账号可登录
|
||||
- [ ] JWT Token 可正常生成和验证
|
||||
- [ ] 权限控制正常工作
|
||||
- [ ] 审计日志正常记录
|
||||
- [ ] 加密功能测试通过
|
||||
- [ ] API 文档可访问
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### 1. 向后兼容性
|
||||
|
||||
保留了简化版登录接口 `/auth/login/simple` 以兼容旧客户端:
|
||||
|
||||
```python
|
||||
@router.post("/login/simple")
|
||||
async def login_simple(username: str, password: str):
|
||||
# 验证并返回 Token
|
||||
...
|
||||
```
|
||||
|
||||
### 2. 数据库连接
|
||||
|
||||
确保在 `app/auth/dependencies.py` 中 `get_db()` 函数能正确获取数据库实例。
|
||||
|
||||
### 3. 密钥安全
|
||||
|
||||
- ❌ 不要提交 `.env` 文件到版本控制
|
||||
- ✅ 在生产环境使用环境变量或密钥管理服务
|
||||
- ✅ 定期轮换 JWT 密钥
|
||||
|
||||
### 4. 性能考虑
|
||||
|
||||
- 审计中间件会增加每个请求的处理时间(约 5-10ms)
|
||||
- 对高频接口可考虑异步记录审计日志
|
||||
- 定期清理或归档旧的审计日志
|
||||
|
||||
---
|
||||
|
||||
## 🐛 故障排查
|
||||
|
||||
### 问题 1: 导入错误
|
||||
|
||||
```
|
||||
ImportError: cannot import name 'db' from 'app.main'
|
||||
```
|
||||
|
||||
**解决**: 确保在 `app/main.py` 中定义了全局 `db` 对象。
|
||||
|
||||
### 问题 2: 认证失败
|
||||
|
||||
```
|
||||
401 Unauthorized: Could not validate credentials
|
||||
```
|
||||
|
||||
**检查**:
|
||||
1. Token 是否正确设置在 `Authorization: Bearer {token}` header
|
||||
2. Token 是否过期
|
||||
3. SECRET_KEY 是否配置正确
|
||||
|
||||
### 问题 3: 数据库连接失败
|
||||
|
||||
```
|
||||
psycopg.OperationalError: connection failed
|
||||
```
|
||||
|
||||
**检查**:
|
||||
1. PostgreSQL 是否运行
|
||||
2. `.env` 中数据库配置是否正确
|
||||
3. 数据库是否存在
|
||||
|
||||
---
|
||||
|
||||
## 📞 技术支持
|
||||
|
||||
详细文档请参考:
|
||||
- `SECURITY_README.md` - 安全功能使用指南
|
||||
- `migrations/` - 数据库迁移脚本
|
||||
- `app/domain/schemas/` - 数据模型定义
|
||||
|
||||
+5
-5
@@ -1,19 +1,19 @@
|
||||
FROM continuumio/miniconda3:latest
|
||||
FROM condaforge/miniforge3:latest
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装 Python 3.12 和 pymetis (通过 conda-forge 避免编译问题)
|
||||
RUN conda install -y -c conda-forge python=3.12 pymetis && \
|
||||
conda clean -afy
|
||||
RUN mamba install -y python=3.12 pymetis && \
|
||||
mamba clean -afy
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
RUN pip install uv
|
||||
RUN uv pip install --system --no-cache-dir -r requirements.txt
|
||||
|
||||
# 将代码放入子目录 'app',将数据放入子目录 'db_inp'
|
||||
# 这样临时文件默认会生成在 /app 下,而代码在 /app/app 下,实现了分离
|
||||
COPY app ./app
|
||||
COPY db_inp ./db_inp
|
||||
COPY temp ./temp
|
||||
COPY .env .
|
||||
|
||||
# 设置 PYTHONPATH 以便 uvicorn 找到 app 模块
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
# API 集成检查清单
|
||||
|
||||
## ✅ 已完成的集成工作
|
||||
|
||||
### 1. 路由集成 (app/api/v1/router.py)
|
||||
|
||||
已添加以下路由到 API Router:
|
||||
|
||||
```python
|
||||
# 新增导入
|
||||
from app.api.v1.endpoints import (
|
||||
...
|
||||
user_management, # 用户管理
|
||||
audit, # 审计日志
|
||||
)
|
||||
|
||||
# 新增路由
|
||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"])
|
||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"])
|
||||
```
|
||||
|
||||
**路由端点**:
|
||||
- `/api/v1/auth/` - 认证相关(register, login, me, refresh)
|
||||
- `/api/v1/users/` - 用户管理(CRUD操作,仅管理员)
|
||||
- `/api/v1/audit/` - 审计日志查询(仅管理员)
|
||||
|
||||
### 2. 主应用配置 (app/main.py)
|
||||
|
||||
#### 2.1 导入更新
|
||||
```python
|
||||
from app.core.config import settings
|
||||
from app.infra.audit.middleware import AuditMiddleware
|
||||
```
|
||||
|
||||
#### 2.2 数据库初始化
|
||||
```python
|
||||
# 在 lifespan 中存储数据库实例到 app.state
|
||||
app.state.db = pgdb
|
||||
```
|
||||
|
||||
#### 2.3 FastAPI 配置
|
||||
```python
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title=settings.PROJECT_NAME,
|
||||
description="TJWater Server - 供水管网智能管理系统",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
```
|
||||
|
||||
#### 2.4 审计中间件(可选)
|
||||
```python
|
||||
# 取消注释以启用审计日志
|
||||
# app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
### 3. 依赖项更新 (app/auth/dependencies.py)
|
||||
|
||||
更新 `get_db()` 函数从 Request 对象获取数据库:
|
||||
|
||||
```python
|
||||
async def get_db(request: Request) -> Database:
|
||||
"""从 app.state 获取数据库实例"""
|
||||
if not hasattr(request.app.state, "db"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Database not initialized"
|
||||
)
|
||||
return request.app.state.db
|
||||
```
|
||||
|
||||
### 4. 审计日志更新
|
||||
|
||||
- `app/api/v1/endpoints/audit.py` - 使用正确的数据库依赖
|
||||
- `app/core/audit.py` - 接受可选的 db 参数
|
||||
|
||||
---
|
||||
|
||||
## 📋 部署前检查清单
|
||||
|
||||
### 环境配置
|
||||
- [ ] 复制 `.env.example` 为 `.env`
|
||||
- [ ] 配置 `SECRET_KEY`(JWT密钥)
|
||||
- [ ] 配置 `ENCRYPTION_KEY`(数据加密密钥)
|
||||
- [ ] 配置数据库连接信息
|
||||
|
||||
### 数据库迁移
|
||||
- [ ] 执行用户表迁移:`psql -U postgres -d tjwater -f migrations/001_create_users_table.sql`
|
||||
- [ ] 执行审计日志表迁移:`psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql`
|
||||
- [ ] 验证表已创建:`\dt` 在 psql 中
|
||||
|
||||
### 依赖检查
|
||||
- [ ] 确认已安装:`cryptography`
|
||||
- [ ] 确认已安装:`python-jose[cryptography]`
|
||||
- [ ] 确认已安装:`passlib[bcrypt]`
|
||||
- [ ] 确认已安装:`email-validator`(用于 Pydantic email 验证)
|
||||
|
||||
### 代码验证
|
||||
- [ ] 检查所有文件导入正常
|
||||
- [ ] 运行加密功能测试:`python tests/test_encryption.py`
|
||||
- [ ] 启动服务器:`uvicorn app.main:app --reload`
|
||||
- [ ] 访问 API 文档:http://localhost:8000/docs
|
||||
|
||||
### API 测试
|
||||
- [ ] 测试登录:POST `/api/v1/auth/login`
|
||||
- [ ] 测试获取当前用户:GET `/api/v1/auth/me`
|
||||
- [ ] 测试用户列表(需管理员):GET `/api/v1/users/`
|
||||
- [ ] 测试审计日志(需管理员):GET `/api/v1/audit/logs`
|
||||
|
||||
---
|
||||
|
||||
## 🔧 快速测试命令
|
||||
|
||||
### 1. 生成密钥
|
||||
```bash
|
||||
# JWT 密钥
|
||||
openssl rand -hex 32
|
||||
|
||||
# 加密密钥
|
||||
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
```
|
||||
|
||||
### 2. 执行迁移
|
||||
```bash
|
||||
cd /home/zhifu/TJWaterServer/TJWaterServerBinary
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
### 3. 测试加密
|
||||
```bash
|
||||
python tests/test_encryption.py
|
||||
```
|
||||
|
||||
### 4. 启动服务器
|
||||
```bash
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### 5. 测试登录 API
|
||||
```bash
|
||||
# 使用默认管理员账号
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=admin&password=admin123"
|
||||
|
||||
# 或使用迁移的账号
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=tjwater&password=tjwater@123"
|
||||
```
|
||||
|
||||
### 6. 测试受保护接口
|
||||
```bash
|
||||
# 保存 Token
|
||||
TOKEN="<从登录响应中获取的 access_token>"
|
||||
|
||||
# 获取当前用户信息
|
||||
curl -X GET "http://localhost:8000/api/v1/auth/me" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
|
||||
# 获取用户列表(需管理员权限)
|
||||
curl -X GET "http://localhost:8000/api/v1/users/" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
|
||||
# 查询审计日志(需管理员权限)
|
||||
curl -X GET "http://localhost:8000/api/v1/audit/logs" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📚 API 端点总览
|
||||
|
||||
### 认证接口 (`/api/v1/auth`)
|
||||
|
||||
| 方法 | 端点 | 描述 | 权限 |
|
||||
|------|------|------|------|
|
||||
| POST | `/register` | 用户注册 | 公开 |
|
||||
| POST | `/login` | OAuth2 登录 | 公开 |
|
||||
| POST | `/login/simple` | 简化登录(兼容旧版) | 公开 |
|
||||
| GET | `/me` | 获取当前用户信息 | 认证用户 |
|
||||
| POST | `/refresh` | 刷新 Token | 认证用户 |
|
||||
|
||||
### 用户管理 (`/api/v1/users`)
|
||||
|
||||
| 方法 | 端点 | 描述 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/` | 获取用户列表 | 管理员 |
|
||||
| GET | `/{id}` | 获取用户详情 | 所有者/管理员 |
|
||||
| PUT | `/{id}` | 更新用户信息 | 所有者/管理员 |
|
||||
| DELETE | `/{id}` | 删除用户 | 管理员 |
|
||||
| POST | `/{id}/activate` | 激活用户 | 管理员 |
|
||||
| POST | `/{id}/deactivate` | 停用用户 | 管理员 |
|
||||
|
||||
### 审计日志 (`/api/v1/audit`)
|
||||
|
||||
| 方法 | 端点 | 描述 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/logs` | 查询审计日志 | 管理员 |
|
||||
| GET | `/logs/count` | 获取日志总数 | 管理员 |
|
||||
| GET | `/logs/my` | 查看我的操作记录 | 认证用户 |
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### 1. 审计中间件
|
||||
审计中间件默认是**禁用**的。如需启用,在 `app/main.py` 中取消注释:
|
||||
|
||||
```python
|
||||
app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
**注意**:启用后会自动记录所有 POST/PUT/DELETE 请求,可能增加数据库负载。
|
||||
|
||||
### 2. 向后兼容
|
||||
保留了原有的简化登录接口 `/auth/login/simple`,可以直接使用查询参数:
|
||||
|
||||
```bash
|
||||
POST /api/v1/auth/login/simple?username=admin&password=admin123
|
||||
```
|
||||
|
||||
### 3. 数据库连接
|
||||
确保数据库实例在应用启动时正确初始化并存储到 `app.state.db`。
|
||||
|
||||
### 4. 权限控制示例
|
||||
为现有接口添加权限控制:
|
||||
|
||||
```python
|
||||
from app.auth.permissions import require_role, get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
# 需要管理员权限
|
||||
@router.delete("/resource/{id}")
|
||||
async def delete_resource(
|
||||
id: int,
|
||||
current_user = Depends(get_current_admin)
|
||||
):
|
||||
...
|
||||
|
||||
# 需要操作员以上权限
|
||||
@router.post("/resource")
|
||||
async def create_resource(
|
||||
data: dict,
|
||||
current_user = Depends(require_role(UserRole.OPERATOR))
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 完整启动流程
|
||||
|
||||
```bash
|
||||
# 1. 进入项目目录
|
||||
cd /home/zhifu/TJWaterServer/TJWaterServerBinary
|
||||
|
||||
# 2. 配置环境变量(如果还没有)
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填写必要的配置
|
||||
|
||||
# 3. 执行数据库迁移(如果还没有)
|
||||
psql -U postgres -d tjwater < migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater < migrations/002_create_audit_logs_table.sql
|
||||
|
||||
# 4. 测试加密功能
|
||||
python tests/test_encryption.py
|
||||
|
||||
# 5. 启动服务器
|
||||
uvicorn app.main:app --reload
|
||||
|
||||
# 6. 访问 API 文档
|
||||
# 浏览器打开: http://localhost:8000/docs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📞 故障排查
|
||||
|
||||
### 问题 1: 导入错误
|
||||
```
|
||||
ModuleNotFoundError: No module named 'jose'
|
||||
```
|
||||
**解决**: 安装依赖 `pip install python-jose[cryptography]`
|
||||
|
||||
### 问题 2: 数据库未初始化
|
||||
```
|
||||
503 Service Unavailable: Database not initialized
|
||||
```
|
||||
**解决**: 检查 `main.py` 中的 lifespan 函数是否正确设置 `app.state.db`
|
||||
|
||||
### 问题 3: Token 验证失败
|
||||
```
|
||||
401 Unauthorized: Could not validate credentials
|
||||
```
|
||||
**解决**:
|
||||
1. 检查 SECRET_KEY 是否配置正确
|
||||
2. 确认 Token 格式:`Authorization: Bearer {token}`
|
||||
3. 检查 Token 是否过期
|
||||
|
||||
### 问题 4: 表不存在
|
||||
```
|
||||
relation "users" does not exist
|
||||
```
|
||||
**解决**: 执行数据库迁移脚本
|
||||
|
||||
---
|
||||
|
||||
## 📖 相关文档
|
||||
|
||||
- **使用指南**: `SECURITY_README.md`
|
||||
- **部署指南**: `DEPLOYMENT.md`
|
||||
- **实施总结**: `SECURITY_IMPLEMENTATION_SUMMARY.md`
|
||||
- **自动设置**: `setup_security.sh`
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-02
|
||||
**状态**: ✅ API 已完全集成
|
||||
@@ -1,370 +0,0 @@
|
||||
# 安全功能实施总结
|
||||
|
||||
## ✅ 已完成的功能
|
||||
|
||||
本次实施完成了完整的安全体系,包括数据加密、身份认证、权限管理、审计日志四大模块。
|
||||
|
||||
---
|
||||
|
||||
## 📁 新增文件清单
|
||||
|
||||
### 核心功能模块
|
||||
|
||||
1. **数据加密**
|
||||
- `app/core/encryption.py` - Fernet 加密实现
|
||||
- `tests/test_encryption.py` - 加密功能测试
|
||||
|
||||
2. **用户系统**
|
||||
- `app/domain/models/role.py` - 用户角色枚举
|
||||
- `app/domain/schemas/user.py` - 用户数据模型
|
||||
- `app/infra/repositories/user_repository.py` - 用户数据访问层
|
||||
|
||||
3. **认证授权**
|
||||
- `app/api/v1/endpoints/auth.py` - 认证接口(已重构)
|
||||
- `app/auth/dependencies.py` - 认证依赖项(已更新)
|
||||
- `app/auth/permissions.py` - 权限控制装饰器
|
||||
- `app/api/v1/endpoints/user_management.py` - 用户管理接口
|
||||
|
||||
4. **审计日志**
|
||||
- `app/core/audit.py` - 审计日志核心(已完善)
|
||||
- `app/domain/schemas/audit.py` - 审计日志数据模型
|
||||
- `app/infra/repositories/audit_repository.py` - 审计日志数据访问层
|
||||
- `app/api/v1/endpoints/audit.py` - 审计日志查询接口
|
||||
- `app/infra/audit/middleware.py` - 自动审计中间件
|
||||
|
||||
### 数据库迁移
|
||||
|
||||
5. **迁移脚本**
|
||||
- `migrations/001_create_users_table.sql` - 用户表
|
||||
- `migrations/002_create_audit_logs_table.sql` - 审计日志表
|
||||
|
||||
### 配置和文档
|
||||
|
||||
6. **配置文件**
|
||||
- `.env.example` - 环境变量模板
|
||||
- `app/core/config.py` - 配置文件(已更新)
|
||||
- `app/core/security.py` - 安全工具(已增强)
|
||||
|
||||
7. **文档**
|
||||
- `SECURITY_README.md` - 完整使用指南(79KB+)
|
||||
- `DEPLOYMENT.md` - 部署和集成指南
|
||||
- `SECURITY_IMPLEMENTATION_SUMMARY.md` - 本文件
|
||||
|
||||
8. **工具**
|
||||
- `setup_security.sh` - 快速设置脚本
|
||||
|
||||
---
|
||||
|
||||
## 🎯 功能特性
|
||||
|
||||
### 1. 数据加密
|
||||
- ✅ 使用 Fernet(AES-128)对称加密
|
||||
- ✅ 支持密钥生成和管理
|
||||
- ✅ 自动从环境变量读取密钥
|
||||
- ✅ 完整的加密/解密 API
|
||||
- ✅ 单元测试覆盖
|
||||
|
||||
### 2. 身份认证
|
||||
- ✅ 基于 JWT 的 Token 认证
|
||||
- ✅ Access Token + Refresh Token 机制
|
||||
- ✅ 用户注册/登录接口
|
||||
- ✅ 支持用户名或邮箱登录
|
||||
- ✅ 密码使用 bcrypt 哈希存储
|
||||
- ✅ Token 过期时间可配置
|
||||
- ✅ 向后兼容旧接口
|
||||
|
||||
### 3. 权限管理(RBAC)
|
||||
- ✅ 4 个预定义角色:ADMIN, OPERATOR, USER, VIEWER
|
||||
- ✅ 基于角色层级的权限检查
|
||||
- ✅ 可复用的权限装饰器
|
||||
- ✅ 资源所有者检查
|
||||
- ✅ 灵活的依赖注入设计
|
||||
|
||||
### 4. 审计日志
|
||||
- ✅ 自动记录所有关键操作
|
||||
- ✅ 记录用户、时间、操作类型、资源等信息
|
||||
- ✅ 敏感数据自动脱敏
|
||||
- ✅ 支持按多条件查询
|
||||
- ✅ 管理员专用查询接口
|
||||
- ✅ 用户可查看自己的操作记录
|
||||
|
||||
---
|
||||
|
||||
## 📊 技术栈
|
||||
|
||||
| 组件 | 技术 | 说明 |
|
||||
|------|------|------|
|
||||
| 加密 | cryptography.Fernet | 对称加密 |
|
||||
| 密码哈希 | bcrypt | 密码安全存储 |
|
||||
| JWT | python-jose | Token 生成和验证 |
|
||||
| 数据库 | PostgreSQL + psycopg | 异步数据访问 |
|
||||
| Web框架 | FastAPI | 现代异步框架 |
|
||||
| 数据验证 | Pydantic | 类型安全的数据模型 |
|
||||
|
||||
---
|
||||
|
||||
## 🔐 安全特性
|
||||
|
||||
1. **密码安全**
|
||||
- bcrypt 哈希(work factor = 12)
|
||||
- 自动加盐
|
||||
- 不可逆加密
|
||||
|
||||
2. **Token 安全**
|
||||
- JWT 签名验证
|
||||
- 短期 Access Token(30分钟)
|
||||
- 长期 Refresh Token(7天)
|
||||
- Token 类型校验
|
||||
|
||||
3. **数据保护**
|
||||
- 敏感字段自动脱敏
|
||||
- 审计日志不记录密码
|
||||
- 加密密钥从环境变量读取
|
||||
|
||||
4. **访问控制**
|
||||
- 基于角色的细粒度权限
|
||||
- 资源级别的访问控制
|
||||
- 自动验证用户激活状态
|
||||
|
||||
---
|
||||
|
||||
## 📈 数据库设计
|
||||
|
||||
### users 表
|
||||
```
|
||||
用户表 - 存储系统用户
|
||||
- id (主键)
|
||||
- username (唯一)
|
||||
- email (唯一)
|
||||
- hashed_password
|
||||
- role (ADMIN/OPERATOR/USER/VIEWER)
|
||||
- is_active
|
||||
- is_superuser
|
||||
- created_at
|
||||
- updated_at (自动更新)
|
||||
```
|
||||
|
||||
### audit_logs 表
|
||||
```
|
||||
审计日志表 - 记录所有关键操作
|
||||
- id (主键)
|
||||
- user_id (外键)
|
||||
- username (冗余字段)
|
||||
- action (操作类型)
|
||||
- resource_type (资源类型)
|
||||
- resource_id (资源ID)
|
||||
- ip_address
|
||||
- user_agent
|
||||
- request_method
|
||||
- request_path
|
||||
- request_data (JSONB)
|
||||
- response_status
|
||||
- error_message
|
||||
- timestamp
|
||||
```
|
||||
|
||||
**索引优化**:
|
||||
- users: username, email, role, is_active
|
||||
- audit_logs: user_id, username, timestamp, action, resource
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 方法 1: 使用自动化脚本
|
||||
|
||||
```bash
|
||||
./setup_security.sh
|
||||
```
|
||||
|
||||
### 方法 2: 手动设置
|
||||
|
||||
```bash
|
||||
# 1. 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填写密钥和数据库配置
|
||||
|
||||
# 2. 执行数据库迁移
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
|
||||
# 3. 测试
|
||||
python tests/test_encryption.py
|
||||
|
||||
# 4. 启动服务
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📋 集成检查清单
|
||||
|
||||
### 必需步骤
|
||||
|
||||
- [ ] 复制 `.env.example` 为 `.env` 并配置
|
||||
- [ ] 生成 JWT 密钥(SECRET_KEY)
|
||||
- [ ] 生成加密密钥(ENCRYPTION_KEY)
|
||||
- [ ] 配置数据库连接信息
|
||||
- [ ] 执行用户表迁移脚本
|
||||
- [ ] 执行审计日志表迁移脚本
|
||||
- [ ] 验证默认管理员可登录
|
||||
|
||||
### 可选步骤
|
||||
|
||||
- [ ] 在 main.py 中添加审计中间件
|
||||
- [ ] 为现有接口添加权限控制
|
||||
- [ ] 注册新的路由(auth, user_management, audit)
|
||||
- [ ] 替换硬编码的认证逻辑
|
||||
- [ ] 配置 Token 过期时间
|
||||
|
||||
---
|
||||
|
||||
## 🔄 向后兼容性
|
||||
|
||||
### 保留的旧接口
|
||||
|
||||
1. **简化登录**: `/api/v1/auth/login/simple`
|
||||
- 仍可使用 `username` 和 `password` 参数
|
||||
- 返回标准 Token 响应
|
||||
|
||||
2. **硬编码用户迁移**
|
||||
- 原有 `tjwater/tjwater@123` 已迁移到数据库
|
||||
- 保持相同的用户名和密码
|
||||
|
||||
### 渐进式迁移
|
||||
|
||||
可以逐步迁移现有接口:
|
||||
|
||||
1. 新接口直接使用新认证系统
|
||||
2. 旧接口保持不变
|
||||
3. 逐个替换旧接口的认证逻辑
|
||||
|
||||
---
|
||||
|
||||
## 📚 API 端点总览
|
||||
|
||||
### 认证接口 (`/api/v1/auth/`)
|
||||
|
||||
| 方法 | 路径 | 说明 | 权限 |
|
||||
|------|------|------|------|
|
||||
| POST | `/register` | 用户注册 | 公开 |
|
||||
| POST | `/login` | OAuth2 登录 | 公开 |
|
||||
| POST | `/login/simple` | 简化登录 | 公开 |
|
||||
| GET | `/me` | 获取当前用户 | 认证用户 |
|
||||
| POST | `/refresh` | 刷新Token | 认证用户 |
|
||||
|
||||
### 用户管理 (`/api/v1/users/`)
|
||||
|
||||
| 方法 | 路径 | 说明 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/` | 用户列表 | 管理员 |
|
||||
| GET | `/{id}` | 用户详情 | 所有者/管理员 |
|
||||
| PUT | `/{id}` | 更新用户 | 所有者/管理员 |
|
||||
| DELETE | `/{id}` | 删除用户 | 管理员 |
|
||||
| POST | `/{id}/activate` | 激活用户 | 管理员 |
|
||||
| POST | `/{id}/deactivate` | 停用用户 | 管理员 |
|
||||
|
||||
### 审计日志 (`/api/v1/audit/`)
|
||||
|
||||
| 方法 | 路径 | 说明 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/logs` | 查询审计日志 | 管理员 |
|
||||
| GET | `/logs/count` | 日志总数 | 管理员 |
|
||||
| GET | `/logs/my` | 我的操作记录 | 认证用户 |
|
||||
|
||||
---
|
||||
|
||||
## 🎓 使用示例
|
||||
|
||||
### Python 示例
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# 登录
|
||||
resp = requests.post("http://localhost:8000/api/v1/auth/login",
|
||||
data={"username": "admin", "password": "admin123"})
|
||||
token = resp.json()["access_token"]
|
||||
|
||||
# 访问受保护接口
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
resp = requests.get("http://localhost:8000/api/v1/auth/me", headers=headers)
|
||||
print(resp.json())
|
||||
```
|
||||
|
||||
### cURL 示例
|
||||
|
||||
```bash
|
||||
# 登录
|
||||
TOKEN=$(curl -s -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-d "username=admin&password=admin123" | jq -r .access_token)
|
||||
|
||||
# 查询审计日志
|
||||
curl -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8000/api/v1/audit/logs?action=LOGIN"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🐛 常见问题
|
||||
|
||||
### Q: 如何修改默认管理员密码?
|
||||
|
||||
A: 登录后通过 PUT `/api/v1/users/{id}` 接口修改,或直接更新数据库。
|
||||
|
||||
### Q: 如何添加新用户?
|
||||
|
||||
A: 使用 POST `/api/v1/auth/register` 接口,或由管理员在用户管理界面创建。
|
||||
|
||||
### Q: 审计日志可以删除吗?
|
||||
|
||||
A: 不建议删除。可以归档到冷存储,保留最近 90 天的数据。
|
||||
|
||||
### Q: Token 过期了怎么办?
|
||||
|
||||
A: 使用 Refresh Token 调用 `/api/v1/auth/refresh` 接口获取新的 Access Token。
|
||||
|
||||
---
|
||||
|
||||
## 📞 技术支持
|
||||
|
||||
- **完整文档**: `SECURITY_README.md`
|
||||
- **部署指南**: `DEPLOYMENT.md`
|
||||
- **测试代码**: `tests/test_encryption.py`
|
||||
- **迁移脚本**: `migrations/`
|
||||
|
||||
---
|
||||
|
||||
## 📝 待办事项(可选)
|
||||
|
||||
未来可以扩展的功能:
|
||||
|
||||
- [ ] 邮件验证
|
||||
- [ ] 密码重置
|
||||
- [ ] 双因素认证(2FA)
|
||||
- [ ] 单点登录(SSO)
|
||||
- [ ] Token 黑名单
|
||||
- [ ] 会话管理
|
||||
- [ ] IP 白名单
|
||||
- [ ] 登录频率限制
|
||||
- [ ] 密码复杂度策略
|
||||
- [ ] 审计日志自动归档
|
||||
|
||||
---
|
||||
|
||||
## 🎉 总结
|
||||
|
||||
本次实施完成了企业级的安全体系,包含:
|
||||
|
||||
✅ 数据加密 - Fernet 对称加密
|
||||
✅ 身份认证 - JWT Token + bcrypt 密码哈希
|
||||
✅ 权限管理 - 基于角色的访问控制(RBAC)
|
||||
✅ 审计日志 - 自动追踪所有关键操作
|
||||
|
||||
所有功能均遵循安全最佳实践,提供完整的文档和测试,可直接投入生产使用。
|
||||
|
||||
---
|
||||
|
||||
**实施日期**: 2026-02-02
|
||||
**版本**: v1.0.0
|
||||
**状态**: ✅ 已完成
|
||||
@@ -1,499 +0,0 @@
|
||||
# 安全功能使用指南
|
||||
|
||||
TJWater Server 安全体系实施完成,包含:数据加密、身份认证、权限管理、审计日志
|
||||
|
||||
## 📋 目录
|
||||
|
||||
1. [快速开始](#快速开始)
|
||||
2. [数据加密](#数据加密)
|
||||
3. [身份认证](#身份认证)
|
||||
4. [权限管理](#权限管理)
|
||||
5. [审计日志](#审计日志)
|
||||
6. [数据库迁移](#数据库迁移)
|
||||
7. [API 使用示例](#api-使用示例)
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 配置环境变量
|
||||
|
||||
复制 `.env.example` 为 `.env` 并配置:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
生成必要的密钥:
|
||||
|
||||
```bash
|
||||
# 生成 JWT 密钥
|
||||
openssl rand -hex 32
|
||||
|
||||
# 生成加密密钥
|
||||
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
```
|
||||
|
||||
编辑 `.env` 文件:
|
||||
|
||||
```env
|
||||
SECRET_KEY=your-generated-jwt-secret-key
|
||||
ENCRYPTION_KEY=your-generated-encryption-key
|
||||
DB_NAME=tjwater
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=your-db-password
|
||||
```
|
||||
|
||||
### 2. 执行数据库迁移
|
||||
|
||||
```bash
|
||||
# 连接到 PostgreSQL
|
||||
psql -U postgres -d tjwater
|
||||
|
||||
# 执行迁移脚本
|
||||
\i migrations/001_create_users_table.sql
|
||||
\i migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
或使用命令行:
|
||||
|
||||
```bash
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
### 3. 验证安装
|
||||
|
||||
默认创建了两个管理员账号:
|
||||
|
||||
- **用户名**: `admin` / **密码**: `admin123`
|
||||
- **用户名**: `tjwater` / **密码**: `tjwater@123`
|
||||
|
||||
---
|
||||
|
||||
## 🔐 数据加密
|
||||
|
||||
### 使用加密器
|
||||
|
||||
```python
|
||||
from app.core.encryption import get_encryptor
|
||||
|
||||
encryptor = get_encryptor()
|
||||
|
||||
# 加密敏感数据
|
||||
encrypted_data = encryptor.encrypt("sensitive information")
|
||||
|
||||
# 解密
|
||||
decrypted_data = encryptor.decrypt(encrypted_data)
|
||||
```
|
||||
|
||||
### 生成新密钥
|
||||
|
||||
```python
|
||||
from app.core.encryption import Encryptor
|
||||
|
||||
new_key = Encryptor.generate_key()
|
||||
print(f"New encryption key: {new_key}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 👤 身份认证
|
||||
|
||||
### 用户角色
|
||||
|
||||
系统定义了 4 个角色(权限由低到高):
|
||||
|
||||
| 角色 | 权限说明 |
|
||||
|------|---------|
|
||||
| `VIEWER` | 仅查询权限 |
|
||||
| `USER` | 读写权限 |
|
||||
| `OPERATOR` | 操作员,可修改数据 |
|
||||
| `ADMIN` | 管理员,完全权限 |
|
||||
|
||||
### API 接口
|
||||
|
||||
#### 用户注册
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/register
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"username": "newuser",
|
||||
"email": "user@example.com",
|
||||
"password": "password123",
|
||||
"role": "USER"
|
||||
}
|
||||
```
|
||||
|
||||
#### 用户登录(OAuth2 标准)
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/login
|
||||
Content-Type: application/x-www-form-urlencoded
|
||||
|
||||
username=admin&password=admin123
|
||||
```
|
||||
|
||||
响应:
|
||||
|
||||
```json
|
||||
{
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800
|
||||
}
|
||||
```
|
||||
|
||||
#### 用户登录(简化版)
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/login/simple?username=admin&password=admin123
|
||||
```
|
||||
|
||||
#### 获取当前用户信息
|
||||
|
||||
```http
|
||||
GET /api/v1/auth/me
|
||||
Authorization: Bearer {access_token}
|
||||
```
|
||||
|
||||
#### 刷新 Token
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/refresh
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"refresh_token": "your-refresh-token"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔑 权限管理
|
||||
|
||||
### 在 API 中使用权限控制
|
||||
|
||||
#### 方式 1: 使用预定义依赖
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, Depends
|
||||
from app.auth.permissions import get_current_admin, get_current_operator
|
||||
from app.domain.schemas.user import UserInDB
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/admin-only")
|
||||
async def admin_endpoint(
|
||||
current_user: UserInDB = Depends(get_current_admin)
|
||||
):
|
||||
"""仅管理员可访问"""
|
||||
return {"message": "Admin access granted"}
|
||||
|
||||
@router.post("/operator-only")
|
||||
async def operator_endpoint(
|
||||
current_user: UserInDB = Depends(get_current_operator)
|
||||
):
|
||||
"""操作员及以上可访问"""
|
||||
return {"message": "Operator access granted"}
|
||||
```
|
||||
|
||||
#### 方式 2: 使用 require_role
|
||||
|
||||
```python
|
||||
from app.auth.permissions import require_role
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
@router.get("/viewer-access")
|
||||
async def viewer_endpoint(
|
||||
current_user: UserInDB = Depends(require_role(UserRole.VIEWER))
|
||||
):
|
||||
"""所有认证用户可访问"""
|
||||
return {"data": "visible to all"}
|
||||
```
|
||||
|
||||
#### 方式 3: 手动检查权限
|
||||
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
from app.auth.permissions import check_resource_owner
|
||||
|
||||
@router.put("/users/{user_id}")
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
current_user: UserInDB = Depends(get_current_active_user)
|
||||
):
|
||||
"""检查是否是资源拥有者或管理员"""
|
||||
if not check_resource_owner(user_id, current_user):
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# 执行更新操作
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📝 审计日志
|
||||
|
||||
### 自动审计
|
||||
|
||||
使用中间件自动记录关键操作,在 `main.py` 中添加:
|
||||
|
||||
```python
|
||||
from app.infra.audit.middleware import AuditMiddleware
|
||||
|
||||
app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
自动记录:
|
||||
- 所有 POST/PUT/DELETE 请求
|
||||
- 登录/登出事件
|
||||
- 关键资源访问
|
||||
|
||||
### 手动记录审计日志
|
||||
|
||||
```python
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
|
||||
await log_audit_event(
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
username=current_user.username,
|
||||
resource_type="project",
|
||||
resource_id="123",
|
||||
ip_address=request.client.host,
|
||||
request_data={"field": "value"},
|
||||
response_status=200
|
||||
)
|
||||
```
|
||||
|
||||
### 查询审计日志
|
||||
|
||||
#### 获取所有审计日志(仅管理员)
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs?skip=0&limit=100
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
#### 按条件过滤
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs?user_id=1&action=LOGIN&start_time=2024-01-01T00:00:00
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
#### 获取我的操作记录
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs/my
|
||||
Authorization: Bearer {access_token}
|
||||
```
|
||||
|
||||
#### 获取日志总数
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs/count?action=LOGIN
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 💾 数据库迁移
|
||||
|
||||
### 用户表结构
|
||||
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(50) UNIQUE NOT NULL,
|
||||
email VARCHAR(100) UNIQUE NOT NULL,
|
||||
hashed_password VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(20) DEFAULT 'USER' NOT NULL,
|
||||
is_active BOOLEAN DEFAULT TRUE NOT NULL,
|
||||
is_superuser BOOLEAN DEFAULT FALSE NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
### 审计日志表结构
|
||||
|
||||
```sql
|
||||
CREATE TABLE audit_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER REFERENCES users(id),
|
||||
username VARCHAR(50),
|
||||
action VARCHAR(50) NOT NULL,
|
||||
resource_type VARCHAR(50),
|
||||
resource_id VARCHAR(100),
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
request_method VARCHAR(10),
|
||||
request_path TEXT,
|
||||
request_data JSONB,
|
||||
response_status INTEGER,
|
||||
error_message TEXT,
|
||||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 API 使用示例
|
||||
|
||||
### Python 客户端示例
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
BASE_URL = "http://localhost:8000/api/v1"
|
||||
|
||||
# 1. 登录
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/auth/login",
|
||||
data={"username": "admin", "password": "admin123"}
|
||||
)
|
||||
token = response.json()["access_token"]
|
||||
|
||||
# 2. 设置 Authorization Header
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# 3. 获取当前用户信息
|
||||
response = requests.get(f"{BASE_URL}/auth/me", headers=headers)
|
||||
print(response.json())
|
||||
|
||||
# 4. 创建新用户(需要管理员权限)
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/auth/register",
|
||||
headers=headers,
|
||||
json={
|
||||
"username": "newuser",
|
||||
"email": "new@example.com",
|
||||
"password": "password123",
|
||||
"role": "USER"
|
||||
}
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
# 5. 查询审计日志(需要管理员权限)
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/audit/logs?action=LOGIN",
|
||||
headers=headers
|
||||
)
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
### cURL 示例
|
||||
|
||||
```bash
|
||||
# 登录
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=admin&password=admin123"
|
||||
|
||||
# 使用 Token 访问受保护接口
|
||||
TOKEN="your-access-token"
|
||||
curl -X GET "http://localhost:8000/api/v1/auth/me" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
|
||||
# 注册新用户
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/register" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": "password123",
|
||||
"role": "USER"
|
||||
}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛡️ 安全最佳实践
|
||||
|
||||
1. **密钥管理**
|
||||
- 绝不在代码中硬编码密钥
|
||||
- 定期轮换 JWT 密钥
|
||||
- 使用强随机密钥
|
||||
|
||||
2. **密码策略**
|
||||
- 最小长度 6 个字符(建议 12+)
|
||||
- 强制密码复杂度(可在注册时添加验证)
|
||||
- 定期提醒用户更换密码
|
||||
|
||||
3. **Token 管理**
|
||||
- Access Token 短期有效(默认 30 分钟)
|
||||
- Refresh Token 长期有效(默认 7 天)
|
||||
- 实施 Token 黑名单(可选)
|
||||
|
||||
4. **审计日志**
|
||||
- 审计日志不可删除
|
||||
- 定期归档旧日志
|
||||
- 监控异常登录行为
|
||||
|
||||
5. **权限控制**
|
||||
- 遵循最小权限原则
|
||||
- 定期审查用户权限
|
||||
- 记录所有权限变更
|
||||
|
||||
---
|
||||
|
||||
## 📚 相关文件
|
||||
|
||||
- **配置**: `app/core/config.py`
|
||||
- **加密**: `app/core/encryption.py`
|
||||
- **安全**: `app/core/security.py`
|
||||
- **审计**: `app/core/audit.py`
|
||||
- **认证**: `app/api/v1/endpoints/auth.py`
|
||||
- **权限**: `app/auth/permissions.py`
|
||||
- **用户管理**: `app/api/v1/endpoints/user_management.py`
|
||||
- **审计日志**: `app/api/v1/endpoints/audit.py`
|
||||
- **迁移脚本**: `migrations/`
|
||||
|
||||
---
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
### Q: 忘记密码怎么办?
|
||||
|
||||
A: 目前需要管理员通过数据库重置。未来可添加邮件重置功能。
|
||||
|
||||
```sql
|
||||
-- 重置密码为 "newpassword123"
|
||||
UPDATE users
|
||||
SET hashed_password = '$2b$12$...' -- 使用 bcrypt 生成哈希
|
||||
WHERE username = 'targetuser';
|
||||
```
|
||||
|
||||
### Q: 如何添加新角色?
|
||||
|
||||
A: 编辑 `app/domain/models/role.py` 中的 `UserRole` 枚举,并更新数据库约束。
|
||||
|
||||
### Q: 审计日志占用太多空间?
|
||||
|
||||
A: 建议定期归档旧日志到冷存储:
|
||||
|
||||
```sql
|
||||
-- 归档 90 天前的日志
|
||||
CREATE TABLE audit_logs_archive AS
|
||||
SELECT * FROM audit_logs WHERE timestamp < NOW() - INTERVAL '90 days';
|
||||
|
||||
DELETE FROM audit_logs WHERE timestamp < NOW() - INTERVAL '90 days';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📞 技术支持
|
||||
|
||||
如有问题,请查看:
|
||||
- 日志文件: `logs/`
|
||||
- 数据库表结构: `migrations/`
|
||||
- 单元测试: `tests/`
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from app.algorithms.data_cleaning import flow_data_clean, pressure_data_clean
|
||||
from app.algorithms.sensors import (
|
||||
from app.algorithms.cleaning import flow_data_clean, pressure_data_clean
|
||||
from app.algorithms.sensor import (
|
||||
pressure_sensor_placement_sensitivity,
|
||||
pressure_sensor_placement_kmeans,
|
||||
)
|
||||
from app.algorithms.valve_isolation import valve_isolation_analysis
|
||||
from app.algorithms.simulations import (
|
||||
from app.algorithms.isolation.valve import valve_isolation_analysis
|
||||
from app.algorithms.leakage import LeakageIdentifier
|
||||
from app.algorithms.health import PipelineHealthAnalyzer
|
||||
from app.algorithms.burst_location import run_burst_location
|
||||
from app.algorithms.simulation.scenarios import (
|
||||
convert_to_local_unit,
|
||||
burst_analysis,
|
||||
valve_close_analysis,
|
||||
@@ -27,4 +30,7 @@ __all__ = [
|
||||
"age_analysis",
|
||||
"pressure_regulation",
|
||||
"valve_isolation_analysis",
|
||||
"LeakageIdentifier",
|
||||
"PipelineHealthAnalyzer",
|
||||
"run_burst_location",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def fill_time_gaps(
|
||||
data: pd.DataFrame,
|
||||
time_col: str = "time",
|
||||
freq: str = "1min",
|
||||
short_gap_threshold: int = 10,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
补齐缺失时间戳并填补数据缺口。
|
||||
|
||||
Args:
|
||||
data: 包含时间列的 DataFrame
|
||||
time_col: 时间列名(默认 'time')
|
||||
freq: 重采样频率(默认 '1min')
|
||||
short_gap_threshold: 短缺口阈值(分钟),<=此值用线性插值,>此值用前向填充
|
||||
|
||||
Returns:
|
||||
补齐时间后的 DataFrame(保留原时间列格式)
|
||||
"""
|
||||
if time_col not in data.columns:
|
||||
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
|
||||
|
||||
# 解析时间列并设为索引
|
||||
data = data.copy()
|
||||
data[time_col] = pd.to_datetime(data[time_col], utc=True)
|
||||
data_indexed = data.set_index(time_col)
|
||||
|
||||
# 生成完整时间范围
|
||||
full_range = pd.date_range(
|
||||
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
|
||||
)
|
||||
|
||||
# 重索引以补齐缺失时间点,同时保留原始时间戳
|
||||
combined_index = data_indexed.index.union(full_range).sort_values().unique()
|
||||
data_reindexed = data_indexed.reindex(combined_index)
|
||||
|
||||
# 按列处理缺口
|
||||
for col in data_reindexed.columns:
|
||||
# 识别缺失值位置
|
||||
is_missing = data_reindexed[col].isna()
|
||||
|
||||
# 计算连续缺失的长度
|
||||
missing_groups = (is_missing != is_missing.shift()).cumsum()
|
||||
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
|
||||
|
||||
# 短缺口:时间插值
|
||||
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
|
||||
if short_gap_mask.any():
|
||||
data_reindexed.loc[short_gap_mask, col] = (
|
||||
data_reindexed[col]
|
||||
.interpolate(method="time", limit_area="inside")
|
||||
.loc[short_gap_mask]
|
||||
)
|
||||
|
||||
# 长缺口:前向填充
|
||||
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
|
||||
if long_gap_mask.any():
|
||||
data_reindexed.loc[long_gap_mask, col] = (
|
||||
data_reindexed[col].ffill().loc[long_gap_mask]
|
||||
)
|
||||
|
||||
# 重置索引并恢复时间列(保留原格式)
|
||||
data_result = data_reindexed.reset_index()
|
||||
data_result.rename(columns={"index": time_col}, inplace=True)
|
||||
|
||||
# 保留时区信息
|
||||
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
# 修正时区格式(Python的%z输出为+0000,需转为+00:00)
|
||||
data_result[time_col] = data_result[time_col].str.replace(
|
||||
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
|
||||
)
|
||||
|
||||
return data_result
|
||||
|
||||
|
||||
def _cleanup_temp_files(prefix: str) -> None:
|
||||
"""清理 EPANET 仿真产生的临时文件。"""
|
||||
for ext in [".inp", ".rpt", ".bin", ".out"]:
|
||||
temp_file = prefix + ext
|
||||
if os.path.exists(temp_file):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except OSError:
|
||||
pass
|
||||
@@ -1,3 +0,0 @@
|
||||
from .flow_data_clean import *
|
||||
from .pressure_data_clean import *
|
||||
from .pipeline_health_analyzer import *
|
||||
@@ -1,355 +0,0 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.impute import SimpleImputer
|
||||
import os
|
||||
|
||||
|
||||
def fill_time_gaps(
|
||||
data: pd.DataFrame,
|
||||
time_col: str = "time",
|
||||
freq: str = "1min",
|
||||
short_gap_threshold: int = 10,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
补齐缺失时间戳并填补数据缺口。
|
||||
|
||||
Args:
|
||||
data: 包含时间列的 DataFrame
|
||||
time_col: 时间列名(默认 'time')
|
||||
freq: 重采样频率(默认 '1min')
|
||||
short_gap_threshold: 短缺口阈值(分钟),<=此值用线性插值,>此值用前向填充
|
||||
|
||||
Returns:
|
||||
补齐时间后的 DataFrame(保留原时间列格式)
|
||||
"""
|
||||
if time_col not in data.columns:
|
||||
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
|
||||
|
||||
# 解析时间列并设为索引
|
||||
data = data.copy()
|
||||
data[time_col] = pd.to_datetime(data[time_col], utc=True)
|
||||
data_indexed = data.set_index(time_col)
|
||||
|
||||
# 生成完整时间范围
|
||||
full_range = pd.date_range(
|
||||
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
|
||||
)
|
||||
|
||||
# 重索引以补齐缺失时间点,同时保留原始时间戳
|
||||
combined_index = data_indexed.index.union(full_range).sort_values().unique()
|
||||
data_reindexed = data_indexed.reindex(combined_index)
|
||||
|
||||
# 按列处理缺口
|
||||
for col in data_reindexed.columns:
|
||||
# 识别缺失值位置
|
||||
is_missing = data_reindexed[col].isna()
|
||||
|
||||
# 计算连续缺失的长度
|
||||
missing_groups = (is_missing != is_missing.shift()).cumsum()
|
||||
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
|
||||
|
||||
# 短缺口:时间插值
|
||||
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
|
||||
if short_gap_mask.any():
|
||||
data_reindexed.loc[short_gap_mask, col] = (
|
||||
data_reindexed[col]
|
||||
.interpolate(method="time", limit_area="inside")
|
||||
.loc[short_gap_mask]
|
||||
)
|
||||
|
||||
# 长缺口:前向填充
|
||||
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
|
||||
if long_gap_mask.any():
|
||||
data_reindexed.loc[long_gap_mask, col] = (
|
||||
data_reindexed[col].ffill().loc[long_gap_mask]
|
||||
)
|
||||
|
||||
# 重置索引并恢复时间列(保留原格式)
|
||||
data_result = data_reindexed.reset_index()
|
||||
data_result.rename(columns={"index": time_col}, inplace=True)
|
||||
|
||||
# 保留时区信息
|
||||
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
# 修正时区格式(Python的%z输出为+0000,需转为+00:00)
|
||||
data_result[time_col] = data_result[time_col].str.replace(
|
||||
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
|
||||
)
|
||||
|
||||
return data_result
|
||||
|
||||
|
||||
def clean_pressure_data_km(
|
||||
input_csv_path: str, show_plot: bool = False, fill_gaps: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
读取输入 CSV,基于 KMeans 检测异常并用滚动平均修复。输出为 <input_basename>_cleaned.xlsx(同目录)。
|
||||
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'。
|
||||
返回输出文件的绝对路径。
|
||||
|
||||
Args:
|
||||
input_csv_path: CSV 文件路径
|
||||
show_plot: 是否显示可视化
|
||||
fill_gaps: 是否先补齐时间缺口(默认 True)
|
||||
"""
|
||||
# 读取 CSV
|
||||
input_csv_path = os.path.abspath(input_csv_path)
|
||||
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
|
||||
# 补齐时间缺口(如果数据包含 time 列)
|
||||
if fill_gaps and "time" in data.columns:
|
||||
data = fill_time_gaps(
|
||||
data, time_col="time", freq="1min", short_gap_threshold=10
|
||||
)
|
||||
|
||||
# 分离时间列和数值列
|
||||
time_col_data = None
|
||||
if "time" in data.columns:
|
||||
time_col_data = data["time"]
|
||||
data = data.drop(columns=["time"])
|
||||
# 标准化
|
||||
data_norm = (data - data.mean()) / data.std()
|
||||
|
||||
# 聚类与异常检测
|
||||
k = 3
|
||||
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
|
||||
clusters = kmeans.fit_predict(data_norm)
|
||||
centers = kmeans.cluster_centers_
|
||||
|
||||
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
|
||||
threshold = distances.mean() + 3 * distances.std()
|
||||
|
||||
anomaly_pos = np.where(distances > threshold)[0]
|
||||
anomaly_indices = data.index[anomaly_pos]
|
||||
|
||||
anomaly_details = {}
|
||||
for pos in anomaly_pos:
|
||||
row_norm = data_norm.iloc[pos]
|
||||
cluster_idx = clusters[pos]
|
||||
center = centers[cluster_idx]
|
||||
diff = abs(row_norm - center)
|
||||
main_sensor = diff.idxmax()
|
||||
anomaly_details[data.index[pos]] = main_sensor
|
||||
|
||||
# 修复:滚动平均(窗口可调)
|
||||
data_rolled = data.rolling(window=13, center=True, min_periods=1).mean()
|
||||
data_repaired = data.copy()
|
||||
for pos in anomaly_pos:
|
||||
label = data.index[pos]
|
||||
sensor = anomaly_details[label]
|
||||
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
|
||||
|
||||
# 可选可视化(使用位置作为 x 轴)
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
if show_plot and len(data.columns) > 0:
|
||||
n = len(data)
|
||||
time = np.arange(n)
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data.columns:
|
||||
plt.plot(time, data[col].values, marker="o", markersize=3, label=col)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
plt.plot(pos, data.iloc[pos][sensor], "ro", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("压力监测值")
|
||||
plt.title("各传感器折线图(红色标记主要异常点)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data_repaired.columns:
|
||||
plt.plot(
|
||||
time, data_repaired[col].values, marker="o", markersize=3, label=col
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("修复后压力监测值")
|
||||
plt.title("修复后各传感器折线图(绿色标记修复值)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
# 保存到 Excel:两个 sheet
|
||||
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
|
||||
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
|
||||
output_filename = f"{input_base}_cleaned.xlsx"
|
||||
output_path = os.path.join(input_dir, output_filename)
|
||||
|
||||
# 如果原始数据包含时间列,将其添加回结果
|
||||
data_for_save = data.copy()
|
||||
data_repaired_for_save = data_repaired.copy()
|
||||
if time_col_data is not None:
|
||||
data_for_save.insert(0, "time", time_col_data)
|
||||
data_repaired_for_save.insert(0, "time", time_col_data)
|
||||
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path) # 覆盖同名文件
|
||||
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
|
||||
data_for_save.to_excel(writer, sheet_name="raw_pressure_data", index=False)
|
||||
data_repaired_for_save.to_excel(
|
||||
writer, sheet_name="cleaned_pressusre_data", index=False
|
||||
)
|
||||
|
||||
# 返回输出文件的绝对路径
|
||||
return os.path.abspath(output_path)
|
||||
|
||||
|
||||
def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> dict:
|
||||
"""
|
||||
接收一个 DataFrame 数据结构,使用KMeans聚类检测异常并用滚动平均修复。
|
||||
返回清洗后的字典数据结构。
|
||||
|
||||
Args:
|
||||
data: 输入 DataFrame(可包含 time 列)
|
||||
show_plot: 是否显示可视化
|
||||
"""
|
||||
# 使用传入的 DataFrame
|
||||
data = data.copy()
|
||||
|
||||
# 补齐时间缺口(如果启用且数据包含 time 列)
|
||||
data_filled = fill_time_gaps(
|
||||
data, time_col="time", freq="1min", short_gap_threshold=10
|
||||
)
|
||||
|
||||
# 保存 time 列用于最后合并
|
||||
time_col_series = None
|
||||
if "time" in data_filled.columns:
|
||||
time_col_series = data_filled["time"]
|
||||
|
||||
# 移除 time 列用于后续清洗
|
||||
data_filled = data_filled.drop(columns=["time"])
|
||||
|
||||
# 标准化(使用填充后的数据)
|
||||
data_norm = (data_filled - data_filled.mean()) / data_filled.std()
|
||||
|
||||
# 添加:处理标准化后的 NaN(例如,标准差为0的列),防止异常数据,时间段内所有数据都相同导致计算结果为 NaN
|
||||
imputer = SimpleImputer(
|
||||
strategy="constant", fill_value=0, keep_empty_features=True
|
||||
) # 用 0 填充 NaN,包括全 NaN,并保留空特征
|
||||
data_norm = pd.DataFrame(
|
||||
imputer.fit_transform(data_norm),
|
||||
columns=data_norm.columns,
|
||||
index=data_norm.index,
|
||||
)
|
||||
|
||||
# 聚类与异常检测
|
||||
k = 3
|
||||
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
|
||||
clusters = kmeans.fit_predict(data_norm)
|
||||
centers = kmeans.cluster_centers_
|
||||
|
||||
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
|
||||
threshold = distances.mean() + 3 * distances.std()
|
||||
|
||||
anomaly_pos = np.where(distances > threshold)[0]
|
||||
anomaly_indices = data_filled.index[anomaly_pos]
|
||||
|
||||
anomaly_details = {}
|
||||
for pos in anomaly_pos:
|
||||
row_norm = data_norm.iloc[pos]
|
||||
cluster_idx = clusters[pos]
|
||||
center = centers[cluster_idx]
|
||||
diff = abs(row_norm - center)
|
||||
main_sensor = diff.idxmax()
|
||||
anomaly_details[data_filled.index[pos]] = main_sensor
|
||||
|
||||
# 修复:滚动平均(窗口可调)
|
||||
data_rolled = data_filled.rolling(window=13, center=True, min_periods=1).mean()
|
||||
data_repaired = data_filled.copy()
|
||||
for pos in anomaly_pos:
|
||||
label = data_filled.index[pos]
|
||||
sensor = anomaly_details[label]
|
||||
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
|
||||
|
||||
# 可选可视化(使用位置作为 x 轴)
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
if show_plot and len(data.columns) > 0:
|
||||
n = len(data)
|
||||
time = np.arange(n)
|
||||
n_filled = len(data_filled)
|
||||
time_filled = np.arange(n_filled)
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data.columns:
|
||||
plt.plot(
|
||||
time, data[col].values, marker="o", markersize=3, label=col, alpha=0.5
|
||||
)
|
||||
for col in data_filled.columns:
|
||||
plt.plot(
|
||||
time_filled,
|
||||
data_filled[col].values,
|
||||
marker="x",
|
||||
markersize=3,
|
||||
label=f"{col}_filled",
|
||||
linestyle="--",
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data_filled.index[pos]]
|
||||
plt.plot(pos, data_filled.iloc[pos][sensor], "ro", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("压力监测值")
|
||||
plt.title("各传感器折线图(红色标记主要异常点,虚线为0值填充后)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data_repaired.columns:
|
||||
plt.plot(
|
||||
time_filled, data_repaired[col].values, marker="o", markersize=3, label=col
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data_filled.index[pos]]
|
||||
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("修复后压力监测值")
|
||||
plt.title("修复后各传感器折线图(绿色标记修复值)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
# 将 time 列添加回结果
|
||||
if time_col_series is not None:
|
||||
data_repaired.insert(0, "time", time_col_series)
|
||||
|
||||
# 返回清洗后的字典
|
||||
return data_repaired
|
||||
|
||||
|
||||
# 测试
|
||||
# if __name__ == "__main__":
|
||||
# # 默认使用脚本目录下的 pressure_raw_data.csv
|
||||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# default_csv = os.path.join(script_dir, "pressure_raw_data.csv")
|
||||
# out_path = clean_pressure_data_km(default_csv, show_plot=False)
|
||||
# print("保存路径:", out_path)
|
||||
|
||||
# 测试 clean_pressure_data_dict_km 函数
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
# 读取 szh_pressure_scada.csv 文件
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
csv_path = os.path.join(script_dir, "szh_pressure_scada.csv")
|
||||
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
|
||||
# 排除 Time 列,随机选择 5 列
|
||||
columns_to_exclude = ["Time"]
|
||||
available_columns = [col for col in data.columns if col not in columns_to_exclude]
|
||||
selected_columns = random.sample(available_columns, 5)
|
||||
|
||||
# 将选中的列转换为字典
|
||||
data_dict = {col: data[col].tolist() for col in selected_columns}
|
||||
|
||||
print("选中的列:", selected_columns)
|
||||
print("原始数据长度:", len(data_dict[selected_columns[0]]))
|
||||
|
||||
# 调用函数进行清洗
|
||||
cleaned_dict = clean_pressure_data_df_km(data_dict, show_plot=True)
|
||||
|
||||
print("清洗后的字典键:", list(cleaned_dict.keys()))
|
||||
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
|
||||
print("测试完成:函数运行正常")
|
||||
@@ -1,557 +0,0 @@
|
||||
# 改进灵敏度法
|
||||
import networkx
|
||||
import numpy as np
|
||||
import pandas
|
||||
import wntr
|
||||
import pandas as pd
|
||||
import copy
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from sklearn.cluster import KMeans
|
||||
from wntr.epanet.toolkit import EpanetException
|
||||
from numpy.linalg import slogdet
|
||||
import random
|
||||
from app.services.tjnetwork import *
|
||||
import app.services.project_info as project_info
|
||||
|
||||
# 2025/03/12
|
||||
# Step1: 获取节点坐标
|
||||
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
|
||||
"""
|
||||
获取管网模型的节点坐标
|
||||
:param wn: 由wntr生成的模型
|
||||
:return: 节点坐标
|
||||
"""
|
||||
# site: pandas.Series
|
||||
# index:节点名称(wn.node_name_list)
|
||||
# values:每个节点的坐标,格式为 tuple(如 (x, y) 或 (x, y, z))
|
||||
site = wn.query_node_attribute('coordinates')
|
||||
# Coor: pandas.Series
|
||||
# index:与site相同(节点名称)。
|
||||
# values:坐标转换为numpy.ndarray(如array([10.5, 20.3]))
|
||||
Coor = site.apply(lambda x: np.array(x)) # 将节点坐标转换为numpy数组
|
||||
# x, y: list[float]
|
||||
x = [] # 存储所有节点的 x 坐标
|
||||
y = [] # 存储所有节点的 y 坐标
|
||||
for i in range(0, len(Coor)):
|
||||
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
|
||||
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
|
||||
# xy: dict[str, list], x、y 坐标的字典
|
||||
xy = {'x': x, 'y': y}
|
||||
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
|
||||
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
|
||||
return Coor_node
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step2: KMeans 聚类
|
||||
# 将节点用kmeans根据坐标分为k组,存入字典g
|
||||
def kgroup(coor: pandas.DataFrame, knum: int) -> dict[int, list[str]]:
|
||||
"""
|
||||
使用KMeans聚类,将节点坐标分组
|
||||
:param coor: 存储所有节点的坐标数据
|
||||
:param knum: 需要分成的聚类数
|
||||
:return: 聚类结果字典
|
||||
"""
|
||||
g = {}
|
||||
# estimator: sklearn.cluster.KMeans,KMeans 聚类模型
|
||||
estimator = KMeans(n_clusters=knum)
|
||||
estimator.fit(coor)
|
||||
# label_pred: numpy.ndarray(int),每个点的类别标签
|
||||
label_pred = estimator.labels_
|
||||
for i in range(0, knum):
|
||||
g[i] = coor[label_pred == i].index.tolist()
|
||||
return g
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step3: wn_func类,水力计算
|
||||
# wn_func 主要用于计算:
|
||||
# 水力距离(hydraulic length):即节点之间的水力阻力。
|
||||
# 灵敏度分析(sensitivity analysis):用于优化测压点的布置。
|
||||
# 一些与水力相关的函数,包括 CtoS:求水力距离,stafun:求状态函数F
|
||||
# # diff:求F对P的导数,返回灵敏度矩阵A
|
||||
# # sensitivity:返回灵敏度和总灵敏度
|
||||
class wn_func(object):
|
||||
|
||||
# Step3.1: 初始化
|
||||
def __init__(self, wn: wntr.network.WaterNetworkModel):
|
||||
"""
|
||||
获取管网模型信息
|
||||
:param wn: 由wntr生成的模型
|
||||
"""
|
||||
# self.results: wntr.sim.results.SimulationResults,仿真结果,包含压力、流量、水头等数据
|
||||
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
|
||||
self.wn = wn
|
||||
# self.q:pandas.DataFrame,管道流量,索引为时间步长,列为管道名称
|
||||
self.q = self.results.link['flowrate']
|
||||
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
|
||||
ReservoirIndex = wn.reservoir_name_list
|
||||
Tankindex = wn.tank_name_list
|
||||
# 删除水库节点,删除与直接水库相连的虚拟管道
|
||||
# self.pipes: list[str],所有管道的名称
|
||||
self.pipes = wn.pipe_name_list
|
||||
# self.nodes: list[str],所有节点的名称
|
||||
self.nodes = wn.node_name_list
|
||||
# self.coordinates:pandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
|
||||
self.coordinates = wn.query_node_attribute('coordinates')
|
||||
# allpumps / allvalves: list[str],所有泵/阀门名称列表
|
||||
allpumps = wn.pump_name_list
|
||||
allvalves = wn.valve_name_list
|
||||
# pumpstnode / pumpednode / valvestnode / valveednode: list[str],存储泵和阀门 起终点节点的名称
|
||||
pumpstnode = []
|
||||
pumpednode = []
|
||||
valvestnode = []
|
||||
valveednode = []
|
||||
# Reservoirpipe / Reservoirednode: list[str],记录与水库相关的管道和节点
|
||||
Reservoirpipe = []
|
||||
Reservoirednode = []
|
||||
for pump in allpumps:
|
||||
pumpstnode.append(wn.links[pump].start_node.name)
|
||||
pumpednode.append(wn.links[pump].end_node.name)
|
||||
for valve in allvalves:
|
||||
valvestnode.append(wn.links[valve].start_node.name)
|
||||
valveednode.append(wn.links[valve].end_node.name)
|
||||
for pipe in self.pipes:
|
||||
if wn.links[pipe].start_node.name in ReservoirIndex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].end_node.name)
|
||||
if wn.links[pipe].start_node.name in Tankindex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].end_node.name)
|
||||
if wn.links[pipe].end_node.name in Tankindex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].start_node.name)
|
||||
# 泵的起终点、tank、reservoir
|
||||
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
|
||||
self.delnodes = list(
|
||||
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
|
||||
# 泵、起终点为tank、reservoir的管道
|
||||
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
|
||||
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
|
||||
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
|
||||
# self.L: list[float],所有管道的长度(以米为单位)
|
||||
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
|
||||
self.n = len(self.nodes)
|
||||
self.m = len(self.pipes)
|
||||
# self.unit_headloss: list[float],单位水头损失(headloss 数据的第一行,单位:米/km)
|
||||
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
|
||||
##
|
||||
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
|
||||
|
||||
# Step3.2: 计算水力距离
|
||||
def CtoS(self):
|
||||
"""
|
||||
计算水力距离矩阵
|
||||
:return:
|
||||
"""
|
||||
# 水力距离:当行索引对应的节点为控制点时,列索引对应的节点距离控制点的(路径*水头损失)的最小值
|
||||
# nodes:list[str](节点名称)
|
||||
nodes = copy.deepcopy(self.nodes)
|
||||
# pipes:list[str](管道名称)
|
||||
pipes = self.pipes
|
||||
wn = self.wn
|
||||
# n / m:int(节点数 / 管道数)
|
||||
n = self.n
|
||||
m = self.m
|
||||
s1 = [0] * m
|
||||
q = self.q
|
||||
L = self.L
|
||||
# H1:pandas.DataFrame,水头数据,索引为时间步长,列为节点名
|
||||
H1 = self.results.node['head'].T
|
||||
# hh:list[float],计算管道两端水头之差
|
||||
hh = []
|
||||
# 水头损失
|
||||
for p in pipes:
|
||||
h1 = self.wn.links[p].start_node.name
|
||||
h1 = H1.loc[str(h1)]
|
||||
h2 = self.wn.links[p].end_node.name
|
||||
h2 = H1.loc[str(h2)]
|
||||
hh.append(abs(h1 - h2))
|
||||
hh = np.array(hh)
|
||||
# headloss:pandas.DataFrame,管道水头损失矩阵
|
||||
headloss = pd.DataFrame(hh, index=pipes).T
|
||||
# s1:管道阻力系数,s2:将管道阻力系数与管道的起始节点和终止节点对应
|
||||
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵,index代表起始节点,columns代表终止节点
|
||||
G = nx.DiGraph()
|
||||
for i in range(0, m):
|
||||
pipe = pipes[i]
|
||||
a = wn.links[pipe].start_node.name
|
||||
b = wn.links[pipe].end_node.name
|
||||
if q.loc[0, pipe] > 0:
|
||||
hf.loc[a, b] = headloss.loc[0, pipe]
|
||||
weightL.loc[a, b] = headloss.loc[0, pipe] * L[i]
|
||||
G.add_weighted_edges_from([(a, b, weightL.loc[a, b])])
|
||||
|
||||
else:
|
||||
hf.loc[b, a] = headloss.loc[0, pipe]
|
||||
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
|
||||
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
|
||||
|
||||
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
|
||||
for a in nodes:
|
||||
if a in G.nodes:
|
||||
d = nx.shortest_path_length(G, source=a, weight='weight')
|
||||
for b in list(d.keys()):
|
||||
hydraulicL.loc[a, b] = d[b]
|
||||
|
||||
hydraulicL = hydraulicL.drop(self.delnodes)
|
||||
hydraulicL = hydraulicL.drop(self.delnodes, axis=1)
|
||||
|
||||
# 求加权水力距离
|
||||
return hydraulicL, G
|
||||
|
||||
# Step3.3: 计算灵敏度矩阵
|
||||
# 获取关系矩阵
|
||||
def get_Conn(self):
|
||||
"""
|
||||
计算管网连接关系矩阵
|
||||
:return:
|
||||
"""
|
||||
m = self.wn.num_links
|
||||
n = self.wn.num_nodes
|
||||
p = self.wn.num_pumps
|
||||
v = self.wn.num_valves
|
||||
|
||||
self.nonjunc_index = []
|
||||
self.non_link_index = []
|
||||
for r in self.wn.reservoirs():
|
||||
self.nonjunc_index.append(r[0])
|
||||
for t in self.wn.tanks():
|
||||
self.nonjunc_index.append(t[0])
|
||||
# Conn:numpy.matrix,节点-管道连接矩阵,起点 -1,终点 1
|
||||
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
|
||||
# NConn:numpy.matrix,节点-节点连接矩阵,有管道相连的地方设为 1
|
||||
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系,之间有管道为1,反之为0
|
||||
# pipes:list[str],去除泵和阀门的管道列表
|
||||
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
|
||||
for pipe_name, pipe in pipes:
|
||||
start = self.wn.node_name_list.index(pipe.start_node_name)
|
||||
end = self.wn.node_name_list.index(pipe.end_node_name)
|
||||
p_index = self.wn.link_name_list.index(pipe_name)
|
||||
Conn[start, p_index] = -1
|
||||
Conn[end, p_index] = 1
|
||||
NConn[start, end] = 1
|
||||
NConn[end, start] = 1
|
||||
self.A = Conn
|
||||
link_name_list = [link for link in self.wn.link_name_list if
|
||||
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
|
||||
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
|
||||
self.A2 = self.A2.drop(self.delnodes)
|
||||
for pipe in self.delpipes:
|
||||
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
|
||||
self.A2 = self.A2.drop(columns=pipe)
|
||||
self.junc_list = self.A2.index
|
||||
self.A2 = np.mat(self.A2) # 节点管道关系
|
||||
self.A3 = NConn
|
||||
|
||||
def Jaco(self, hL: pandas.DataFrame):
|
||||
"""
|
||||
计算灵敏度矩阵(节点压力对粗糙度变化的响应)
|
||||
:param hL: 水力距离矩阵
|
||||
:return:
|
||||
"""
|
||||
# global result
|
||||
# A:numpy.matrix, 节点-管道关系矩阵
|
||||
A = self.A2
|
||||
wn = self.wn
|
||||
|
||||
try:
|
||||
result = wntr.sim.EpanetSimulator(wn).run_sim()
|
||||
except EpanetException:
|
||||
pass
|
||||
finally:
|
||||
h = result.link['headloss'][self.pipes].values[0]
|
||||
q = result.link['flowrate'][self.pipes].values[0]
|
||||
l = self.wn.query_link_attribute('length')[self.pipes]
|
||||
C = self.wn.query_link_attribute('roughness')[self.pipes]
|
||||
# headloss:numpy.ndarray,水头损失数组
|
||||
headloss = np.array(h)
|
||||
# 调整流量方向
|
||||
for i in range(0, len(q)):
|
||||
if q[i] < 0:
|
||||
A[:, i] = -A[:, i]
|
||||
# q:numpy.ndarray,流量数组
|
||||
q = np.abs(q)
|
||||
# 两个灵敏度矩阵
|
||||
# B / S:numpy.matrix,灵敏度计算的中间矩阵
|
||||
B = np.mat(np.diag(q / ((1.852 * headloss) + 1e-10)))
|
||||
S = np.mat(np.diag(q / C))
|
||||
# X:numpy.matrix, 灵敏度矩阵
|
||||
X = A * B * A.T
|
||||
try:
|
||||
det = np.linalg.det(X)
|
||||
except RuntimeError as e:
|
||||
sign, logdet = slogdet(X) # 防止溢出
|
||||
det = sign * np.exp(logdet)
|
||||
if det != 0:
|
||||
J_H_Cw = X.I * A * S
|
||||
# J_H_Q = -X.I
|
||||
J_q_Cw = S - B * A.T * X.I * A * S # 去掉了delnodes和delpipes
|
||||
# J_q_Q = B * A.T * X.I
|
||||
else: # 当X不可逆
|
||||
J_H_Cw = np.linalg.pinv(X) @ A @ S
|
||||
# J_H_Q = -np.linalg.pinv(X)
|
||||
J_q_Cw = S - B * A.T * np.linalg.pinv(X) * A * S
|
||||
# J_q_Q = B * A.T * np.linalg.pinv(X)
|
||||
|
||||
Sen_pressure = []
|
||||
S_pressure = np.abs(J_H_Cw).sum(axis=1).tolist() # 修改为绝对值
|
||||
for ss in S_pressure:
|
||||
Sen_pressure.append(ss[0])
|
||||
# 求总灵敏度
|
||||
SS_pressure = copy.deepcopy(hL)
|
||||
for i in range(0, len(Sen_pressure)):
|
||||
SS_pressure.iloc[i, :] = SS_pressure.iloc[i, :] * Sen_pressure[i]
|
||||
SS = copy.deepcopy(hL)
|
||||
for i in range(0, len(Sen_pressure)):
|
||||
SS.iloc[i, :] = SS.iloc[i, :] * Sen_pressure[i]
|
||||
# SS[i,j]:节点nodes[i]的灵敏度*该节点到nodes[j]的水力距离
|
||||
return SS
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step4: 传感器布置优化
|
||||
# Sensorplacement
|
||||
# weight:分配权重
|
||||
# sensor:传感器布置的位置
|
||||
class Sensorplacement(wn_func):
|
||||
"""
|
||||
Sensorplacement 类继承了 wn_func 类,并且用于计算和优化传感器布置的位置。
|
||||
"""
|
||||
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int):
|
||||
"""
|
||||
|
||||
:param wn: 由wntr生成的模型
|
||||
:param sensornum: 传感器的数量
|
||||
"""
|
||||
wn_func.__init__(self, wn)
|
||||
self.sensornum = sensornum
|
||||
|
||||
# 1.某个节点到所有节点的加权距离之和
|
||||
# 2.某个节点到该组内所有节点的加权距离之和
|
||||
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
|
||||
"""
|
||||
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
|
||||
:param SS: 灵敏度矩阵,每个节点的行和列代表不同节点,矩阵元素表示节点间的灵敏度。SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
|
||||
:param G: 加权图,表示管网的拓扑结构,每个节点通过管道连接。图的边的权重通常是根据水力距离或者流量等计算的
|
||||
:param group: 节点分组,字典的键是分组编号,值是该组的节点名称列表
|
||||
:return:
|
||||
"""
|
||||
# 传感器布置个数以及位置
|
||||
# W = self.weight()
|
||||
n = self.n - len(self.delnodes)
|
||||
nodes = copy.deepcopy(self.nodes)
|
||||
for node in self.delnodes:
|
||||
nodes.remove(node)
|
||||
# sumSS:list[float],每个节点到其他节点的灵敏度之和。SS.iloc[i, :] 返回第 i 个节点与所有其他节点的灵敏度值,sum(SS.iloc[i, :]) 计算这些灵敏度值的总和。
|
||||
sumSS = []
|
||||
for i in range(0, n):
|
||||
sumSS.append(sum(SS.iloc[i, :]))
|
||||
# 一个整数范围,表示每个节点的索引,用作sumSS_ DataFrame的索引
|
||||
indices = range(0, n)
|
||||
# sumSS_:pandas.DataFrame,将 sumSS 转换成 DataFrame 格式,并且将节点的总灵敏度保存到 CSV 文件 sumSS_data.csv 中
|
||||
sumSS_ = pd.DataFrame(np.array(sumSS), index=indices)
|
||||
sumSS_.to_csv('sumSS_data.csv') # 存储节点总灵敏度
|
||||
# sumSS:pandas.DataFrame,sumSS 被转换为 DataFrame 类型,并且按总灵敏度(即灵敏度之和)降序排列。此时,sumSS 是按节点的灵敏度之和排序的 DataFrame
|
||||
sumSS = pd.DataFrame(np.array(sumSS), index=nodes)
|
||||
sumSS = sumSS.sort_values(by=[0], ascending=[False])
|
||||
# sensorindex:list[str],用于存储根据灵敏度排序选出的传感器位置的节点名称,存储根据总灵敏度排序的节点列表,用于传感器布置
|
||||
sensorindex = []
|
||||
# sensorindex_2:list[str],用于存储每组内根据灵敏度排序选出的传感器位置的节点名称,存储每个组内根据灵敏度排序选择的传感器节点
|
||||
sensorindex_2 = []
|
||||
# group_S:dict[int, pandas.DataFrame],存储每个组内的灵敏度矩阵
|
||||
group_S = {}
|
||||
# group_sumSS:dict[int, list[float]],存储每个组内节点的总灵敏度,值为每个组内节点灵敏度之和的列表
|
||||
group_sumSS = {}
|
||||
for i in range(0, len(group)):
|
||||
for node in self.delnodes:
|
||||
# 这里的group[i]是每个组的节点列表,代码首先去除已经被标记为删除的节点self.delnodes
|
||||
if node in group[i]:
|
||||
group[i].remove(node)
|
||||
group_S[i] = SS.loc[group[i], group[i]]
|
||||
# 对每个组内的节点,计算组内节点的总灵敏度(group_sumSS[i])。它将每个组内节点的灵敏度值相加,并且按灵敏度降序排序
|
||||
group_sumSS[i] = []
|
||||
for j in range(0, len(group[i])):
|
||||
group_sumSS[i].append(sum(group_S[i].iloc[j, :]))
|
||||
group_sumSS[i] = pd.DataFrame(np.array(group_sumSS[i]), index=group[i])
|
||||
group_sumSS[i] = group_sumSS[i].sort_values(by=[0], ascending=[False])
|
||||
pass
|
||||
|
||||
# 1.选sumSS最大的节点,然后把这个节点所在的那个组删掉,就可以不再从这个组选点。再重新排序选sumSS最大的;
|
||||
# 2.在每组内选group_sumSS最大的节点
|
||||
# 在这个循环中,首先选择灵敏度最高的节点Smaxnode并添加到sensorindex。然后根据灵敏度排序,删除已选的节点并继续选择下一个灵敏度最大的节点。这个过程用于选择传感器的位置
|
||||
sensornum = self.sensornum
|
||||
for i in range(0, sensornum):
|
||||
# Smaxnode:str,最大灵敏度节点,sumSS.index[0] 表示灵敏度最高的节点
|
||||
Smaxnode = sumSS.index[0]
|
||||
sensorindex.append(Smaxnode)
|
||||
sensorindex_2.append(group_sumSS[i].index[0])
|
||||
|
||||
for key, value in group.items():
|
||||
if Smaxnode in value:
|
||||
sumSS = sumSS.drop(index=group[key])
|
||||
continue
|
||||
|
||||
sumSS = sumSS.sort_values(by=[0], ascending=[False])
|
||||
|
||||
return sensorindex, sensorindex_2
|
||||
|
||||
|
||||
# 2025/03/13
|
||||
def get_sensor_coord(name: str, sensor_num: int) -> dict[str, float]:
|
||||
"""
|
||||
获取布置测压点的坐标,初始测压点布置根据灵敏度来布置,计算初始情况下的校准过程的error
|
||||
:param name: 数据库名称
|
||||
:param sensor_num: 测压点数目
|
||||
:return: 测压点坐标字典
|
||||
"""
|
||||
# inp_file_real:str,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
|
||||
inp_file_real = f'./db_inp/{name}.db.inp'
|
||||
# sensornum:int,需要布置的传感器数量
|
||||
# sensornum = sensor_num
|
||||
# wn_real:wntr.network.WaterNetworkModel,加载 EPANET 水力模型
|
||||
wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
|
||||
# sim_real:wntr.sim.EpanetSimulator,创建一个水力仿真器对象
|
||||
sim_real = wntr.sim.EpanetSimulator(wn_real)
|
||||
# results_real:wntr.sim.results.SimulationResults,运行仿真并返回结果
|
||||
results_real = sim_real.run_sim()
|
||||
|
||||
# real_C:list[float],包含所有管道粗糙度的列表
|
||||
real_C = wn_real.query_link_attribute('roughness').tolist()
|
||||
# wn_fun1:wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
|
||||
wn_fun1 = wn_func(wn_real)
|
||||
# nodes:list[str],管网的节点名称列表
|
||||
nodes = wn_fun1.nodes
|
||||
# delnodes:list[str],被删除的节点(如水库、泵、阀门连接的节点等)
|
||||
delnodes = wn_fun1.delnodes
|
||||
# Coor_node:pandas.DataFrame
|
||||
Coor_node = getCoor(wn_real)
|
||||
Coor_node = Coor_node.drop(wn_fun1.delnodes)
|
||||
nodes = [node for node in wn_fun1.nodes if node not in delnodes]
|
||||
# coordinates:pandas.Series,存储所有节点的坐标,类型为 Series,索引为节点名称,值为 (x, y) 坐标对
|
||||
coordinates = wn_fun1.coordinates
|
||||
|
||||
# 随机产生监测点
|
||||
# junctionnum:int,nodes 的长度,表示节点的数量
|
||||
junctionnum = len(nodes)
|
||||
# random_numbers:list[int],使用 random.sample 随机选择 sensornum(20)个节点的编号。它返回一个不重复的随机编号列表
|
||||
# random_numbers = random.sample(range(junctionnum), sensor_num)
|
||||
# for i in range(sensor_num):
|
||||
# # print(random_numbers[i])
|
||||
|
||||
wn_fun1.get_Conn()
|
||||
# hL:pandas.DataFrame,水力距离矩阵,表示每个节点到其他节点的水力阻力
|
||||
# G:networkx.DiGraph,加权有向图,表示管网的拓扑结构,节点之间的边带有权重
|
||||
hL, G = wn_fun1.CtoS()
|
||||
# SS:pandas.DataFrame,灵敏度矩阵,表示每个节点对管网变化(如粗糙度、流量等)的响应
|
||||
SS = wn_fun1.Jaco(hL)
|
||||
# group:dict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组,每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
|
||||
group = kgroup(Coor_node, sensor_num)
|
||||
# wn_fun:Sensorplacement(继承自wn_func)
|
||||
# 创建Sensorplacement类的实例,传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
|
||||
wn_fun = Sensorplacement(wn_real, sensor_num)
|
||||
wn_fun.__dict__.update(wn_fun1.__dict__)
|
||||
# sensorindex:list[str],初始传感器布置位置的节点名称
|
||||
# sensorindex_2:list[str],根据分组选择的传感器位置
|
||||
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
|
||||
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
|
||||
sensor_coord = {}
|
||||
# 重新打开数据库
|
||||
if is_project_open(name=name):
|
||||
close_project(name=name)
|
||||
open_project(name=name)
|
||||
for node_id in sensorindex:
|
||||
sensor_coord[node_id] = get_node_coord(name=name, node_id=node_id)
|
||||
close_project(name=name)
|
||||
# print(sensor_coord)
|
||||
return sensor_coord
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sensor_coord = get_sensor_coord(name=project_info.name, sensor_num=20)
|
||||
print(sensor_coord)
|
||||
# '''
|
||||
# 初始测压点布置根据灵敏度来布置,计算初始情况下的校准过程的error
|
||||
# '''
|
||||
#
|
||||
# # inp_file_real:str,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
|
||||
# inp_file_real = './db_inp/bb.db.inp'
|
||||
# # sensornum:int,需要布置的传感器数量
|
||||
# sensornum = 20
|
||||
# # wn_real:wntr.network.WaterNetworkModel,加载 EPANET 水力模型
|
||||
# wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
|
||||
# # sim_real:wntr.sim.EpanetSimulator,创建一个水力仿真器对象
|
||||
# sim_real = wntr.sim.EpanetSimulator(wn_real)
|
||||
# # results_real:wntr.sim.results.SimulationResults,运行仿真并返回结果
|
||||
# results_real = sim_real.run_sim()
|
||||
#
|
||||
# # real_C:list[float],包含所有管道粗糙度的列表
|
||||
# real_C = wn_real.query_link_attribute('roughness').tolist()
|
||||
# # wn_fun1:wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
|
||||
# wn_fun1 = wn_func(wn_real)
|
||||
# # nodes:list[str],管网的节点名称列表
|
||||
# nodes = wn_fun1.nodes
|
||||
# # delnodes:list[str],被删除的节点(如水库、泵、阀门连接的节点等)
|
||||
# delnodes = wn_fun1.delnodes
|
||||
# # Coor_node:pandas.DataFrame
|
||||
# Coor_node = getCoor(wn_real)
|
||||
# Coor_node = Coor_node.drop(wn_fun1.delnodes)
|
||||
# nodes = [node for node in wn_fun1.nodes if node not in delnodes]
|
||||
# # coordinates:pandas.Series,存储所有节点的坐标,类型为 Series,索引为节点名称,值为 (x, y) 坐标对
|
||||
# coordinates = wn_fun1.coordinates
|
||||
#
|
||||
# # 随机产生监测点
|
||||
# # junctionnum:int,nodes 的长度,表示节点的数量
|
||||
# junctionnum = len(nodes)
|
||||
# # random_numbers:list[int],使用 random.sample 随机选择 sensornum(20)个节点的编号。它返回一个不重复的随机编号列表
|
||||
# random_numbers = random.sample(range(junctionnum), sensornum)
|
||||
# for i in range(sensornum):
|
||||
# print(random_numbers[i])
|
||||
#
|
||||
# wn_fun1.get_Conn()
|
||||
# # hL:pandas.DataFrame,水力距离矩阵,表示每个节点到其他节点的水力阻力
|
||||
# # G:networkx.DiGraph,加权有向图,表示管网的拓扑结构,节点之间的边带有权重
|
||||
# hL, G = wn_fun1.CtoS()
|
||||
# # SS:pandas.DataFrame,灵敏度矩阵,表示每个节点对管网变化(如粗糙度、流量等)的响应
|
||||
# SS = wn_fun1.Jaco(hL)
|
||||
# # group:dict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组,每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
|
||||
# group = kgroup(Coor_node, sensornum)
|
||||
# # wn_fun:Sensorplacement(继承自wn_func)
|
||||
# # 创建Sensorplacement类的实例,传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
|
||||
# wn_fun = Sensorplacement(wn_real, sensornum)
|
||||
# wn_fun.__dict__.update(wn_fun1.__dict__)
|
||||
# # sensorindex:list[str],初始传感器布置位置的节点名称
|
||||
# # sensorindex_2:list[str],根据分组选择的传感器位置
|
||||
# sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
|
||||
# print(str(sensornum), "个测压点,测压点位置:", sensorindex)
|
||||
|
||||
# # 分区画图
|
||||
# colorlist = ['lightpink', 'coral', 'rosybrown', 'olive', 'powderblue', 'lightskyblue', 'steelblue', 'peachpuff','brown','silver','indigo','lime','gold','violet','maroon','navy','teal','magenta','cyan',
|
||||
# 'burlywood', 'tan', 'slategrey', 'thistle', 'lightseagreen', 'lightgreen', 'red','blue','yellow','orange','purple','grey','green','pink','lightblue','beige','chartreuse','turquoise','lavender','fuchsia','coral']
|
||||
# G = wn_real.to_graph()
|
||||
# G = G.to_undirected() # 变为无向图
|
||||
# pos = nx.get_node_attributes(G, 'pos')
|
||||
# pass
|
||||
# for i in range(0, sensornum):
|
||||
# ax = plt.gca()
|
||||
# ax.set_title(inp_file_real + str(sensornum))
|
||||
# nodes = nx.draw_networkx_nodes(G, pos, nodelist=group[i], node_color=colorlist[i], node_size=20)
|
||||
# nodes = nx.draw_networkx_nodes(G, pos,
|
||||
# nodelist=sensorindex_2, node_color='black', node_size=70, node_shape='*'
|
||||
# )
|
||||
# edges = nx.draw_networkx_edges(G, pos)
|
||||
# ax.spines['top'].set_visible(False)
|
||||
# ax.spines['right'].set_visible(False)
|
||||
# ax.spines['bottom'].set_visible(False)
|
||||
# ax.spines['left'].set_visible(False)
|
||||
# plt.savefig(inp_file_real + str(sensornum) + ".png")
|
||||
# plt.show()
|
||||
#
|
||||
# wntr.graphics.plot_network(wn_real, node_attribute=sensorindex_2, node_size=50, node_labels=False,
|
||||
# title=inp_file_real + '_Projetion' + str(sensornum))
|
||||
# plt.savefig(inp_file_real + '_S' + str(sensornum) + ".png")
|
||||
# plt.show()
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.algorithms.burst_detection.burst_detector import BurstDetector
|
||||
|
||||
__all__ = ["BurstDetector"]
|
||||
@@ -0,0 +1,259 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.fft import fft, ifft
|
||||
from sklearn.ensemble import IsolationForest
|
||||
|
||||
PressureDataInput = (
|
||||
pd.DataFrame
|
||||
| dict[str, list[Any]]
|
||||
| list[dict[str, Any]]
|
||||
| list[list[Any]]
|
||||
| np.ndarray
|
||||
)
|
||||
IGNORED_OBSERVATION_COLUMNS = {"time", "timestamp", "datetime", "date"}
|
||||
|
||||
|
||||
class BurstDetector:
|
||||
"""FFT + IsolationForest based burst detection for daily aligned pressure data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mu: int = 100,
|
||||
points_per_day: int = 1440,
|
||||
iforest_params: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if points_per_day <= 0:
|
||||
raise ValueError("points_per_day 必须大于 0。")
|
||||
if mu <= 0:
|
||||
raise ValueError("mu 必须大于 0。")
|
||||
|
||||
self.mu = int(mu)
|
||||
self.points_per_day = int(points_per_day)
|
||||
self.iforest_params = {
|
||||
"n_estimators": 50,
|
||||
"random_state": 42,
|
||||
"contamination": "auto",
|
||||
}
|
||||
if iforest_params:
|
||||
self.iforest_params.update(iforest_params)
|
||||
|
||||
self.data: np.ndarray | None = None
|
||||
self.sensor_names: list[str] = []
|
||||
self.high_freq_features: np.ndarray | None = None
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
data_source: PressureDataInput,
|
||||
*,
|
||||
sensor_nodes: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
标准化输入观测数据为 DataFrame。
|
||||
|
||||
支持的 `data_source` 格式:
|
||||
- `pd.DataFrame`
|
||||
每一列代表一个传感器,每一行代表一个时间点。
|
||||
- `dict[str, list[Any]]`
|
||||
键为传感器 ID,值为该传感器按时间顺序排列的压力序列。
|
||||
例如:`{"J1": [101.2, 101.0], "J2": [99.8, 99.7]}`。
|
||||
- `list[dict[str, Any]]`
|
||||
每个字典代表一个时间点,键为传感器 ID,值为该时刻压力。
|
||||
例如:`[{"J1": 101.2, "J2": 99.8}, {"J1": 101.0, "J2": 99.7}]`。
|
||||
- `list[list[Any]]`
|
||||
二维列表,格式为 `(时间点数, 传感器数)`。
|
||||
例如:`[[101.2, 99.8], [101.0, 99.7]]`。
|
||||
- `np.ndarray`
|
||||
二维数组,形状必须为 `(时间点数, 传感器数)`。
|
||||
|
||||
参数:
|
||||
- `sensor_nodes`:
|
||||
可选的传感器列筛选列表。传入后,数据中必须包含这些列名。
|
||||
|
||||
返回:
|
||||
- 标准化后的 `pd.DataFrame`,列为传感器,行为时间点。
|
||||
"""
|
||||
if isinstance(data_source, np.ndarray):
|
||||
observation_df = pd.DataFrame(data_source)
|
||||
elif isinstance(data_source, pd.DataFrame):
|
||||
observation_df = data_source.copy()
|
||||
else:
|
||||
observation_df = pd.DataFrame(data_source)
|
||||
|
||||
return self._normalize_observation_frame(
|
||||
observation_df=observation_df, sensor_nodes=sensor_nodes
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
observed_pressure_data: PressureDataInput,
|
||||
*,
|
||||
sensor_nodes: list[str] | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对输入压力序列按天切片,并提取每天末时刻的高频特征。
|
||||
|
||||
`observed_pressure_data` 的格式与 `load_data()` 一致,统一要求:
|
||||
- 数据必须表示为“行=时间点、列=传感器”。
|
||||
- 总行数必须是 `points_per_day` 的整数倍。
|
||||
- 至少需要 2 天数据,即总行数 `>= 2 * points_per_day`。
|
||||
|
||||
例如:
|
||||
- 当 `points_per_day=1440` 时,15 天数据的形状通常为 `(21600, 传感器数)`。
|
||||
- 若传入 `sensor_nodes=["J1", "J2"]`,则输入中必须存在 `J1/J2` 两列。
|
||||
|
||||
返回:
|
||||
- `np.ndarray`,形状为 `(天数, 传感器数)`,
|
||||
每个值表示对应传感器在当天末时刻提取出的高频分量。
|
||||
"""
|
||||
observation_df = self.load_data(
|
||||
observed_pressure_data,
|
||||
sensor_nodes=sensor_nodes,
|
||||
)
|
||||
matrix = observation_df.to_numpy(dtype=float)
|
||||
total_points, sensor_count = matrix.shape
|
||||
if sensor_count == 0:
|
||||
raise ValueError("压力观测数据中未找到可用传感器列。")
|
||||
if total_points < self.points_per_day * 2:
|
||||
raise ValueError("至少需要 2 天的观测数据才能执行爆管侦测。")
|
||||
if total_points % self.points_per_day != 0:
|
||||
raise ValueError("观测数据长度必须能被每日采样点数整除,以便按天切分。")
|
||||
|
||||
day_count = total_points // self.points_per_day
|
||||
high_freq_features = np.zeros((day_count, sensor_count), dtype=float)
|
||||
|
||||
for sensor_idx in range(sensor_count):
|
||||
sensor_series = matrix[:, sensor_idx]
|
||||
for day_idx in range(day_count):
|
||||
start = day_idx * self.points_per_day
|
||||
end = (day_idx + 1) * self.points_per_day
|
||||
day_data = sensor_series[start:end]
|
||||
mirrored_data = np.concatenate([day_data, day_data[::-1]])
|
||||
transformed = fft(mirrored_data)
|
||||
transformed[self.mu : len(mirrored_data) - self.mu + 1] = 0
|
||||
low_freq = ifft(transformed).real
|
||||
high_freq = day_data - low_freq[: self.points_per_day]
|
||||
high_freq_features[day_idx, sensor_idx] = float(high_freq[-1])
|
||||
|
||||
self.data = matrix
|
||||
self.sensor_names = [str(column) for column in observation_df.columns]
|
||||
self.high_freq_features = high_freq_features
|
||||
return high_freq_features
|
||||
|
||||
def detect(self) -> pd.DataFrame:
|
||||
if self.high_freq_features is None:
|
||||
raise ValueError("特征未提取。请先调用 process()。")
|
||||
|
||||
day_count = self.high_freq_features.shape[0]
|
||||
if day_count < 2:
|
||||
raise ValueError("孤立森林至少需要 2 天特征数据。")
|
||||
|
||||
clf = IsolationForest(
|
||||
n_estimators=self.iforest_params.get("n_estimators", 50),
|
||||
max_samples=day_count,
|
||||
random_state=self.iforest_params.get("random_state", 42),
|
||||
contamination=self.iforest_params.get("contamination", "auto"),
|
||||
**{
|
||||
key: value
|
||||
for key, value in self.iforest_params.items()
|
||||
if key not in {"n_estimators", "random_state", "contamination"}
|
||||
},
|
||||
)
|
||||
clf.fit(self.high_freq_features)
|
||||
|
||||
scores = clf.decision_function(self.high_freq_features)
|
||||
predictions = clf.predict(self.high_freq_features)
|
||||
result_df = pd.DataFrame(
|
||||
{
|
||||
"Day": range(1, day_count + 1),
|
||||
"Score": scores.astype(float),
|
||||
"Prediction": predictions.astype(int),
|
||||
}
|
||||
)
|
||||
result_df["IsBurst"] = result_df["Prediction"].eq(-1)
|
||||
result_df.attrs["sensor_nodes"] = self.sensor_names.copy()
|
||||
result_df.attrs["high_freq_features"] = self.high_freq_features.copy()
|
||||
result_df.attrs["day_count"] = day_count
|
||||
result_df.attrs["points_per_day"] = self.points_per_day
|
||||
result_df.attrs["sample_count"] = (
|
||||
int(self.data.shape[0]) if self.data is not None else 0
|
||||
)
|
||||
return result_df
|
||||
|
||||
def run_detection(
|
||||
self,
|
||||
observed_pressure_data: PressureDataInput,
|
||||
*,
|
||||
sensor_nodes: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
执行完整爆管侦测流程。
|
||||
|
||||
输入格式与 `process()` 相同:
|
||||
- `DataFrame` / `dict[str, list[Any]]` / `list[dict[str, Any]]` / `list[list[Any]]` / `np.ndarray`
|
||||
- 行表示时间点,列表示传感器
|
||||
- 总行数必须能被 `points_per_day` 整除
|
||||
|
||||
返回结果包含列:
|
||||
- `Day`: 第几天(从 1 开始)
|
||||
- `Score`: IsolationForest 异常分数,越小越异常
|
||||
- `Prediction`: `-1` 表示异常,`1` 表示正常
|
||||
- `IsBurst`: 是否判定为异常日
|
||||
"""
|
||||
self.process(observed_pressure_data, sensor_nodes=sensor_nodes)
|
||||
return self.detect()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_observation_frame(
|
||||
*,
|
||||
observation_df: pd.DataFrame,
|
||||
sensor_nodes: list[str] | None,
|
||||
) -> pd.DataFrame:
|
||||
if observation_df.empty:
|
||||
raise ValueError("压力观测数据为空。")
|
||||
|
||||
normalized_df = observation_df.copy()
|
||||
normalized_df.columns = [str(column) for column in normalized_df.columns]
|
||||
normalized_df = normalized_df.drop(
|
||||
columns=[
|
||||
column
|
||||
for column in normalized_df.columns
|
||||
if column.lower() in IGNORED_OBSERVATION_COLUMNS
|
||||
or column.lower().startswith("unnamed:")
|
||||
],
|
||||
errors="ignore",
|
||||
)
|
||||
|
||||
if sensor_nodes:
|
||||
selected_columns = [str(node) for node in sensor_nodes]
|
||||
missing_columns = [
|
||||
column
|
||||
for column in selected_columns
|
||||
if column not in normalized_df.columns
|
||||
]
|
||||
if missing_columns:
|
||||
preview = ", ".join(missing_columns[:10])
|
||||
raise ValueError(f"观测数据缺少传感器列: {preview}")
|
||||
normalized_df = normalized_df.loc[:, selected_columns]
|
||||
else:
|
||||
candidate_df = normalized_df.apply(pd.to_numeric, errors="coerce")
|
||||
normalized_df = candidate_df.loc[:, candidate_df.notna().any(axis=0)]
|
||||
|
||||
if normalized_df.empty:
|
||||
raise ValueError("未识别到可用的数值型压力观测列。")
|
||||
|
||||
normalized_df = normalized_df.apply(pd.to_numeric, errors="coerce")
|
||||
invalid_columns = [
|
||||
column
|
||||
for column in normalized_df.columns
|
||||
if normalized_df[column].isna().any()
|
||||
]
|
||||
if invalid_columns:
|
||||
preview = ", ".join(invalid_columns[:10])
|
||||
raise ValueError(f"压力观测数据包含非数值或缺失值: {preview}")
|
||||
|
||||
return normalized_df.reset_index(drop=True)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .burst_location import run_burst_location
|
||||
|
||||
__all__ = ["run_burst_location"]
|
||||
@@ -0,0 +1,342 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from multiprocessing import cpu_count
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from app.algorithms.burst_location import leak_simulator
|
||||
|
||||
from .burst_locator import (
|
||||
DN_search_multi_simple_add_flow_count_new,
|
||||
)
|
||||
from .network_model import (
|
||||
_build_node_pipe_maps,
|
||||
cal_node_coordinate,
|
||||
construct_graph,
|
||||
load_inp,
|
||||
read_inf_inp,
|
||||
read_inf_inp_other,
|
||||
)
|
||||
|
||||
DEFAULT_N_WORKERS = max(1, min(cpu_count() - 1, 4))
|
||||
# DEFAULT_N_WORKERS = max(1, cpu_count() - 1)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _read_id_list_json(path):
|
||||
if path is None:
|
||||
return None
|
||||
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
return [str(item) for item in data]
|
||||
if isinstance(data, dict):
|
||||
if "ids" in data and isinstance(data["ids"], list):
|
||||
return [str(item) for item in data["ids"]]
|
||||
raise ValueError(f"ID JSON must be list or dict with key 'ids': {path}")
|
||||
raise ValueError(f"Unsupported ID JSON format: {path}")
|
||||
|
||||
|
||||
def _read_series_csv(path):
|
||||
if path is None:
|
||||
return None
|
||||
df = pd.read_csv(path)
|
||||
if df.shape[1] < 2:
|
||||
raise ValueError(f"CSV must contain at least two columns (id,value): {path}")
|
||||
if {"id", "value"}.issubset(df.columns):
|
||||
id_col, value_col = "id", "value"
|
||||
else:
|
||||
id_col, value_col = df.columns[0], df.columns[1]
|
||||
series = pd.Series(
|
||||
df[value_col].values, index=df[id_col].astype(str).values, dtype=float
|
||||
)
|
||||
return series
|
||||
|
||||
|
||||
def _align_scada_series(
|
||||
series: pd.Series, ids: Iterable[str], series_name: str
|
||||
) -> pd.Series:
|
||||
ids = [str(item) for item in ids]
|
||||
aligned = series.copy()
|
||||
aligned.index = aligned.index.map(str)
|
||||
missing_ids = [item for item in ids if item not in aligned.index]
|
||||
if missing_ids:
|
||||
preview = ", ".join(missing_ids[:10])
|
||||
raise ValueError(f"{series_name} missing IDs: {preview}")
|
||||
aligned = pd.to_numeric(aligned.loc[ids], errors="coerce")
|
||||
invalid_ids = aligned[aligned.isna()].index.tolist()
|
||||
if invalid_ids:
|
||||
preview = ", ".join(invalid_ids[:10])
|
||||
raise ValueError(
|
||||
f"{series_name} contains non-numeric values for IDs: {preview}"
|
||||
)
|
||||
return aligned
|
||||
|
||||
|
||||
def _validate_flow_inputs(
|
||||
flow_scada_ids: list[str] | None,
|
||||
burst_flow: pd.Series | None,
|
||||
normal_flow: pd.Series | None,
|
||||
) -> tuple[bool, list[str]]:
|
||||
has_any_flow = any(
|
||||
value is not None for value in [flow_scada_ids, burst_flow, normal_flow]
|
||||
)
|
||||
has_all_flow = all(
|
||||
value is not None for value in [flow_scada_ids, burst_flow, normal_flow]
|
||||
)
|
||||
if has_any_flow and not has_all_flow:
|
||||
raise ValueError(
|
||||
"flow_scada_ids, burst_flow, and normal_flow must be provided together."
|
||||
)
|
||||
|
||||
if not has_all_flow:
|
||||
return False, []
|
||||
|
||||
flow_ids = [str(item) for item in (flow_scada_ids or [])]
|
||||
if len(flow_ids) == 0:
|
||||
raise ValueError("flow_scada_ids cannot be empty when flow data is provided.")
|
||||
return True, flow_ids
|
||||
|
||||
|
||||
def _build_top_candidates(similarity_series: pd.Series) -> list[dict[str, Any]]:
|
||||
top_series = similarity_series.iloc[:10]
|
||||
return [
|
||||
{"pipe_id": str(pipe_id), "similarity": float(score)}
|
||||
for pipe_id, score in top_series.items()
|
||||
]
|
||||
|
||||
|
||||
def run_burst_location(
|
||||
wn_inp_path: str,
|
||||
pressure_scada_ids: list[str],
|
||||
burst_pressure: pd.Series,
|
||||
normal_pressure: pd.Series,
|
||||
burst_leakage: float,
|
||||
flow_scada_ids: list[str] | None = None,
|
||||
burst_flow: pd.Series | None = None,
|
||||
normal_flow: pd.Series | None = None,
|
||||
min_dpressure: float = 2.0,
|
||||
basic_pressure: float = 10.0,
|
||||
n_workers: int = DEFAULT_N_WORKERS,
|
||||
partition_on_full_graph: bool = True,
|
||||
visualize_partition: bool = True,
|
||||
visualize_pause_seconds: float = 0.3,
|
||||
final_candidates_csv_path: (
|
||||
str | None
|
||||
) = "temp/burst_location/final_round_candidates.csv",
|
||||
) -> dict[str, Any]:
|
||||
if pressure_scada_ids is None or len(pressure_scada_ids) == 0:
|
||||
raise ValueError("pressure_scada_ids cannot be empty.")
|
||||
if burst_pressure is None or normal_pressure is None:
|
||||
raise ValueError("burst_pressure and normal_pressure are required.")
|
||||
|
||||
has_all_flow, flow_ids = _validate_flow_inputs(
|
||||
flow_scada_ids=flow_scada_ids,
|
||||
burst_flow=burst_flow,
|
||||
normal_flow=normal_flow,
|
||||
)
|
||||
|
||||
inp_path = Path(wn_inp_path)
|
||||
wn = load_inp(
|
||||
inp_name=inp_path.name,
|
||||
inp_location=str(inp_path.parent) + "/",
|
||||
inp_time=0,
|
||||
driven_mode="PDD",
|
||||
require_p=float(basic_pressure),
|
||||
minimum_p=0.0,
|
||||
)
|
||||
|
||||
(
|
||||
all_node,
|
||||
_,
|
||||
node_coordinates,
|
||||
all_pipe,
|
||||
_,
|
||||
_,
|
||||
pipe_length,
|
||||
pipe_diameter,
|
||||
) = read_inf_inp(wn)
|
||||
|
||||
candidate_pipe, _ = leak_simulator.cal_possible_pipe(
|
||||
burst_leakage, all_pipe, pipe_diameter
|
||||
)
|
||||
|
||||
_, pipe_start_node_all, pipe_end_node_all = read_inf_inp_other(wn)
|
||||
node_x, node_y = cal_node_coordinate(all_node, node_coordinates)
|
||||
G0 = construct_graph(wn)
|
||||
node_pipe_dic, couple_node_length = _build_node_pipe_maps(
|
||||
all_node,
|
||||
all_pipe,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
pipe_length,
|
||||
)
|
||||
all_node_series = pd.Series(range(len(all_node)), index=all_node)
|
||||
|
||||
pressure_ids = [str(item) for item in pressure_scada_ids]
|
||||
normal_pressure_aligned = _align_scada_series(
|
||||
normal_pressure, pressure_ids, "normal_pressure"
|
||||
)
|
||||
burst_pressure_aligned = _align_scada_series(
|
||||
burst_pressure, pressure_ids, "burst_pressure"
|
||||
)
|
||||
pressure_normal = normal_pressure_aligned.to_frame().T
|
||||
pressure_monitor = burst_pressure_aligned.to_frame().T
|
||||
pressure_predict = pressure_normal.copy()
|
||||
timestep_list = list(pressure_normal.index)
|
||||
|
||||
if has_all_flow:
|
||||
normal_flow_aligned = _align_scada_series(normal_flow, flow_ids, "normal_flow")
|
||||
burst_flow_aligned = _align_scada_series(burst_flow, flow_ids, "burst_flow")
|
||||
flow_normal = normal_flow_aligned.to_frame().T
|
||||
flow_monitor = burst_flow_aligned.to_frame().T
|
||||
flow_predict = flow_normal.copy()
|
||||
similarity_mode = "CDF"
|
||||
max_flow = flow_normal.iloc[0, :].abs()
|
||||
else:
|
||||
flow_normal = pd.DataFrame(index=timestep_list)
|
||||
flow_monitor = pd.DataFrame(index=timestep_list)
|
||||
flow_predict = pd.DataFrame(index=timestep_list)
|
||||
similarity_mode = "CAD_new_gy"
|
||||
max_flow = pd.Series(dtype=float)
|
||||
|
||||
stage_timing: dict[str, Any] = {}
|
||||
try:
|
||||
(
|
||||
located_pipe,
|
||||
elapsed_seconds,
|
||||
simulation_times,
|
||||
_,
|
||||
similarity_series,
|
||||
exit_condition,
|
||||
final_candidates_csv,
|
||||
) = DN_search_multi_simple_add_flow_count_new(
|
||||
wn=wn,
|
||||
wn_inp_path=str(inp_path),
|
||||
G0=G0,
|
||||
all_node=all_node,
|
||||
node_x=node_x,
|
||||
node_y=node_y,
|
||||
pipe_start_node_all=pipe_start_node_all,
|
||||
pipe_end_node_all=pipe_end_node_all,
|
||||
pipe_diameter=pipe_diameter,
|
||||
couple_node_length=couple_node_length,
|
||||
node_pipe_dic=node_pipe_dic,
|
||||
all_node_series=all_node_series,
|
||||
top_group_ratio=0.3,
|
||||
top_pipe_num_max=80,
|
||||
top_pipe_num_min=10,
|
||||
candidate_pipe_input_initial=candidate_pipe,
|
||||
similarity_mode=similarity_mode,
|
||||
pressure_monitor=pressure_monitor,
|
||||
pressure_predict=pressure_predict,
|
||||
pressure_normal=pressure_normal,
|
||||
pressure_leak_all=None,
|
||||
flow_monitor=flow_monitor,
|
||||
flow_predict=flow_predict,
|
||||
flow_normal=flow_normal,
|
||||
flow_leak_all=None,
|
||||
timestep_list=timestep_list,
|
||||
max_flow=max_flow,
|
||||
group_basic_num=30,
|
||||
Top_sensor_num=min(5, len(pressure_ids)),
|
||||
if_gy=0,
|
||||
pressure_threshold=float(min_dpressure),
|
||||
leak_mag=float(burst_leakage),
|
||||
n_workers=max(1, int(n_workers)),
|
||||
stage_timing=stage_timing,
|
||||
partition_on_full_graph=partition_on_full_graph,
|
||||
visualize_partition=visualize_partition,
|
||||
visualize_pause_seconds=visualize_pause_seconds,
|
||||
final_candidates_csv_path=final_candidates_csv_path,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Burst location algorithm execution failed.")
|
||||
raise RuntimeError(f"Failed to run burst location algorithm: {exc}") from exc
|
||||
|
||||
return {
|
||||
"located_pipe": located_pipe,
|
||||
"burst_leakage": float(burst_leakage),
|
||||
"elapsed_seconds": elapsed_seconds,
|
||||
"simulation_times": int(simulation_times),
|
||||
"top_candidates": _build_top_candidates(similarity_series),
|
||||
"similarity_mode": similarity_mode,
|
||||
"exit_condition": exit_condition,
|
||||
"final_candidates_csv": final_candidates_csv,
|
||||
"stage_timing_seconds": stage_timing,
|
||||
}
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(description="爆管定位主函数入口")
|
||||
parser.add_argument("--wn-inp", required=True, help="EPANET inp 文件路径")
|
||||
parser.add_argument(
|
||||
"--pressure-ids-json", required=True, help="压力SCADA ID列表 JSON 文件"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flow-ids-json", default=None, help="(可选)流量SCADA ID列表 JSON 文件"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--burst-pressure-csv", required=True, help="爆管时压力 CSV(id,value)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normal-pressure-csv", required=True, help="正常时压力 CSV(id,value)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--burst-flow-csv", default=None, help="(可选)爆管时流量 CSV(id,value)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normal-flow-csv", default=None, help="(可选)正常时流量 CSV(id,value)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--burst-leakage", type=float, required=True, help="爆管漏损流量"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-dpressure",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="(可选)最小压降阈值,默认 2.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--basic-pressure",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="(可选)基础服务压力,默认 10.0",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-workers",
|
||||
type=int,
|
||||
default=DEFAULT_N_WORKERS,
|
||||
help="(可选)特征中心模拟进程数,默认 max(1, min(cpu_count()-1, 4))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--final-candidates-csv-path",
|
||||
default="temp/burst_location/final_round_candidates.csv",
|
||||
help="(可选)最后一轮候选管道明细 CSV 输出路径",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
result = run_burst_location(
|
||||
wn_inp_path=args.wn_inp,
|
||||
pressure_scada_ids=_read_id_list_json(args.pressure_ids_json),
|
||||
burst_pressure=_read_series_csv(args.burst_pressure_csv),
|
||||
normal_pressure=_read_series_csv(args.normal_pressure_csv),
|
||||
burst_leakage=args.burst_leakage,
|
||||
flow_scada_ids=_read_id_list_json(args.flow_ids_json),
|
||||
burst_flow=_read_series_csv(args.burst_flow_csv),
|
||||
normal_flow=_read_series_csv(args.normal_flow_csv),
|
||||
min_dpressure=args.min_dpressure,
|
||||
basic_pressure=args.basic_pressure,
|
||||
n_workers=args.n_workers,
|
||||
final_candidates_csv_path=args.final_candidates_csv_path,
|
||||
)
|
||||
print(json.dumps(result, ensure_ascii=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,772 @@
|
||||
"""爆管定位主模块。"""
|
||||
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from time import perf_counter
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .leak_simulator import cal_signature_pipe_multi_pf
|
||||
from .network_partitioner import (
|
||||
cal_group_num,
|
||||
metis_grouping_pipe_weight,
|
||||
visualize_metis_partition,
|
||||
)
|
||||
from .similarity_calculator import (
|
||||
adjust_ratio,
|
||||
cal_similarity_all_multi_new_sq_improve_double_lzr,
|
||||
decode_mode,
|
||||
extra_judge,
|
||||
update_similarity,
|
||||
)
|
||||
|
||||
|
||||
def _ensure_signatures_for_centers(
|
||||
wn,
|
||||
wn_inp_path,
|
||||
center_list, # 本轮要用到的中心(list[str])
|
||||
pressure_leak_all,
|
||||
flow_leak_all, # 全量缓存(可为空 DF)
|
||||
timestep_list, # 你现有的时序列表
|
||||
pressure_monitor,
|
||||
flow_monitor, # 用来推断传感器列名
|
||||
leak_mag,
|
||||
n_workers=1,
|
||||
):
|
||||
"""
|
||||
只为缺失的中心补算 SLF(调用你现有的 cal_signature_pipe_multi_pf),
|
||||
并把补算结果并回缓存。返回:
|
||||
pressure_leak_subset, flow_leak_subset, pressure_leak_all_new, flow_leak_all_new
|
||||
其中 subset 只包含 center_list 的行(顺序与 center_list 保持一致)。
|
||||
"""
|
||||
|
||||
center_list = _dedupe_preserve_order(center_list)
|
||||
|
||||
# 1) 推断传感器列名(与现有数据保持一致)
|
||||
sensor_name_all = list(pressure_monitor.columns)
|
||||
sensor_f_name_all = (
|
||||
list(flow_monitor.columns)
|
||||
if (flow_monitor is not None and hasattr(flow_monitor, "columns"))
|
||||
else []
|
||||
)
|
||||
|
||||
# 2) 取出缓存里已经有的中心(考虑 MultiIndex 的第 0 层为 pipe)
|
||||
def _existing_pipes(df):
|
||||
if df is None or len(df) == 0:
|
||||
return set()
|
||||
idx = df.index
|
||||
if isinstance(idx, pd.MultiIndex):
|
||||
return set(idx.get_level_values(0))
|
||||
else:
|
||||
return set(idx)
|
||||
|
||||
exist_p = _existing_pipes(pressure_leak_all)
|
||||
need = [p for p in center_list if p not in exist_p]
|
||||
|
||||
# 3) 若有缺失中心,仅为这些中心补算一次
|
||||
if len(need) > 0:
|
||||
p_new, _ = cal_signature_pipe_multi_pf(
|
||||
wn,
|
||||
leak_mag,
|
||||
need,
|
||||
timestep_list,
|
||||
sensor_name_all,
|
||||
n_workers=n_workers,
|
||||
wn_inp_path=wn_inp_path,
|
||||
)
|
||||
# 初始化空缓存时,做一次“同构化”
|
||||
if pressure_leak_all is None or len(pressure_leak_all) == 0:
|
||||
pressure_leak_all = p_new
|
||||
else:
|
||||
pressure_leak_all = pd.concat([pressure_leak_all, p_new], axis=0)
|
||||
# if (flow_leak_all is None or len(flow_leak_all) == 0) and f_new is not None:
|
||||
# flow_leak_all = f_new
|
||||
# elif f_new is not None:
|
||||
# flow_leak_all = pd.concat([flow_leak_all, f_new], axis=0)
|
||||
|
||||
# 去重(如果既有缓存里不小心有重复中心)
|
||||
if isinstance(pressure_leak_all.index, pd.MultiIndex):
|
||||
pressure_leak_all = pressure_leak_all[
|
||||
~pressure_leak_all.index.duplicated(keep="last")
|
||||
]
|
||||
if flow_leak_all is not None and len(flow_leak_all) > 0:
|
||||
flow_leak_all = flow_leak_all[
|
||||
~flow_leak_all.index.duplicated(keep="last")
|
||||
]
|
||||
else:
|
||||
pressure_leak_all = pressure_leak_all[
|
||||
~pressure_leak_all.index.duplicated(keep="last")
|
||||
]
|
||||
if flow_leak_all is not None and len(flow_leak_all) > 0:
|
||||
flow_leak_all = flow_leak_all[
|
||||
~flow_leak_all.index.duplicated(keep="last")
|
||||
]
|
||||
|
||||
# 4) 从更新后的缓存里,取出这轮需要的中心子集(顺序与 center_list 一致)
|
||||
if isinstance(pressure_leak_all.index, pd.MultiIndex):
|
||||
pressure_subset = pressure_leak_all.loc[center_list]
|
||||
flow_subset = (
|
||||
flow_leak_all.loc[center_list]
|
||||
if (flow_leak_all is not None and len(flow_leak_all) > 0)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
pressure_subset = pressure_leak_all.loc[center_list, :]
|
||||
flow_subset = (
|
||||
flow_leak_all.loc[center_list, :]
|
||||
if (flow_leak_all is not None and len(flow_leak_all) > 0)
|
||||
else None
|
||||
)
|
||||
|
||||
return pressure_subset, flow_subset, pressure_leak_all, flow_leak_all
|
||||
|
||||
|
||||
def area_output_num_ki_improve(
|
||||
candidate_center,
|
||||
candidate_group,
|
||||
similarity,
|
||||
new_all_node,
|
||||
top_group_ratio,
|
||||
top_pipe_num_max,
|
||||
top_pipe_num_min,
|
||||
cut_ratio,
|
||||
):
|
||||
final_area = []
|
||||
final_center = []
|
||||
all_node_iter = []
|
||||
if similarity.index.is_unique == False:
|
||||
total_center_num = len(set(similarity.index))
|
||||
else:
|
||||
total_center_num = len(similarity.index)
|
||||
|
||||
next_group_num = min(
|
||||
total_center_num, math.ceil(total_center_num / cut_ratio * top_group_ratio)
|
||||
)
|
||||
|
||||
for i in range(next_group_num):
|
||||
top_center = similarity.index[i]
|
||||
top_center_index = find_list_repeat(candidate_center, top_center)
|
||||
for j in range(len(top_center_index)):
|
||||
final_area = final_area + candidate_group[top_center_index[j]]
|
||||
all_node_iter = all_node_iter + list(new_all_node[top_center_index[j]])
|
||||
final_center.append(top_center)
|
||||
final_area = sorted(set(final_area))
|
||||
|
||||
if len(final_area) > top_pipe_num_max:
|
||||
if_end = 0
|
||||
elif len(final_area) > top_pipe_num_min:
|
||||
if_end = 1
|
||||
elif total_center_num == next_group_num:
|
||||
if_end = 1
|
||||
else:
|
||||
if_end = 1
|
||||
for i in np.arange(next_group_num, total_center_num, 1):
|
||||
|
||||
before_list = copy.deepcopy(final_area)
|
||||
top_center = similarity.index[i]
|
||||
top_center_index = candidate_center.index(top_center)
|
||||
temp_group = final_area + candidate_group[top_center_index]
|
||||
temp_area = sorted(set(temp_group))
|
||||
|
||||
if len(temp_area) < top_pipe_num_min:
|
||||
|
||||
final_center.append(top_center)
|
||||
all_node_iter = all_node_iter + list(new_all_node[top_center_index])
|
||||
final_area = temp_area
|
||||
|
||||
elif len(temp_area) < top_pipe_num_max:
|
||||
final_center.append(top_center)
|
||||
all_node_iter = all_node_iter + list(new_all_node[top_center_index])
|
||||
final_area = temp_area
|
||||
break
|
||||
else:
|
||||
a = len(temp_area) - top_pipe_num_max
|
||||
b = top_pipe_num_min - len(before_list)
|
||||
|
||||
if a >= b:
|
||||
final_area = before_list
|
||||
else:
|
||||
final_center.append(top_center)
|
||||
all_node_iter = all_node_iter + list(new_all_node[top_center_index])
|
||||
final_area = temp_area
|
||||
break
|
||||
|
||||
final_center = sorted(set(final_center))
|
||||
all_node_iter = sorted(set(all_node_iter))
|
||||
return final_area, final_center, all_node_iter, if_end
|
||||
|
||||
|
||||
def find_list_repeat(candidate_center, target):
|
||||
repeated_list = []
|
||||
for index, nums in enumerate(candidate_center):
|
||||
if nums == target:
|
||||
repeated_list.append(index)
|
||||
return repeated_list
|
||||
|
||||
|
||||
def _dedupe_preserve_order(items):
|
||||
seen = set()
|
||||
output = []
|
||||
for item in items:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
output.append(item)
|
||||
return output
|
||||
|
||||
|
||||
def _accumulate_stage(stage_timing, stage_name, started_at):
|
||||
stage_timing[stage_name] = stage_timing.get(stage_name, 0.0) + (
|
||||
perf_counter() - started_at
|
||||
)
|
||||
|
||||
|
||||
def _write_last_round_candidates_csv(
|
||||
csv_path,
|
||||
exit_condition,
|
||||
iteration_count,
|
||||
similarity_mode,
|
||||
candidate_details,
|
||||
fallback_similarity,
|
||||
):
|
||||
if not csv_path:
|
||||
return None
|
||||
timestamp_suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
base_path, ext = os.path.splitext(csv_path)
|
||||
ext = ext or ".csv"
|
||||
output_path = f"{base_path}_{timestamp_suffix}{ext}"
|
||||
if candidate_details is not None and len(candidate_details) > 0:
|
||||
export_df = candidate_details.copy()
|
||||
if export_df.index.name == "pipe_id":
|
||||
export_df = export_df.reset_index()
|
||||
else:
|
||||
export_df = pd.DataFrame(
|
||||
{
|
||||
"pipe_id": [str(pipe_id) for pipe_id in fallback_similarity.index],
|
||||
"final_similarity": [float(value) for value in fallback_similarity.values],
|
||||
}
|
||||
)
|
||||
export_df["exit_condition"] = exit_condition
|
||||
export_df["iterations"] = int(iteration_count)
|
||||
export_df["similarity_mode"] = similarity_mode
|
||||
parent_dir = os.path.dirname(output_path)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
export_df.to_csv(output_path, index=False, encoding="utf-8-sig")
|
||||
return output_path
|
||||
|
||||
|
||||
def cal_DtoTop1(
|
||||
G0, pipe_leak, located_pipe, pipe_start_node_all, pipe_end_node_all, pipe_length
|
||||
):
|
||||
if pipe_leak == located_pipe:
|
||||
result_DtoTop1 = 0
|
||||
result_DtoTop1_num = 0
|
||||
else:
|
||||
pipe_leak_start_node = pipe_start_node_all[pipe_leak]
|
||||
pipe_leak_end_node = pipe_end_node_all[pipe_leak]
|
||||
located_pipe_start_node = pipe_start_node_all[located_pipe]
|
||||
located_pipe_end_node = pipe_end_node_all[located_pipe]
|
||||
DtoTop1_series = pd.Series(dtype=object)
|
||||
DtoTop1_num_series = pd.Series(dtype=object)
|
||||
DtoTop1_series["ss"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_start_node, located_pipe_start_node, weight="weight"
|
||||
)
|
||||
DtoTop1_series["se"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_start_node, located_pipe_end_node, weight="weight"
|
||||
)
|
||||
DtoTop1_series["es"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_end_node, located_pipe_start_node, weight="weight"
|
||||
)
|
||||
DtoTop1_series["ee"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_end_node, located_pipe_end_node, weight="weight"
|
||||
)
|
||||
DtoTop1_num_series["ss"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_start_node, located_pipe_start_node
|
||||
)
|
||||
DtoTop1_num_series["se"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_start_node, located_pipe_end_node
|
||||
)
|
||||
DtoTop1_num_series["es"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_end_node, located_pipe_start_node
|
||||
)
|
||||
DtoTop1_num_series["ee"] = nx.shortest_path_length(
|
||||
G0, pipe_leak_end_node, located_pipe_end_node
|
||||
)
|
||||
|
||||
if DtoTop1_num_series.min() == 0:
|
||||
result_DtoTop1_num = 1
|
||||
result_DtoTop1 = DtoTop1_series.max() / 2
|
||||
else:
|
||||
result_DtoTop1_num = DtoTop1_num_series.min() + 1
|
||||
DtoTop1_type = DtoTop1_series.argmin()
|
||||
result_DtoTop1 = (
|
||||
DtoTop1_series[DtoTop1_type]
|
||||
+ (pipe_length[pipe_leak] + pipe_length[located_pipe]) / 2
|
||||
)
|
||||
return result_DtoTop1, result_DtoTop1_num
|
||||
|
||||
|
||||
def cal_RR(located_pipe, similarity_sp):
|
||||
if located_pipe in similarity_sp.index:
|
||||
rank = similarity_sp.index.get_loc(located_pipe)
|
||||
RR = rank / len(similarity_sp.index)
|
||||
else:
|
||||
RR = 1.1
|
||||
return RR
|
||||
|
||||
|
||||
def cal_cover(similarity, leak_pipe):
|
||||
if leak_pipe in list(similarity.index):
|
||||
cover = 1
|
||||
else:
|
||||
cover = 0
|
||||
return cover
|
||||
|
||||
|
||||
def cal_SD(located_pipe, real_pipe, pipe_x, pipe_y):
|
||||
dx = pipe_x[located_pipe] - pipe_x[real_pipe]
|
||||
dy = pipe_y[located_pipe] - pipe_y[real_pipe]
|
||||
SD = math.sqrt(dx * dx + dy * dy)
|
||||
return SD
|
||||
|
||||
|
||||
def DN_search_multi_simple_add_flow_count_new(
|
||||
wn,
|
||||
wn_inp_path,
|
||||
G0,
|
||||
all_node,
|
||||
node_x,
|
||||
node_y,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
pipe_diameter,
|
||||
couple_node_length,
|
||||
node_pipe_dic,
|
||||
all_node_series,
|
||||
top_group_ratio,
|
||||
top_pipe_num_max,
|
||||
top_pipe_num_min,
|
||||
candidate_pipe_input_initial,
|
||||
similarity_mode,
|
||||
pressure_monitor,
|
||||
pressure_predict,
|
||||
pressure_normal,
|
||||
pressure_leak_all,
|
||||
flow_monitor,
|
||||
flow_predict,
|
||||
flow_normal,
|
||||
flow_leak_all,
|
||||
timestep_list,
|
||||
max_flow,
|
||||
group_basic_num,
|
||||
Top_sensor_num,
|
||||
if_gy,
|
||||
pressure_threshold,
|
||||
leak_mag,
|
||||
n_workers=1,
|
||||
stage_timing=None,
|
||||
partition_on_full_graph=True,
|
||||
visualize_partition=False,
|
||||
visualize_pause_seconds=0.3,
|
||||
final_candidates_csv_path=None,
|
||||
):
|
||||
if stage_timing is None:
|
||||
stage_timing = {}
|
||||
exit_condition = "unknown"
|
||||
final_candidates_csv = None
|
||||
iter_count = 0
|
||||
all_node_iter = copy.deepcopy(all_node)
|
||||
candidate_pipe_input = copy.deepcopy(candidate_pipe_input_initial) # 可能漏损管段
|
||||
t1 = datetime.now()
|
||||
if_flow, if_only_cos, if_only_flow = decode_mode(similarity_mode) # 定位方法
|
||||
# threshold
|
||||
if if_only_flow == 1:
|
||||
dpressure = (flow_predict - flow_monitor).mean()
|
||||
dpressure = dpressure.abs()
|
||||
effective_sensor = list(dpressure.index)
|
||||
|
||||
else:
|
||||
dpressure = (pressure_predict - pressure_monitor).mean()
|
||||
dpressure = dpressure.abs()
|
||||
dpressure = dpressure[dpressure > pressure_threshold]
|
||||
effective_sensor = list(dpressure.index)
|
||||
simulation_times = 0 # 模拟次数
|
||||
if len(dpressure) > 0:
|
||||
break_flag = 0
|
||||
last_round_candidate_details = None
|
||||
|
||||
cos_h = 0
|
||||
dis_h = 0
|
||||
dis_f_h = 0
|
||||
if_compalsive = 0
|
||||
record_center_dataset = []
|
||||
record_center_set = set()
|
||||
# iter
|
||||
while 1:
|
||||
final_area = []
|
||||
final_center = []
|
||||
group_num = cal_group_num(candidate_pipe_input, group_basic_num)
|
||||
partition_nodes = all_node if partition_on_full_graph else all_node_iter
|
||||
|
||||
# group 分组,得出候选漏损中心
|
||||
stage_start = perf_counter()
|
||||
(
|
||||
candidate_center_list,
|
||||
candidate_group_list,
|
||||
new_all_node,
|
||||
candidate_center_candidates,
|
||||
) = (
|
||||
metis_grouping_pipe_weight(
|
||||
G0,
|
||||
wn,
|
||||
partition_nodes,
|
||||
candidate_pipe_input,
|
||||
group_num,
|
||||
node_x,
|
||||
node_y,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
node_pipe_dic,
|
||||
all_node_series,
|
||||
couple_node_length,
|
||||
pipe_diameter,
|
||||
)
|
||||
)
|
||||
_accumulate_stage(stage_timing, "group_partitioning", stage_start)
|
||||
if visualize_partition:
|
||||
visualize_metis_partition(
|
||||
G0,
|
||||
candidate_center_list,
|
||||
candidate_group_list,
|
||||
node_x,
|
||||
node_y,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
title=(
|
||||
f"METIS Partition Iteration {iter_count + 1} | "
|
||||
f"candidate pipes={len(candidate_pipe_input)} "
|
||||
f"groups={len(candidate_group_list)}"
|
||||
),
|
||||
block=False,
|
||||
pause_seconds=visualize_pause_seconds,
|
||||
)
|
||||
simulation_times = simulation_times + len(candidate_center_list)
|
||||
# pick_pressure_leak
|
||||
# pressure_leak = pressure_leak_all.loc[candidate_center_list].loc[:, :]
|
||||
# flow_leak = flow_leak_all.loc[candidate_center_list].loc[:, :]
|
||||
# —— 新增泄漏量(保持你现在的一致,或从外部传入)——
|
||||
# —— 只为缺失中心补算,然后取本轮需要的中心子集 ——
|
||||
stage_start = perf_counter()
|
||||
pressure_leak, flow_leak, pressure_leak_all, flow_leak_all = (
|
||||
_ensure_signatures_for_centers(
|
||||
wn=wn,
|
||||
wn_inp_path=wn_inp_path,
|
||||
center_list=candidate_center_list,
|
||||
pressure_leak_all=pressure_leak_all,
|
||||
flow_leak_all=flow_leak_all,
|
||||
timestep_list=timestep_list,
|
||||
pressure_monitor=pressure_monitor,
|
||||
flow_monitor=flow_monitor,
|
||||
leak_mag=leak_mag,
|
||||
n_workers=n_workers,
|
||||
)
|
||||
)
|
||||
_accumulate_stage(stage_timing, "signature_for_candidates", stage_start)
|
||||
|
||||
# pressure_leak_f= pressure_leak.swaplevel()
|
||||
|
||||
# --------------------------------------------------------
|
||||
|
||||
add_center = []
|
||||
leak_center_dict = dict()
|
||||
for i in range(len(candidate_center_list)):
|
||||
primary_center = candidate_center_list[i]
|
||||
houxuan_center = [
|
||||
center
|
||||
for center in candidate_center_candidates[i]
|
||||
if center != primary_center
|
||||
]
|
||||
candidate_group_set = set(candidate_group_list[i])
|
||||
for each_center in record_center_dataset:
|
||||
if (
|
||||
each_center in candidate_group_set
|
||||
and each_center != primary_center
|
||||
):
|
||||
houxuan_center.append(each_center)
|
||||
add_center = add_center + houxuan_center
|
||||
leak_center_dict[primary_center] = _dedupe_preserve_order(
|
||||
houxuan_center + [primary_center]
|
||||
)
|
||||
add_center = _dedupe_preserve_order(add_center)
|
||||
for each_group_centers in candidate_center_candidates:
|
||||
for each_center in each_group_centers:
|
||||
if each_center not in record_center_set:
|
||||
record_center_dataset.append(each_center)
|
||||
record_center_set.add(each_center)
|
||||
for each_center in add_center:
|
||||
if each_center not in record_center_set:
|
||||
record_center_dataset.append(each_center)
|
||||
record_center_set.add(each_center)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# --------------------------------------------------------
|
||||
# if len(add_center) > 0:
|
||||
# s3 = pressure_leak_all.loc[add_center]
|
||||
# pressure_leak = pd.concat([pressure_leak, s3])
|
||||
# s4 = flow_leak_all.loc[add_center]
|
||||
# flow_leak = pd.concat([flow_leak, s4])
|
||||
# --------------------------------------------------------
|
||||
# 只为 add_center 里还没算过的中心补算,并与本轮中心合并
|
||||
if len(add_center) > 0:
|
||||
stage_start = perf_counter()
|
||||
pressure_add, flow_add, pressure_leak_all, flow_leak_all = (
|
||||
_ensure_signatures_for_centers(
|
||||
wn=wn,
|
||||
wn_inp_path=wn_inp_path,
|
||||
center_list=add_center,
|
||||
pressure_leak_all=pressure_leak_all,
|
||||
flow_leak_all=flow_leak_all,
|
||||
timestep_list=timestep_list,
|
||||
pressure_monitor=pressure_monitor,
|
||||
flow_monitor=flow_monitor,
|
||||
leak_mag=leak_mag, # 与上面一致
|
||||
n_workers=n_workers,
|
||||
)
|
||||
)
|
||||
_accumulate_stage(
|
||||
stage_timing, "signature_for_extra_centers", stage_start
|
||||
)
|
||||
pressure_leak = pd.concat([pressure_leak, pressure_add], axis=0)
|
||||
if (flow_leak is not None) and (flow_add is not None):
|
||||
flow_leak = pd.concat([flow_leak, flow_add], axis=0)
|
||||
# --------------------------------------------------------
|
||||
#
|
||||
if len(candidate_pipe_input) < 1.2 * top_pipe_num_max / top_group_ratio:
|
||||
if_compalsive = 1
|
||||
cos_h, dis_h, dis_f_h = adjust_ratio(similarity_mode, cos_h, dis_h, dis_f_h)
|
||||
candidate_center_list_sup = _dedupe_preserve_order(
|
||||
candidate_center_list + add_center
|
||||
)
|
||||
stage_start = perf_counter()
|
||||
similarity, cos_h, dis_h, dis_f_h, break_flag, similarity_details = (
|
||||
cal_similarity_all_multi_new_sq_improve_double_lzr(
|
||||
candidate_center_list_sup,
|
||||
similarity_mode,
|
||||
pressure_leak,
|
||||
pressure_monitor,
|
||||
pressure_predict,
|
||||
pressure_normal,
|
||||
if_flow,
|
||||
if_only_cos,
|
||||
if_only_flow,
|
||||
flow_leak,
|
||||
flow_monitor,
|
||||
flow_predict,
|
||||
flow_normal,
|
||||
timestep_list,
|
||||
Top_sensor_num,
|
||||
if_gy,
|
||||
effective_sensor,
|
||||
cos_h,
|
||||
dis_h,
|
||||
dis_f_h,
|
||||
if_compalsive,
|
||||
max_flow,
|
||||
)
|
||||
)
|
||||
last_round_candidate_details = similarity_details
|
||||
_accumulate_stage(stage_timing, "similarity_ranking", stage_start)
|
||||
if break_flag == 1:
|
||||
exit_condition = "similarity_break_flag"
|
||||
break
|
||||
|
||||
new_similarity = update_similarity(
|
||||
candidate_center_list, similarity, leak_center_dict
|
||||
)
|
||||
if len(candidate_pipe_input) > top_pipe_num_max / top_group_ratio:
|
||||
cut_ratio, new_similarity = extra_judge(new_similarity)
|
||||
else:
|
||||
cut_ratio = 1
|
||||
stage_start = perf_counter()
|
||||
final_area_t, final_center_t, all_node_new_1, if_end = (
|
||||
area_output_num_ki_improve(
|
||||
candidate_center_list,
|
||||
candidate_group_list,
|
||||
new_similarity,
|
||||
new_all_node,
|
||||
top_group_ratio,
|
||||
top_pipe_num_max,
|
||||
top_pipe_num_min,
|
||||
cut_ratio,
|
||||
)
|
||||
)
|
||||
_accumulate_stage(stage_timing, "candidate_area_selection", stage_start)
|
||||
final_area = final_area + final_area_t
|
||||
final_center = final_center + final_center_t
|
||||
|
||||
final_area = sorted(set(final_area))
|
||||
final_center = sorted(set(final_center))
|
||||
if if_end == 1:
|
||||
exit_condition = "candidate_area_if_end"
|
||||
break
|
||||
elif len(candidate_pipe_input) == len(final_area):
|
||||
exit_condition = "candidate_size_no_change"
|
||||
break
|
||||
else:
|
||||
candidate_pipe_input = final_area
|
||||
if not partition_on_full_graph:
|
||||
all_node_iter = all_node_new_1
|
||||
iter_count += 1
|
||||
sys.stdout.write(
|
||||
"\r"
|
||||
+ "已经完成"
|
||||
+ str(iter_count)
|
||||
+ "次迭代计算"
|
||||
+ "候选节点"
|
||||
+ str(len(final_area))
|
||||
+ "个"
|
||||
)
|
||||
# if break_flag == 0:
|
||||
# final_area_pipe = copy.deepcopy(final_area)
|
||||
# simulation_times = simulation_times + len(final_area)
|
||||
# pressure_leak_sp = pressure_leak_all.loc[final_area_pipe].loc[:, :]
|
||||
# flow_leak_sp = flow_leak_all.loc[final_area_pipe].loc[:, :]
|
||||
# similarity_sp, cos_h, dis_h, dis_f_h, break_flag = cal_similarity_all_multi_new_sq_improve_double_lzr(
|
||||
# final_area_pipe, similarity_mode, pressure_leak_sp,
|
||||
# pressure_monitor, pressure_predict, pressure_normal, if_flow,
|
||||
# if_only_cos, if_only_flow,
|
||||
# flow_leak_sp, flow_monitor, flow_predict, flow_normal,
|
||||
# timestep_list, Top_sensor_num, if_gy, effective_sensor, cos_h, dis_h, dis_f_h, if_compalsive, max_flow)
|
||||
if break_flag == 0:
|
||||
final_area_pipe = list(final_area) # 确保是 list
|
||||
|
||||
# 只为还没算过的管段补齐 SLF(按需计算)
|
||||
stage_start = perf_counter()
|
||||
pressure_leak_sp, flow_leak_sp, pressure_leak_all, flow_leak_all = (
|
||||
_ensure_signatures_for_centers(
|
||||
wn=wn,
|
||||
wn_inp_path=wn_inp_path,
|
||||
center_list=final_area_pipe, # 这次要用的“最终区域里的所有管段”
|
||||
pressure_leak_all=pressure_leak_all, # 累积缓存(会被更新)
|
||||
flow_leak_all=flow_leak_all,
|
||||
timestep_list=timestep_list,
|
||||
pressure_monitor=pressure_monitor,
|
||||
flow_monitor=flow_monitor,
|
||||
leak_mag=leak_mag,
|
||||
n_workers=n_workers,
|
||||
)
|
||||
)
|
||||
_accumulate_stage(stage_timing, "signature_for_final_area", stage_start)
|
||||
|
||||
# 如果你要精确统计模拟次数,这里可以加上“本次新补的数量”,
|
||||
# 做法:让 _ensure_signatures_for_centers 额外返回 need_cnt,再 simulation_times += need_cnt
|
||||
|
||||
stage_start = perf_counter()
|
||||
(
|
||||
similarity_sp,
|
||||
cos_h,
|
||||
dis_h,
|
||||
dis_f_h,
|
||||
break_flag,
|
||||
similarity_details,
|
||||
) = (
|
||||
cal_similarity_all_multi_new_sq_improve_double_lzr(
|
||||
final_area_pipe,
|
||||
similarity_mode,
|
||||
pressure_leak_sp,
|
||||
pressure_monitor,
|
||||
pressure_predict,
|
||||
pressure_normal,
|
||||
if_flow,
|
||||
if_only_cos,
|
||||
if_only_flow,
|
||||
flow_leak_sp,
|
||||
flow_monitor,
|
||||
flow_predict,
|
||||
flow_normal,
|
||||
timestep_list,
|
||||
Top_sensor_num,
|
||||
if_gy,
|
||||
effective_sensor,
|
||||
cos_h,
|
||||
dis_h,
|
||||
dis_f_h,
|
||||
if_compalsive,
|
||||
max_flow,
|
||||
)
|
||||
)
|
||||
last_round_candidate_details = similarity_details
|
||||
_accumulate_stage(stage_timing, "similarity_final", stage_start)
|
||||
|
||||
else:
|
||||
dpressure = (pressure_predict - pressure_monitor).mean()
|
||||
dpressure = dpressure.abs()
|
||||
simulation_times = simulation_times + len(dpressure.index)
|
||||
similarity_sp = pd.Series(dtype=float)
|
||||
for each_node in dpressure.index:
|
||||
pipe = node_pipe_dic[each_node][0]
|
||||
similarity_sp.loc[pipe] = dpressure.loc[each_node]
|
||||
similarity_sp = similarity_sp.sort_values(ascending=False, kind="mergesort")
|
||||
t2 = datetime.now()
|
||||
final_area_pipe = []
|
||||
|
||||
sys.stdout.write(
|
||||
"\r"
|
||||
+ "已经完成"
|
||||
+ str(iter_count + 1)
|
||||
+ "次迭代计算"
|
||||
+ "候选节点"
|
||||
+ str(len(final_area_pipe))
|
||||
+ "个"
|
||||
)
|
||||
t2 = datetime.now()
|
||||
dt = (t2 - t1).seconds
|
||||
final_candidates_csv = _write_last_round_candidates_csv(
|
||||
csv_path=final_candidates_csv_path,
|
||||
exit_condition=exit_condition,
|
||||
iteration_count=iter_count + 1,
|
||||
similarity_mode=similarity_mode,
|
||||
candidate_details=last_round_candidate_details,
|
||||
fallback_similarity=similarity_sp,
|
||||
)
|
||||
else:
|
||||
exit_condition = "no_effective_sensor_after_threshold"
|
||||
dpressure = (pressure_predict - pressure_monitor).mean()
|
||||
dpressure = dpressure.abs()
|
||||
|
||||
similarity_sp = pd.Series(dtype=float)
|
||||
for each_node in dpressure.index:
|
||||
pipe = node_pipe_dic[each_node][0]
|
||||
similarity_sp.loc[pipe] = dpressure.loc[each_node]
|
||||
similarity_sp = similarity_sp.sort_values(ascending=False, kind="mergesort")
|
||||
t2 = datetime.now()
|
||||
dt = (t2 - t1).seconds
|
||||
final_candidates_csv = _write_last_round_candidates_csv(
|
||||
csv_path=final_candidates_csv_path,
|
||||
exit_condition=exit_condition,
|
||||
iteration_count=0,
|
||||
similarity_mode=similarity_mode,
|
||||
candidate_details=None,
|
||||
fallback_similarity=similarity_sp,
|
||||
)
|
||||
stage_timing["iterations"] = iter_count + 1 if len(dpressure) > 0 else 0
|
||||
stage_timing["total_elapsed_seconds"] = float(dt)
|
||||
stage_timing["exit_condition"] = exit_condition
|
||||
stage_timing["final_candidates_csv"] = final_candidates_csv
|
||||
return (
|
||||
similarity_sp.index[0],
|
||||
dt,
|
||||
simulation_times,
|
||||
wn,
|
||||
similarity_sp,
|
||||
exit_condition,
|
||||
final_candidates_csv,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,563 @@
|
||||
"""漏损模拟模块。"""
|
||||
|
||||
import math
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import wntr
|
||||
|
||||
from app.algorithms._utils import _cleanup_temp_files
|
||||
|
||||
_PIPE2LEAKNODE = None
|
||||
_SIGNATURE_WORKER_DATA = {}
|
||||
|
||||
|
||||
def _make_temp_prefix(tag):
|
||||
temp_dir = os.path.abspath(os.path.join("temp", "burst_location"))
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
safe_tag = str(tag).replace(os.sep, "_").replace(" ", "_")
|
||||
return os.path.join(temp_dir, f"{safe_tag}_{os.getpid()}")
|
||||
|
||||
|
||||
def _snapshot_hydraulic_options(wn):
|
||||
options = wn.options
|
||||
return {
|
||||
"demand_model": options.hydraulic.demand_model,
|
||||
"duration": float(options.time.duration),
|
||||
"hydraulic_timestep": float(options.time.hydraulic_timestep),
|
||||
"pattern_timestep": float(options.time.pattern_timestep),
|
||||
"report_timestep": float(options.time.report_timestep),
|
||||
"required_pressure": float(options.hydraulic.required_pressure),
|
||||
"minimum_pressure": float(options.hydraulic.minimum_pressure),
|
||||
}
|
||||
|
||||
|
||||
def _apply_hydraulic_options(wn, option_values):
|
||||
options = wn.options
|
||||
options.hydraulic.demand_model = option_values["demand_model"]
|
||||
options.time.duration = option_values["duration"]
|
||||
options.time.hydraulic_timestep = option_values["hydraulic_timestep"]
|
||||
options.time.pattern_timestep = option_values["pattern_timestep"]
|
||||
options.time.report_timestep = option_values["report_timestep"]
|
||||
options.hydraulic.required_pressure = option_values["required_pressure"]
|
||||
options.hydraulic.minimum_pressure = option_values["minimum_pressure"]
|
||||
|
||||
|
||||
def simple_add_leak(wn, leak_mag, leak_pipe):
|
||||
whole_inf = dict()
|
||||
leak_pipe_self = wn.get_link(leak_pipe)
|
||||
pipe_diameter = leak_pipe_self.diameter
|
||||
pipe_length = leak_pipe_self.length
|
||||
pipe_roughness = leak_pipe_self.roughness
|
||||
pipe_minor_loss = leak_pipe_self.minor_loss
|
||||
# pipe_status = leak_pipe_self.status
|
||||
# pipe_check_valve = leak_pipe_self.check_valve
|
||||
pipe_start_node = leak_pipe_self.start_node_name
|
||||
pipe_end_node = leak_pipe_self.end_node_name
|
||||
|
||||
# close the pipe
|
||||
# leak_pipe_self.status = 'Closed'
|
||||
wn.remove_link(leak_pipe)
|
||||
# add the pipe
|
||||
add_pipe1 = leak_pipe + "A"
|
||||
add_pipe2 = leak_pipe + "B"
|
||||
add_node = leak_pipe + "_"
|
||||
|
||||
start_n = wn.get_node(pipe_start_node)
|
||||
end_n = wn.get_node(pipe_end_node)
|
||||
if start_n.node_type == "Reservoir":
|
||||
end_n_elevation = end_n.elevation
|
||||
start_n_elevation = end_n_elevation
|
||||
elif end_n.node_type == "Reservoir":
|
||||
start_n_elevation = start_n.elevation
|
||||
end_n_elevation = start_n_elevation
|
||||
else:
|
||||
end_n_elevation = end_n.elevation
|
||||
start_n_elevation = start_n.elevation
|
||||
elevation_self = (start_n_elevation + end_n_elevation) / 2
|
||||
coordinates_self = (
|
||||
(start_n.coordinates[0] + end_n.coordinates[0]) / 2,
|
||||
(start_n.coordinates[1] + end_n.coordinates[1]),
|
||||
)
|
||||
|
||||
wn.add_junction(
|
||||
add_node, base_demand=0, elevation=elevation_self, coordinates=coordinates_self
|
||||
)
|
||||
leak_node = wn.get_node(add_node)
|
||||
wn.add_pipe(
|
||||
add_pipe1,
|
||||
start_node_name=pipe_start_node,
|
||||
end_node_name=add_node,
|
||||
length=pipe_length / 2,
|
||||
diameter=pipe_diameter,
|
||||
roughness=pipe_roughness,
|
||||
minor_loss=pipe_minor_loss,
|
||||
)
|
||||
wn.add_pipe(
|
||||
add_pipe2,
|
||||
start_node_name=pipe_end_node,
|
||||
end_node_name=add_node,
|
||||
length=pipe_length / 2,
|
||||
diameter=pipe_diameter,
|
||||
roughness=pipe_roughness,
|
||||
minor_loss=pipe_minor_loss,
|
||||
)
|
||||
|
||||
# simulation
|
||||
leak_node.add_demand(base=leak_mag, pattern_name="add_leak")
|
||||
|
||||
whole_inf["leak_node_name"] = add_node
|
||||
whole_inf["add_pipe1"] = add_pipe1
|
||||
whole_inf["add_pipe2"] = add_pipe2
|
||||
whole_inf["leak_pipe"] = leak_pipe
|
||||
whole_inf["pipe_start_node"] = pipe_start_node
|
||||
whole_inf["pipe_end_node"] = pipe_end_node
|
||||
whole_inf["pipe_length"] = pipe_length
|
||||
whole_inf["pipe_diameter"] = pipe_diameter
|
||||
whole_inf["pipe_roughness"] = pipe_roughness
|
||||
whole_inf["pipe_minor_loss"] = pipe_diameter
|
||||
return wn, whole_inf, add_pipe1
|
||||
|
||||
|
||||
def simple_recover_wn(wn, whole_inf):
|
||||
leak_node = wn.get_node(whole_inf["leak_node_name"])
|
||||
|
||||
del leak_node.demand_timeseries_list[-1]
|
||||
# update
|
||||
wn.remove_link(whole_inf["add_pipe1"])
|
||||
wn.remove_link(whole_inf["add_pipe2"])
|
||||
wn.remove_node(whole_inf["leak_node_name"])
|
||||
# open the pipe
|
||||
# leak_pipe_self.status = 'Open'
|
||||
wn.add_pipe(
|
||||
whole_inf["leak_pipe"],
|
||||
start_node_name=whole_inf["pipe_start_node"],
|
||||
end_node_name=whole_inf["pipe_end_node"],
|
||||
length=whole_inf["pipe_length"],
|
||||
diameter=whole_inf["pipe_diameter"],
|
||||
roughness=whole_inf["pipe_roughness"],
|
||||
minor_loss=whole_inf["pipe_minor_loss"],
|
||||
)
|
||||
|
||||
return wn
|
||||
|
||||
|
||||
def disable_all_controls_temporarily(wn):
|
||||
"""返回(控制名, 控制对象)的列表,之后可用 restore_controls 还原。"""
|
||||
removed = []
|
||||
# WNTR 的控制都在 wn.control_name_list / wn.get_control / wn.remove_control
|
||||
for cname in list(wn.control_name_list):
|
||||
ctrl = wn.get_control(cname)
|
||||
removed.append((cname, ctrl))
|
||||
wn.remove_control(cname)
|
||||
return removed
|
||||
|
||||
|
||||
def restore_controls(wn, removed):
|
||||
"""把先前禁用的控制全部加回去。"""
|
||||
for cname, ctrl in removed:
|
||||
wn.add_control(cname, ctrl)
|
||||
|
||||
|
||||
def set_pipe2leaknode_mapping(mapping):
|
||||
global _PIPE2LEAKNODE
|
||||
_PIPE2LEAKNODE = mapping
|
||||
|
||||
|
||||
def _get_or_create_leak_demand_ts(leak_node):
|
||||
"""
|
||||
返回:泄漏专用 demand 的下标 idx。
|
||||
若不存在,以 category='leak' 新建一条 base=0.0 的 demand。
|
||||
"""
|
||||
# 先尝试找到已有的 'leak' 分类
|
||||
for i, ts in enumerate(leak_node.demand_timeseries_list):
|
||||
# WNTR 的 Demand object 存在 category 属性
|
||||
if getattr(ts, "category", None) == "leak":
|
||||
return i
|
||||
# 没有则新建(base=0.0,后续临时改 base_value)
|
||||
leak_node.add_demand(base=0.0, pattern_name=None, category="leak")
|
||||
return len(leak_node.demand_timeseries_list) - 1
|
||||
|
||||
|
||||
def ensure_mid_node(wn, leak_pipe):
|
||||
add_pipe1 = f"{leak_pipe}A"
|
||||
add_pipe2 = f"{leak_pipe}B"
|
||||
add_node = f"{leak_pipe}__mid"
|
||||
|
||||
if add_node in wn.node_name_list:
|
||||
return add_node
|
||||
|
||||
if leak_pipe in wn.link_name_list:
|
||||
leak_pipe_self = wn.get_link(leak_pipe)
|
||||
pipe_diameter = leak_pipe_self.diameter
|
||||
pipe_length = leak_pipe_self.length
|
||||
pipe_roughness = leak_pipe_self.roughness
|
||||
pipe_minor_loss = leak_pipe_self.minor_loss
|
||||
pipe_start_node = leak_pipe_self.start_node_name
|
||||
pipe_end_node = leak_pipe_self.end_node_name
|
||||
|
||||
start_n = wn.get_node(pipe_start_node)
|
||||
end_n = wn.get_node(pipe_end_node)
|
||||
if start_n.node_type == "Reservoir":
|
||||
end_elev = end_n.elevation
|
||||
start_elev = end_elev
|
||||
elif end_n.node_type == "Reservoir":
|
||||
start_elev = start_n.elevation
|
||||
end_elev = start_elev
|
||||
else:
|
||||
end_elev = end_n.elevation
|
||||
start_elev = start_n.elevation
|
||||
elev_mid = (start_elev + end_elev) / 2.0
|
||||
x_mid = (start_n.coordinates[0] + end_n.coordinates[0]) / 2.0
|
||||
y_mid = (start_n.coordinates[1] + end_n.coordinates[1]) / 2.0
|
||||
|
||||
wn.remove_link(leak_pipe)
|
||||
wn.add_junction(
|
||||
add_node, base_demand=0.0, elevation=elev_mid, coordinates=(x_mid, y_mid)
|
||||
)
|
||||
wn.add_pipe(
|
||||
add_pipe1,
|
||||
start_node_name=pipe_start_node,
|
||||
end_node_name=add_node,
|
||||
length=pipe_length / 2.0,
|
||||
diameter=pipe_diameter,
|
||||
roughness=pipe_roughness,
|
||||
minor_loss=pipe_minor_loss,
|
||||
)
|
||||
wn.add_pipe(
|
||||
add_pipe2,
|
||||
start_node_name=add_node,
|
||||
end_node_name=pipe_end_node,
|
||||
length=pipe_length / 2.0,
|
||||
diameter=pipe_diameter,
|
||||
roughness=pipe_roughness,
|
||||
minor_loss=pipe_minor_loss,
|
||||
)
|
||||
return add_node
|
||||
|
||||
# 若 A/B 已存在但中点不在,建议确认网络一致性
|
||||
raise KeyError(f"Cannot ensure mid node for pipe '{leak_pipe}'.")
|
||||
|
||||
|
||||
def leak_simulation_pipe_dd_multi_pf(
|
||||
wn, leak_mag, leak_pipe, sensor_name, file_prefix=None
|
||||
):
|
||||
"""
|
||||
优化版:
|
||||
- 不再 remove/add link/node
|
||||
- 直接在预插入的中点泄漏节点上设置 base_demand = leak_mag;仿真后设回 0
|
||||
"""
|
||||
wn.options.hydraulic.demand_model = "DD"
|
||||
|
||||
# 确保中点节点存在
|
||||
leak_node_name = ensure_mid_node(wn, leak_pipe)
|
||||
leak_node = wn.get_node(leak_node_name)
|
||||
|
||||
# 拿到泄漏专用的 demand time-series 下标
|
||||
leak_idx = _get_or_create_leak_demand_ts(leak_node)
|
||||
ts_obj = leak_node.demand_timeseries_list[leak_idx]
|
||||
|
||||
# 记录原值(通常是 0.0)
|
||||
orig_base = ts_obj.base_value
|
||||
try:
|
||||
# 打开泄漏:只改 base_value,不碰 base_demand(只读)
|
||||
ts_obj.base_value = float(leak_mag)
|
||||
|
||||
# 仿真
|
||||
sim = wntr.sim.EpanetSimulator(wn)
|
||||
if file_prefix is None:
|
||||
results = sim.run_sim()
|
||||
else:
|
||||
results = sim.run_sim(file_prefix=file_prefix)
|
||||
|
||||
# 输出(保持列顺序)
|
||||
pressure_output = results.node["pressure"].loc[:, sensor_name]
|
||||
# flow_output = results.link['flowrate'].loc[:, sensor_f_name]
|
||||
return wn, pressure_output
|
||||
finally:
|
||||
# 关闭泄漏:还原 base_value
|
||||
ts_obj.base_value = orig_base
|
||||
if file_prefix is not None:
|
||||
_cleanup_temp_files(file_prefix)
|
||||
|
||||
|
||||
def prepare_leak_infrastructure(wn, candidate_pipes):
|
||||
"""
|
||||
把 candidate_pipes 每条管段切成两段,并在中点插入一个泄漏节点(base_demand=0)。
|
||||
返回一个映射:pipe_id -> leak_node_name
|
||||
注意:只做一次;后续仿真通过在该节点设置 base_demand 实现“打开泄漏”,结束后恢复为 0。
|
||||
"""
|
||||
pipe2leaknode = {}
|
||||
for leak_pipe in candidate_pipes:
|
||||
if leak_pipe in pipe2leaknode:
|
||||
continue
|
||||
|
||||
leak_pipe_self = wn.get_link(leak_pipe)
|
||||
pipe_diameter = leak_pipe_self.diameter
|
||||
pipe_length = leak_pipe_self.length
|
||||
pipe_roughness = leak_pipe_self.roughness
|
||||
pipe_minor_loss = leak_pipe_self.minor_loss
|
||||
pipe_start_node = leak_pipe_self.start_node_name
|
||||
pipe_end_node = leak_pipe_self.end_node_name
|
||||
|
||||
# 计算中点高程/坐标(与原逻辑一致)
|
||||
start_n = wn.get_node(pipe_start_node)
|
||||
end_n = wn.get_node(pipe_end_node)
|
||||
if start_n.node_type == "Reservoir":
|
||||
end_elev = end_n.elevation
|
||||
start_elev = end_elev
|
||||
elif end_n.node_type == "Reservoir":
|
||||
start_elev = start_n.elevation
|
||||
end_elev = start_elev
|
||||
else:
|
||||
end_elev = end_n.elevation
|
||||
start_elev = start_n.elevation
|
||||
elev_mid = (start_elev + end_elev) / 2.0
|
||||
x_mid = (start_n.coordinates[0] + end_n.coordinates[0]) / 2.0
|
||||
y_mid = (start_n.coordinates[1] + end_n.coordinates[1]) / 2.0
|
||||
|
||||
# 先删原管,再加中点与两段半长管(只做一次)
|
||||
wn.remove_link(leak_pipe)
|
||||
|
||||
add_pipe1 = f"{leak_pipe}A"
|
||||
add_pipe2 = f"{leak_pipe}B"
|
||||
add_node = f"{leak_pipe}__mid" # 唯一命名,后面直接用它当泄漏节点
|
||||
|
||||
wn.add_junction(
|
||||
add_node, base_demand=0.0, elevation=elev_mid, coordinates=(x_mid, y_mid)
|
||||
)
|
||||
wn.add_pipe(
|
||||
add_pipe1,
|
||||
start_node_name=pipe_start_node,
|
||||
end_node_name=add_node,
|
||||
length=pipe_length / 2.0,
|
||||
diameter=pipe_diameter,
|
||||
roughness=pipe_roughness,
|
||||
minor_loss=pipe_minor_loss,
|
||||
)
|
||||
wn.add_pipe(
|
||||
add_pipe2,
|
||||
start_node_name=add_node,
|
||||
end_node_name=pipe_end_node,
|
||||
length=pipe_length / 2.0,
|
||||
diameter=pipe_diameter,
|
||||
roughness=pipe_roughness,
|
||||
minor_loss=pipe_minor_loss,
|
||||
)
|
||||
|
||||
pipe2leaknode[leak_pipe] = add_node
|
||||
|
||||
return pipe2leaknode
|
||||
|
||||
|
||||
def normal_simulation_pf(
|
||||
wn, drive_mode, sensor_name, sensor_f_name, inp_time, require_p, minimum_p
|
||||
):
|
||||
# inp_time = 0
|
||||
if drive_mode == "PDD": # 需水量根据节点压力动态调整
|
||||
wn.options.hydraulic.demand_model = "PDD"
|
||||
wn.options.hydraulic.required_pressure = require_p
|
||||
wn.options.hydraulic.minimum_pressure = minimum_p
|
||||
elif drive_mode == "DD": # 需水量固定,与压力无关
|
||||
wn.options.hydraulic.demand_model = "DD"
|
||||
sim = wntr.sim.EpanetSimulator(wn)
|
||||
results = sim.run_sim()
|
||||
pressure_all = results.node["pressure"][sensor_name]
|
||||
pressure = pressure_all.iloc[inp_time]
|
||||
demand_all = results.node["demand"]
|
||||
demand = demand_all.iloc[inp_time]
|
||||
sum_demand = cal_sum_demand(demand)
|
||||
flow_all = results.link["flowrate"][sensor_f_name]
|
||||
flow = flow_all.iloc[inp_time]
|
||||
|
||||
top_sensor = pressure.idxmin()
|
||||
basic_p = results.node["pressure"]
|
||||
basic_p = basic_p.iloc[inp_time]
|
||||
return pressure, flow, basic_p, top_sensor, sum_demand
|
||||
|
||||
|
||||
def normal_simulation_multi_pf(
|
||||
wn, drive_mode, sensor_name, sensor_f_name, require_p, minimum_p
|
||||
):
|
||||
# inp_time = 0
|
||||
if drive_mode == "PDD":
|
||||
wn.options.hydraulic.demand_model = "PDD"
|
||||
wn.options.hydraulic.required_pressure = require_p
|
||||
wn.options.hydraulic.minimum_pressure = minimum_p
|
||||
elif drive_mode == "DD":
|
||||
wn.options.hydraulic.demand_model = "DD"
|
||||
sim = wntr.sim.EpanetSimulator(wn)
|
||||
results = sim.run_sim()
|
||||
pressure_all = results.node["pressure"][sensor_name]
|
||||
pressure = pressure_all
|
||||
demand_all = results.node["demand"]
|
||||
demand = demand_all
|
||||
flow = results.link["flowrate"][sensor_f_name]
|
||||
sum_demand = pd.Series(dtype=object)
|
||||
for i in range(len(demand.index)):
|
||||
sum_demand[str(demand.index[i])] = cal_sum_demand(demand.iloc[i])
|
||||
if type(pressure) == pd.core.series.Series:
|
||||
top_sensor = pressure.idxmin()
|
||||
else:
|
||||
mean_pressure = pressure.mean()
|
||||
top_sensor = mean_pressure.idxmin()
|
||||
basic_p = results.node["pressure"]
|
||||
return pressure, flow, basic_p, top_sensor, sum_demand
|
||||
|
||||
|
||||
def simple_simulation_pf(wn, sensor_name, sensor_f_name, leak_pipe, add_pipe1):
|
||||
sim = wntr.sim.EpanetSimulator(wn)
|
||||
results = sim.run_sim()
|
||||
pressure_all = results.node["pressure"][sensor_name]
|
||||
if len(leak_pipe) > 0 and leak_pipe in sensor_f_name:
|
||||
|
||||
f_sensor_name = [add_pipe1 if i == leak_pipe else i for i in sensor_f_name]
|
||||
flow_all = results.link["flowrate"][f_sensor_name]
|
||||
flow_all.columns = sensor_f_name
|
||||
else:
|
||||
flow_all = results.link["flowrate"][sensor_f_name]
|
||||
return pressure_all, flow_all
|
||||
|
||||
|
||||
def cal_sum_demand(demand):
|
||||
sum_demand = 0
|
||||
for i in range(len(demand)):
|
||||
if demand.iloc[i] > 0:
|
||||
sum_demand += demand.iloc[i]
|
||||
return sum_demand
|
||||
|
||||
|
||||
def cal_signature_pipe_multi_pf(
|
||||
wn,
|
||||
leak_mag,
|
||||
candidate_center,
|
||||
timestep_list,
|
||||
sensor_name,
|
||||
n_workers=1,
|
||||
wn_inp_path=None,
|
||||
):
|
||||
candidate_center_num = len(candidate_center)
|
||||
pressure_leak = pd.DataFrame(
|
||||
index=pd.MultiIndex.from_product([candidate_center, timestep_list]),
|
||||
columns=sensor_name,
|
||||
)
|
||||
# flow_leak = pd.DataFrame(index=pd.MultiIndex.from_product([candidate_center, timestep_list]),
|
||||
# columns=sensor_f_name)
|
||||
pressure_leak = pressure_leak.sort_index()
|
||||
# flow_leak = flow_leak.sort_index()
|
||||
can_parallel = (
|
||||
n_workers > 1
|
||||
and candidate_center_num > 1
|
||||
and wn_inp_path is not None
|
||||
and len(str(wn_inp_path)) > 0
|
||||
)
|
||||
if can_parallel:
|
||||
option_values = _snapshot_hydraulic_options(wn)
|
||||
worker_count = min(n_workers, candidate_center_num)
|
||||
start_methods = mp.get_all_start_methods()
|
||||
context_name = "spawn" if "spawn" in start_methods else start_methods[0]
|
||||
with mp.get_context(context_name).Pool(
|
||||
processes=worker_count,
|
||||
initializer=_signature_worker_init,
|
||||
initargs=(
|
||||
str(wn_inp_path),
|
||||
float(leak_mag),
|
||||
list(sensor_name),
|
||||
option_values,
|
||||
list(candidate_center),
|
||||
),
|
||||
) as pool:
|
||||
for i, (center_name, pressure_array) in enumerate(
|
||||
pool.imap(_signature_worker_run_center, candidate_center)
|
||||
):
|
||||
pressure_leak.loc[(center_name, slice(None)), :] = pressure_array
|
||||
sys.stdout.write("\r" + "已经完成计算" + str(i + 1) + "个特征中心")
|
||||
else:
|
||||
# Pre-insert all mid-nodes so every simulation sees the same topology
|
||||
for center in candidate_center:
|
||||
ensure_mid_node(wn, center)
|
||||
for i in range(candidate_center_num):
|
||||
temp_prefix = _make_temp_prefix(f"sig_{i}")
|
||||
wn, pressure_output = leak_simulation_pipe_dd_multi_pf(
|
||||
wn,
|
||||
leak_mag,
|
||||
candidate_center[i],
|
||||
sensor_name,
|
||||
file_prefix=temp_prefix,
|
||||
)
|
||||
# leak_or_not_list.append(leak_or_not)
|
||||
pressure_leak.loc[(candidate_center[i], slice(None)), :] = (
|
||||
pressure_output.to_numpy()
|
||||
)
|
||||
# flow_leak.loc[candidate_center[i]].loc[:, :] = flow_output
|
||||
sys.stdout.write("\r" + "已经完成计算" + str(i + 1) + "个特征中心")
|
||||
return pressure_leak, candidate_center
|
||||
|
||||
|
||||
def _signature_worker_init(
|
||||
inp_path, leak_mag, sensor_name, option_values, candidate_centers=None
|
||||
):
|
||||
global _SIGNATURE_WORKER_DATA
|
||||
wn = wntr.network.WaterNetworkModel(inp_path)
|
||||
_apply_hydraulic_options(wn, option_values)
|
||||
# Pre-insert ALL mid-nodes so every simulation runs on the same topology,
|
||||
# regardless of which worker handles which task.
|
||||
if candidate_centers is not None:
|
||||
for center in candidate_centers:
|
||||
ensure_mid_node(wn, center)
|
||||
_SIGNATURE_WORKER_DATA = {
|
||||
"wn": wn,
|
||||
"leak_mag": leak_mag,
|
||||
"sensor_name": sensor_name,
|
||||
}
|
||||
|
||||
|
||||
def _signature_worker_run_center(center_name):
|
||||
data = _SIGNATURE_WORKER_DATA
|
||||
temp_prefix = _make_temp_prefix(f"sig_worker_{center_name}")
|
||||
_, pressure_output = leak_simulation_pipe_dd_multi_pf(
|
||||
data["wn"],
|
||||
data["leak_mag"],
|
||||
center_name,
|
||||
data["sensor_name"],
|
||||
file_prefix=temp_prefix,
|
||||
)
|
||||
return center_name, pressure_output.to_numpy()
|
||||
|
||||
|
||||
def pick_pipe(all_pipes, pipe_diameter, limited_diameter):
|
||||
candidate_pipe = []
|
||||
for each_pipe in all_pipes:
|
||||
if pipe_diameter[each_pipe] >= limited_diameter:
|
||||
candidate_pipe.append(each_pipe)
|
||||
return candidate_pipe
|
||||
|
||||
|
||||
def cal_possible_pipe(leak_flow, all_pipe, pipe_diameter):
|
||||
basic_pressure = 10 # 基础压力
|
||||
discharge_coeff = 0.6 # 经验系数
|
||||
break_area_ratio = 1 # 爆管面积比 0.5 1.25
|
||||
break_area = leak_flow / (
|
||||
discharge_coeff * math.sqrt(2 * basic_pressure * 9.81)
|
||||
) # 爆管面积 m3/h
|
||||
"""break_area_diameter = math.sqrt(4 * break_area / math.pi)
|
||||
min_diameter = (math.ceil(1000 * break_area_diameter / break_area_ratio)) / 1000"""
|
||||
break_area_diameter = math.sqrt(
|
||||
4 * break_area / math.pi / break_area_ratio
|
||||
) # 爆管直径
|
||||
min_diameter = (math.ceil(1000 * break_area_diameter)) / 1000 # 向上取整
|
||||
new_all_pipe = pick_pipe(all_pipe, pipe_diameter, min_diameter)
|
||||
return new_all_pipe, min_diameter
|
||||
|
||||
|
||||
def extract_links(data, link_types, direction):
|
||||
return [
|
||||
link
|
||||
for res_data in data.values()
|
||||
for link_type in link_types
|
||||
for link in res_data[link_type][direction]
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
"""管网模型读取与图构建模块。"""
|
||||
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
import wntr
|
||||
|
||||
|
||||
def load_inp(inp_name, inp_location, inp_time, driven_mode, require_p, minimum_p):
|
||||
inp_file = inp_location + inp_name
|
||||
wn = wntr.network.WaterNetworkModel(inp_file)
|
||||
if driven_mode == "PDD":
|
||||
wn.options.hydraulic.demand_model = "PDD"
|
||||
wn.options.hydraulic.required_pressure = require_p
|
||||
wn.options.hydraulic.minimum_pressure = minimum_p
|
||||
else:
|
||||
wn.options.hydraulic.demand_model = "DD"
|
||||
return wn
|
||||
|
||||
|
||||
def read_inf_inp(wn):
|
||||
all_node = wn.node_name_list
|
||||
node_elevation = wn.query_node_attribute("elevation")
|
||||
node_coordinates = wn.query_node_attribute("coordinates")
|
||||
|
||||
all_pipe = wn.pipe_name_list
|
||||
# 改_wz__________________________________
|
||||
n_pipe = []
|
||||
for p in all_pipe:
|
||||
pipe = wn.get_link(p)
|
||||
if pipe.initial_status == 0: # 状态为'Closed'
|
||||
n_pipe.append(p)
|
||||
candidate_pipe_init = sorted(set(all_pipe) - set(n_pipe))
|
||||
pipe_start_node = wn.query_link_attribute(
|
||||
"start_node_name", link_type=wntr.network.model.Pipe
|
||||
)
|
||||
pipe_end_node = wn.query_link_attribute(
|
||||
"end_node_name", link_type=wntr.network.model.Pipe
|
||||
)
|
||||
pipe_length = wn.query_link_attribute("length")
|
||||
pipe_diameter = wn.query_link_attribute("diameter")
|
||||
|
||||
return (
|
||||
all_node,
|
||||
node_elevation,
|
||||
node_coordinates,
|
||||
candidate_pipe_init,
|
||||
pipe_start_node,
|
||||
pipe_end_node,
|
||||
pipe_length,
|
||||
pipe_diameter,
|
||||
)
|
||||
|
||||
|
||||
def read_inf_inp_other(wn):
|
||||
all_link = wn.link_name_list
|
||||
pipe_start_node_all = wn.query_link_attribute("start_node_name")
|
||||
pipe_end_node_all = wn.query_link_attribute("end_node_name")
|
||||
|
||||
return all_link, pipe_start_node_all, pipe_end_node_all
|
||||
|
||||
|
||||
def construct_graph(wn):
|
||||
length = wn.query_link_attribute("length")
|
||||
G = wn.get_graph(wn, link_weight=length)
|
||||
# 转为无向图
|
||||
G0 = G.to_undirected()
|
||||
# A0 = np.array(nx.adjacency_graph(G0).todense())
|
||||
return G0 # , A0
|
||||
|
||||
|
||||
def cal_pipe_coordinate(all_pipe, pipe_start_node, pipe_end_node, node_coordinates):
|
||||
pipe_num = len(all_pipe)
|
||||
pipe_coordinates = np.zeros([pipe_num, 2])
|
||||
pipe_x = copy.deepcopy(pipe_start_node)
|
||||
pipe_y = copy.deepcopy(pipe_start_node)
|
||||
for i in range(pipe_num):
|
||||
temp_pipe = all_pipe[i]
|
||||
pipe_x[temp_pipe] = (
|
||||
node_coordinates[pipe_start_node[temp_pipe]][0]
|
||||
+ node_coordinates[pipe_end_node[temp_pipe]][0]
|
||||
) / 2
|
||||
pipe_y[temp_pipe] = (
|
||||
node_coordinates[pipe_start_node[temp_pipe]][1]
|
||||
+ node_coordinates[pipe_end_node[temp_pipe]][1]
|
||||
) / 2
|
||||
return pipe_x, pipe_y
|
||||
|
||||
|
||||
def cal_node_coordinate(all_node, node_coordinates):
|
||||
node_x = copy.deepcopy(node_coordinates)
|
||||
node_y = copy.deepcopy(node_coordinates)
|
||||
for i in range(len(node_x)):
|
||||
temp_node = all_node[i]
|
||||
node_x[temp_node] = node_coordinates[temp_node][0]
|
||||
node_y[temp_node] = node_coordinates[temp_node][1]
|
||||
return node_x, node_y
|
||||
|
||||
|
||||
def produce_pattern_value(wn, all_node):
|
||||
wn_o = copy.deepcopy(wn)
|
||||
# 改_wz_____________________________
|
||||
# sample_node = wn_o.get_node(all_node[0])
|
||||
# num_categories = len(sample_node.demand_timeseries_list)
|
||||
num_categories = 1
|
||||
columns = [f"D{i}" for i in range(num_categories)]
|
||||
basic_demand_pd = pd.DataFrame(index=all_node, columns=columns)
|
||||
for each in all_node:
|
||||
node = wn_o.get_node(each)
|
||||
for i in range(num_categories):
|
||||
basic_demand_pd.loc[each, columns[i]] = node.demand_timeseries_list[
|
||||
i
|
||||
].base_value
|
||||
|
||||
return basic_demand_pd
|
||||
|
||||
|
||||
def _build_node_pipe_maps(
|
||||
all_nodes, candidate_pipes, pipe_start_node, pipe_end_node, pipe_length
|
||||
):
|
||||
node_pipe_dic = {node: [] for node in all_nodes}
|
||||
couple_node_length = {}
|
||||
for pipe in candidate_pipes:
|
||||
start_node = pipe_start_node[pipe]
|
||||
end_node = pipe_end_node[pipe]
|
||||
if start_node in node_pipe_dic:
|
||||
node_pipe_dic[start_node].append(pipe)
|
||||
if end_node in node_pipe_dic:
|
||||
node_pipe_dic[end_node].append(pipe)
|
||||
length = float(pipe_length[pipe])
|
||||
couple_node_length[f"{start_node},{end_node}"] = length
|
||||
couple_node_length[f"{end_node},{start_node}"] = length
|
||||
return node_pipe_dic, couple_node_length
|
||||
|
||||
|
||||
@@ -0,0 +1,456 @@
|
||||
"""管网分区模块。"""
|
||||
|
||||
import math
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import networkx as networkx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pymetis
|
||||
from scipy.sparse import coo_matrix, csr_matrix
|
||||
from scipy.sparse.csgraph import connected_components
|
||||
|
||||
|
||||
def _to_metis_edge_weight(edge_weight):
|
||||
weight = float(edge_weight)
|
||||
if not math.isfinite(weight):
|
||||
raise ValueError(f"Invalid non-finite METIS edge weight: {edge_weight}")
|
||||
# pymetis expects integer edge weights.
|
||||
return max(1, int(round(weight)))
|
||||
|
||||
|
||||
def _dedupe_preserve_order(items):
|
||||
seen = set()
|
||||
output = []
|
||||
for item in items:
|
||||
if item in seen:
|
||||
continue
|
||||
seen.add(item)
|
||||
output.append(item)
|
||||
return output
|
||||
|
||||
|
||||
def pick_center_pipe(node_x, node_y, candidate_pipe, pipe_start_node, pipe_end_node):
|
||||
candidate_pipe_list = list(candidate_pipe)
|
||||
start_nodes = pipe_start_node[candidate_pipe_list]
|
||||
end_nodes = pipe_end_node[candidate_pipe_list]
|
||||
|
||||
x_vals = (node_x[start_nodes].to_numpy() + node_x[end_nodes].to_numpy()) / 2.0
|
||||
y_vals = (node_y[start_nodes].to_numpy() + node_y[end_nodes].to_numpy()) / 2.0
|
||||
mean_x = float(np.mean(x_vals))
|
||||
mean_y = float(np.mean(y_vals))
|
||||
distance = np.abs(x_vals - mean_x) + np.abs(y_vals - mean_y)
|
||||
center_idx = int(np.argmin(distance))
|
||||
return candidate_pipe_list[center_idx]
|
||||
|
||||
|
||||
def pick_max_diameter_pipe(candidate_pipe, pipe_diameter):
|
||||
candidate_pipe_list = list(candidate_pipe)
|
||||
diameters = pd.to_numeric(
|
||||
pipe_diameter.reindex(candidate_pipe_list), errors="coerce"
|
||||
).dropna()
|
||||
if len(diameters) != len(candidate_pipe_list):
|
||||
missing = sorted(set(candidate_pipe_list) - set(diameters.index))
|
||||
preview = ", ".join(map(str, missing[:10]))
|
||||
raise ValueError(f"Missing or invalid diameter for pipes: {preview}")
|
||||
|
||||
max_diameter = float(diameters.max())
|
||||
max_diameter_pipes = sorted(
|
||||
[pipe for pipe, diameter in diameters.items() if float(diameter) == max_diameter],
|
||||
key=str,
|
||||
)
|
||||
return max_diameter_pipes[0]
|
||||
|
||||
|
||||
def pick_dual_center_pipes(
|
||||
node_x, node_y, candidate_pipe, pipe_start_node, pipe_end_node, pipe_diameter
|
||||
):
|
||||
geometric_center = pick_center_pipe(
|
||||
node_x, node_y, candidate_pipe, pipe_start_node, pipe_end_node
|
||||
)
|
||||
diameter_center = pick_max_diameter_pipe(candidate_pipe, pipe_diameter)
|
||||
return _dedupe_preserve_order([geometric_center, diameter_center])
|
||||
|
||||
|
||||
def find_new_center_pipe(
|
||||
node_x,
|
||||
node_y,
|
||||
candidate_pipe,
|
||||
pipe_start_node,
|
||||
pipe_end_node,
|
||||
pipe_diameter,
|
||||
record_center,
|
||||
):
|
||||
new_candidate_pipe = sorted(set(candidate_pipe) - set(record_center))
|
||||
if new_candidate_pipe == []:
|
||||
new_candidate_pipe = candidate_pipe
|
||||
center_t = pick_center_pipe(
|
||||
node_x,
|
||||
node_y,
|
||||
new_candidate_pipe,
|
||||
pipe_start_node,
|
||||
pipe_end_node,
|
||||
)
|
||||
return center_t
|
||||
|
||||
|
||||
def cal_area_node_linked_pipe(nodeset, node_pipe_dic):
|
||||
pipeset = []
|
||||
for temp_node in nodeset:
|
||||
pipeset.extend(node_pipe_dic[temp_node])
|
||||
return pipeset
|
||||
|
||||
|
||||
def metis_grouping_pipe_weight(
|
||||
G0,
|
||||
wn,
|
||||
all_node_iter,
|
||||
candidate_pipe_input,
|
||||
group_num,
|
||||
node_x,
|
||||
node_y,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
node_pipe_dic,
|
||||
all_node_series,
|
||||
couple_node_length,
|
||||
pipe_diameter,
|
||||
):
|
||||
all_node_iter_series_new = all_node_series[all_node_iter]
|
||||
all_node_iter_series_new = all_node_iter_series_new.sort_values(ascending=True)
|
||||
all_node_iter_new = list(all_node_iter_series_new.index)
|
||||
|
||||
G1 = G0.subgraph(all_node_iter_new)
|
||||
delimiter = " "
|
||||
adjacency_list = []
|
||||
node_dict = {}
|
||||
c_new = 0
|
||||
for each_node in all_node_iter_new:
|
||||
node_dict[each_node] = c_new
|
||||
c_new = c_new + 1
|
||||
correspond_dic = {}
|
||||
count_node = 0
|
||||
w = []
|
||||
for node_name in all_node_iter_new:
|
||||
neighbors = G1[node_name]
|
||||
w_temp = []
|
||||
n_t = [node_dict[node_name]]
|
||||
for neighbor_name in sorted(neighbors.keys()):
|
||||
edge_data = neighbors[neighbor_name]
|
||||
edge_key = f"{node_name},{neighbor_name}"
|
||||
reverse_edge_key = f"{neighbor_name},{node_name}"
|
||||
if edge_key in couple_node_length:
|
||||
edge_weight = couple_node_length[edge_key]
|
||||
elif reverse_edge_key in couple_node_length:
|
||||
edge_weight = couple_node_length[reverse_edge_key]
|
||||
elif edge_data.get("weight") is not None:
|
||||
edge_weight = float(edge_data["weight"])
|
||||
else:
|
||||
# Ignore graph edges that are outside candidate pipes and have no usable
|
||||
# partition weight (e.g. some non-pipe links in mixed network graphs).
|
||||
continue
|
||||
w_temp.append(_to_metis_edge_weight(edge_weight))
|
||||
n_t.append(node_dict[neighbor_name])
|
||||
w.append(w_temp)
|
||||
correspond_dic[n_t[0]] = count_node
|
||||
count_node = count_node + 1
|
||||
# del n_t[0]
|
||||
adjacency_list.append(n_t)
|
||||
adjacency_list_new = [[] * 1 for i in range(len(adjacency_list))]
|
||||
w_new = [[] * 1 for i in range(len(adjacency_list))]
|
||||
for i in range(len(adjacency_list)):
|
||||
adjacency_list_new[int(adjacency_list[i][0])] = adjacency_list[i]
|
||||
w_new[int(adjacency_list[i][0])] = w[i]
|
||||
|
||||
for i in range(len(adjacency_list)):
|
||||
del adjacency_list_new[i][0]
|
||||
xadj = [0]
|
||||
w_f = []
|
||||
final_adjacency_list = []
|
||||
for i in range(len(adjacency_list_new)):
|
||||
final_adjacency_list = final_adjacency_list + adjacency_list_new[i]
|
||||
xadj.append(len(final_adjacency_list))
|
||||
w_f = w_f + w_new[i]
|
||||
|
||||
# (edgecuts, parts) = pymetis.part_graph(nparts=group_num, adjacency=adjacency_list_new)
|
||||
metis_options = pymetis.Options()
|
||||
metis_options.seed = 42
|
||||
(edgecuts, parts) = pymetis.part_graph(
|
||||
nparts=group_num,
|
||||
adjncy=final_adjacency_list,
|
||||
xadj=xadj,
|
||||
eweights=w_f,
|
||||
options=metis_options,
|
||||
)
|
||||
# (edgecuts, parts) = pymetis.part_graph(nparts=group_num, adjacency=adjacency_list_new)
|
||||
candidate_group_list = [[] * 1 for i in range(group_num)]
|
||||
for i in range(len(all_node_iter_new)):
|
||||
candidate_group_list[parts[i]].append(all_node_iter_new[i])
|
||||
|
||||
"""parts_new = np.zeros(len(candidate_node_input), dtype=int)
|
||||
for i in range(len(candidate_group_list)):
|
||||
temp_group = candidate_group_list[i]
|
||||
for each_node in temp_group:
|
||||
parts_new[node_dict[each_node]] = i
|
||||
parts_new = list(parts_new)"""
|
||||
|
||||
new_center = []
|
||||
new_group = []
|
||||
new_center_candidates = []
|
||||
new_all_node = []
|
||||
candidate_pipe_set = set(candidate_pipe_input)
|
||||
all_grouped_pipe = []
|
||||
for i in range(group_num):
|
||||
# 构建子图
|
||||
G_sub = G0.subgraph(candidate_group_list[i])
|
||||
|
||||
# 计算联通子图
|
||||
sub_graphs = networkx.connected_components(G_sub)
|
||||
if networkx.number_connected_components(G_sub) == 1:
|
||||
# 求交集
|
||||
nodeset = G_sub.nodes()
|
||||
pipeset_set = set(cal_area_node_linked_pipe(nodeset, node_pipe_dic))
|
||||
candidate_pipe = sorted(pipeset_set.intersection(candidate_pipe_set))
|
||||
|
||||
# 判断集合是否保留
|
||||
if len(candidate_pipe) > 0:
|
||||
# 保留 计算中心
|
||||
center_t = pick_center_pipe(
|
||||
node_x,
|
||||
node_y,
|
||||
candidate_pipe,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
)
|
||||
center_candidates_t = pick_dual_center_pipes(
|
||||
node_x,
|
||||
node_y,
|
||||
candidate_pipe,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
pipe_diameter,
|
||||
)
|
||||
|
||||
# 更新
|
||||
new_center.append(center_t)
|
||||
new_center_candidates.append(center_candidates_t)
|
||||
new_group.append(candidate_pipe)
|
||||
new_all_node.append(nodeset)
|
||||
all_grouped_pipe = all_grouped_pipe + candidate_pipe
|
||||
else:
|
||||
for c in sorted(sub_graphs, key=lambda c: min(c)):
|
||||
G_temp = G0.subgraph(c)
|
||||
nodeset = G_temp.nodes()
|
||||
pipeset = cal_area_node_linked_pipe(nodeset, node_pipe_dic)
|
||||
pipeset_set = set(pipeset)
|
||||
|
||||
# 求交集
|
||||
candidate_pipe = sorted(pipeset_set.intersection(candidate_pipe_set))
|
||||
# print(len(candidate_node))
|
||||
# 判断集合是否保留
|
||||
if len(candidate_pipe) > 0:
|
||||
# 保留 计算中心
|
||||
center_t = pick_center_pipe(
|
||||
node_x,
|
||||
node_y,
|
||||
candidate_pipe,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
)
|
||||
center_candidates_t = pick_dual_center_pipes(
|
||||
node_x,
|
||||
node_y,
|
||||
candidate_pipe,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
pipe_diameter,
|
||||
)
|
||||
# 更新
|
||||
new_center.append(center_t)
|
||||
new_center_candidates.append(center_candidates_t)
|
||||
new_group.append(candidate_pipe)
|
||||
new_all_node.append(nodeset)
|
||||
all_grouped_pipe = all_grouped_pipe + candidate_pipe
|
||||
record_center = []
|
||||
c_g = 0
|
||||
for each_group in new_group:
|
||||
if len(each_group) < 3:
|
||||
record_center.append(new_center[c_g])
|
||||
c_g += 1
|
||||
c_g = 0
|
||||
for each_group in new_group:
|
||||
if len(each_group) >= 3:
|
||||
if new_center[c_g] in record_center:
|
||||
new_center[c_g] = find_new_center_pipe(
|
||||
node_x,
|
||||
node_y,
|
||||
each_group,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
pipe_diameter,
|
||||
record_center,
|
||||
)
|
||||
new_center_candidates[c_g] = _dedupe_preserve_order(
|
||||
[new_center[c_g]] + list(new_center_candidates[c_g])
|
||||
)
|
||||
record_center.append(new_center[c_g])
|
||||
c_g += 1
|
||||
|
||||
# visualize_metis_partition(
|
||||
# G0, new_center, new_group,
|
||||
# node_x, node_y,
|
||||
# pipe_start_node_all, pipe_end_node_all
|
||||
# )
|
||||
return new_center, new_group, new_all_node, new_center_candidates
|
||||
|
||||
|
||||
def visualize_metis_partition(
|
||||
G,
|
||||
center_pipes,
|
||||
pipe_groups,
|
||||
node_x,
|
||||
node_y,
|
||||
pipe_start_node_all,
|
||||
pipe_end_node_all,
|
||||
title: str | None = None,
|
||||
block: bool = True,
|
||||
pause_seconds: float | None = None,
|
||||
):
|
||||
"""
|
||||
可视化METIS分区结果(单图模式)
|
||||
参数:
|
||||
G: 原始管网图(nx.Graph)
|
||||
center_pipes: 中心管道列表(list)
|
||||
pipe_groups: 分组管道列表(list of lists)
|
||||
node_x: 节点X坐标字典(dict)
|
||||
node_y: 节点Y坐标字典(dict)
|
||||
pipe_start_node_all: 管道起点字典(dict)
|
||||
pipe_end_node_all: 管道终点字典(dict)
|
||||
"""
|
||||
fig = plt.figure("metis_partition_convergence", figsize=(22.51, 12.48))
|
||||
fig.clf()
|
||||
ax = fig.add_subplot(111)
|
||||
if not block:
|
||||
plt.ion()
|
||||
|
||||
# 生成颜色映射(自动扩展颜色数量)
|
||||
colors = plt.cm.tab20(np.linspace(0, 1, len(pipe_groups)))
|
||||
|
||||
# --- 绘制背景管网(灰色半透明) ---
|
||||
for edge in G.edges():
|
||||
start_node, end_node = edge
|
||||
ax.plot(
|
||||
[node_x[start_node], node_x[end_node]],
|
||||
[node_y[start_node], node_y[end_node]],
|
||||
color="lightgray",
|
||||
linewidth=0.5,
|
||||
alpha=0.3,
|
||||
zorder=1, # 确保背景在底层
|
||||
)
|
||||
|
||||
# --- 绘制各分区管道(彩色)---
|
||||
legend_handles = [] # 用于图例的句柄
|
||||
for i, (group, center) in enumerate(zip(pipe_groups, center_pipes)):
|
||||
color = colors[i % len(colors)] # 循环使用颜色
|
||||
|
||||
# 绘制分组管道
|
||||
for pipe in group:
|
||||
start = pipe_start_node_all[pipe]
|
||||
end = pipe_end_node_all[pipe]
|
||||
line = ax.plot(
|
||||
[node_x[start], node_x[end]],
|
||||
[node_y[start], node_y[end]],
|
||||
color=color,
|
||||
linewidth=2.5,
|
||||
alpha=0.8,
|
||||
zorder=2,
|
||||
)
|
||||
# 只为每个分组的第一个管道添加图例句柄
|
||||
if pipe == group[0]:
|
||||
legend_handles.append(line[0])
|
||||
|
||||
# 高亮中心管道(红色虚线)
|
||||
if center in pipe_start_node_all and center in pipe_end_node_all:
|
||||
start = pipe_start_node_all[center]
|
||||
end = pipe_end_node_all[center]
|
||||
ax.plot(
|
||||
[node_x[start], node_x[end]],
|
||||
[node_y[start], node_y[end]],
|
||||
color="red",
|
||||
linewidth=4,
|
||||
linestyle="--",
|
||||
dash_capstyle="round",
|
||||
zorder=3, # 确保中心管道在最顶层
|
||||
)
|
||||
|
||||
# --- 添加图例和标注 ---
|
||||
# 分组图例
|
||||
if legend_handles:
|
||||
group_labels = [f"Group {i + 1}" for i in range(len(pipe_groups))]
|
||||
ax.legend(
|
||||
legend_handles,
|
||||
group_labels,
|
||||
loc="upper right",
|
||||
title="Partitions",
|
||||
fontsize=8,
|
||||
title_fontsize=10,
|
||||
)
|
||||
|
||||
# 中心管道标注(可选)
|
||||
for i, center in enumerate(center_pipes):
|
||||
if center in pipe_start_node_all:
|
||||
x = (
|
||||
node_x[pipe_start_node_all[center]] + node_x[pipe_end_node_all[center]]
|
||||
) / 2
|
||||
y = (
|
||||
node_y[pipe_start_node_all[center]] + node_y[pipe_end_node_all[center]]
|
||||
) / 2
|
||||
ax.text(
|
||||
x,
|
||||
y,
|
||||
f"C{i + 1}",
|
||||
color="red",
|
||||
fontsize=10,
|
||||
ha="center",
|
||||
va="center",
|
||||
bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"),
|
||||
)
|
||||
|
||||
# --- 图形美化 ---
|
||||
ax.set_title(title or "Water Network Partitioning Overview", fontsize=14, pad=20)
|
||||
ax.set_xlabel("X Coordinate", fontsize=10)
|
||||
ax.set_ylabel("Y Coordinate", fontsize=10)
|
||||
ax.grid(True, alpha=0.2, linestyle=":")
|
||||
fig.tight_layout()
|
||||
|
||||
# 显示图形并强制刷新,避免迭代显示滞后一轮。
|
||||
plt.show(block=block)
|
||||
if not block:
|
||||
fig.canvas.draw_idle()
|
||||
fig.canvas.flush_events()
|
||||
pause_value = 0.001 if pause_seconds is None else max(0.0, float(pause_seconds))
|
||||
plt.pause(max(0.001, pause_value))
|
||||
elif pause_seconds is not None:
|
||||
plt.pause(max(0.0, float(pause_seconds)))
|
||||
return fig
|
||||
|
||||
|
||||
def generate_adjlist_with_all_edges(G, delimiter):
|
||||
for s, nbrs in G.adjacency():
|
||||
line = str(s) + delimiter
|
||||
for t, data in nbrs.items():
|
||||
line += str(t) + delimiter
|
||||
yield line[: -len(delimiter)]
|
||||
|
||||
|
||||
def cal_group_num(candidate_node_input, cal_group_num):
|
||||
candidate_node_num = len(candidate_node_input)
|
||||
if candidate_node_num > 100:
|
||||
group_num_input = cal_group_num # 30
|
||||
else:
|
||||
group_num_input = 10
|
||||
|
||||
return group_num_input
|
||||
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
"""噪声生成模块。"""
|
||||
|
||||
import copy
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .leak_simulator import simple_add_leak, simple_recover_wn, simple_simulation_pf
|
||||
|
||||
|
||||
def add_noise_pd(data, noise_type, noise_para):
|
||||
output_data = copy.deepcopy(data)
|
||||
if type(output_data) == pd.core.frame.Series:
|
||||
if noise_type == "uni":
|
||||
for x in output_data.index:
|
||||
noise = (np.random.random() - 0.5) * 2
|
||||
output_data[x] = output_data[x] + noise * noise_para
|
||||
|
||||
elif noise_type == "gauss":
|
||||
noise = np.random.normal(loc=0, scale=noise_para, size=output_data.shape)
|
||||
output_data = output_data + noise
|
||||
elif type(output_data) == pd.core.frame.DataFrame:
|
||||
if noise_type == "uni":
|
||||
noise = (np.random.random(size=output_data.shape) - 0.5) * 2
|
||||
output_data = output_data + noise * noise_para
|
||||
|
||||
elif noise_type == "gauss":
|
||||
noise = np.random.normal(loc=0, scale=noise_para, size=output_data.shape)
|
||||
output_data = output_data + noise
|
||||
return output_data
|
||||
|
||||
|
||||
def add_noise_number(data, noise_type, noise_para):
|
||||
output_data = copy.deepcopy(data)
|
||||
if noise_type == "uni":
|
||||
noise = (np.random.random() - 0.5) * 2
|
||||
output_data = output_data + noise * noise_para
|
||||
|
||||
elif noise_type == "gauss":
|
||||
noise = random.gauss(0, noise_para)
|
||||
output_data = output_data + noise
|
||||
return output_data
|
||||
|
||||
|
||||
def add_noise_number_flow(data, noise_para_mean, noise_para_std1, noise_para_std2):
|
||||
output_data = copy.deepcopy(data)
|
||||
noise_flag1 = np.random.random() - 0.5
|
||||
if noise_flag1 < 0:
|
||||
noise = noise_para_mean - abs(np.random.normal(loc=0, scale=noise_para_std1))
|
||||
else:
|
||||
noise = noise_para_mean + abs(np.random.normal(loc=0, scale=noise_para_std2))
|
||||
noise_flag2 = np.random.random() - 0.5
|
||||
if noise_flag2 < 0:
|
||||
noise_f = noise * (-1)
|
||||
else:
|
||||
noise_f = noise
|
||||
output_data = output_data + noise_f
|
||||
return output_data
|
||||
|
||||
|
||||
def produce_noise_number(noise_type, noise_para):
|
||||
if noise_type == "uni":
|
||||
noise = (np.random.random() - 0.5) * 2
|
||||
noise = noise * noise_para
|
||||
elif noise_type == "gauss":
|
||||
noise = random.gauss(0, noise_para)
|
||||
else:
|
||||
noise = 0
|
||||
return noise
|
||||
|
||||
|
||||
def add_noise_percentage_pd(data, noise_type, noise_para):
|
||||
output_data = copy.deepcopy(data)
|
||||
|
||||
if type(output_data) == pd.core.frame.Series:
|
||||
if noise_type == "uni":
|
||||
for x in output_data.index:
|
||||
noise = (np.random.random() - 0.5) * 2
|
||||
output_data[x] = output_data[x] * (1 + noise * noise_para / 100)
|
||||
|
||||
elif noise_type == "gauss":
|
||||
for x in output_data.index:
|
||||
noise = np.random.gauss(0, noise_para)
|
||||
output_data[x] = output_data[x] * (1 + noise / 100)
|
||||
# std_noise = noise.std()
|
||||
elif type(output_data) == pd.core.frame.DataFrame:
|
||||
if noise_type == "uni":
|
||||
noise = (np.random.random(size=output_data.shape) - 0.5) * 2
|
||||
output_data = output_data * (1 + noise * noise_para / 100)
|
||||
|
||||
elif noise_type == "gauss":
|
||||
noise = np.random.normal(loc=0, scale=noise_para, size=output_data.shape)
|
||||
output_data = output_data * (1 + noise / 100)
|
||||
# std_noise = noise.std().mean()
|
||||
return output_data
|
||||
|
||||
|
||||
def add_noise_in_wn_pf(
|
||||
wn,
|
||||
pipe_c_noise,
|
||||
timestep_list,
|
||||
pipe_coefficient,
|
||||
sensor_name,
|
||||
sensor_f_name,
|
||||
all_node,
|
||||
basic_demand_pd,
|
||||
noise_type,
|
||||
noise_para,
|
||||
leak_pipe,
|
||||
leak_flow,
|
||||
):
|
||||
|
||||
wn.options.time.duration = 0
|
||||
pipe_roughness_change = add_noise_pd(pipe_coefficient, noise_type, pipe_c_noise)
|
||||
wn = change_para_of_wn(wn, pipe_roughness_change)
|
||||
record_pressure = pd.DataFrame(index=timestep_list, columns=sensor_name)
|
||||
record_flow = pd.DataFrame(index=timestep_list, columns=sensor_f_name)
|
||||
record_noise_all = pd.DataFrame(
|
||||
index=pd.MultiIndex.from_product([timestep_list, all_node]),
|
||||
columns=basic_demand_pd.columns,
|
||||
)
|
||||
record_noise_all = record_noise_all.sort_index()
|
||||
|
||||
# normal 获取添加噪声后的监测点数据
|
||||
for i in range(len(timestep_list)):
|
||||
wn, record_noise = change_node_demand(
|
||||
wn, basic_demand_pd, all_node, noise_type, noise_para
|
||||
)
|
||||
record_noise_all.loc[timestep_list[i]].loc[:, :] = record_noise
|
||||
pressure_temp, flow_temp = simple_simulation_pf(
|
||||
wn, sensor_name, sensor_f_name, [], []
|
||||
)
|
||||
record_pressure.iloc[i, :] = pressure_temp
|
||||
record_flow.iloc[i, :] = flow_temp
|
||||
|
||||
# leak_simulation 获取添加漏损后的监测点数据
|
||||
record_pressure_leak = pd.DataFrame(index=timestep_list, columns=sensor_name)
|
||||
record_flow_leak = pd.DataFrame(index=timestep_list, columns=sensor_f_name)
|
||||
# 改_wz_________________________________________
|
||||
# add leak
|
||||
wn, whole_inf, add_pipe1 = simple_add_leak(wn, leak_flow, leak_pipe)
|
||||
# simulation
|
||||
for i in range(len(timestep_list)):
|
||||
record_noise = record_noise_all.loc[timestep_list[i]]
|
||||
wn = change_node_demand_leak(wn, record_noise, all_node)
|
||||
pressure_temp, flow_temp = simple_simulation_pf(
|
||||
wn, sensor_name, sensor_f_name, leak_pipe, add_pipe1
|
||||
)
|
||||
record_pressure_leak.iloc[i, :] = pressure_temp
|
||||
record_flow_leak.iloc[i, :] = flow_temp
|
||||
# delete leak
|
||||
wn = simple_recover_wn(wn, whole_inf)
|
||||
return wn, record_pressure, record_flow, record_pressure_leak, record_flow_leak
|
||||
|
||||
|
||||
def change_node_demand(wn, basic_demand_pd, all_node, noise_type, noise_para):
|
||||
# 改_wz_____________________________________
|
||||
record_noise = pd.DataFrame(index=all_node, columns=basic_demand_pd.columns)
|
||||
for each_node in all_node:
|
||||
node = wn.get_node(each_node)
|
||||
num_columns = len(basic_demand_pd.columns)
|
||||
# 处理前N-1列(如果有)
|
||||
for i in range(num_columns - 1):
|
||||
# 获取原始值并添加噪声
|
||||
record_noise.loc[each_node].iloc[i] = (
|
||||
1 + produce_noise_number(noise_type, noise_para)
|
||||
) * basic_demand_pd.loc[each_node].iloc[i]
|
||||
node.demand_timeseries_list[i].base_value = record_noise.loc[
|
||||
each_node
|
||||
].iloc[i]
|
||||
# 处理最后一列(当列数>=1时)
|
||||
if num_columns >= 1:
|
||||
last_col = basic_demand_pd.columns[-1]
|
||||
original_last = basic_demand_pd.loc[each_node, last_col]
|
||||
record_noise.loc[each_node, last_col] = original_last
|
||||
node.demand_timeseries_list[-1].base_value = original_last
|
||||
|
||||
return wn, record_noise
|
||||
|
||||
|
||||
def change_node_demand_leak(wn, record_noise, all_node):
|
||||
sample_node = wn.get_node(all_node[0])
|
||||
# num_categories = len(sample_node.demand_timeseries_list)
|
||||
num_categories = 1
|
||||
for each in all_node:
|
||||
node = wn.get_node(each)
|
||||
for i in range(num_categories):
|
||||
node.demand_timeseries_list[i].base_value = record_noise.loc[each].iloc[i]
|
||||
return wn
|
||||
|
||||
|
||||
def change_para_of_wn(wn, pipe_roughness_change):
|
||||
for pipe_name, pipe in wn.pipes():
|
||||
pipe.roughness = pipe_roughness_change[pipe_name]
|
||||
return wn
|
||||
|
||||
|
||||
@@ -0,0 +1,858 @@
|
||||
"""相似性计算模块。"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def cal_similarity_simple_return_dd(
|
||||
similarity_mode,
|
||||
monitor_p,
|
||||
predict_p,
|
||||
normal_p,
|
||||
leak_p,
|
||||
monitor_p_all,
|
||||
predict_p_all,
|
||||
normal_p_all,
|
||||
leak_p_all,
|
||||
important_sensor,
|
||||
mean_dpressure,
|
||||
dpressure_std,
|
||||
dpressure_std_all,
|
||||
if_gy=0,
|
||||
cos_or_flow=1,
|
||||
):
|
||||
# cos_or_flow 用于 CAF
|
||||
dpressure_s = normal_p - leak_p
|
||||
dpressure = predict_p - monitor_p
|
||||
act_dpressure = pd.Series(dtype=object)
|
||||
for i in range(len(leak_p.index)):
|
||||
if dpressure_std.iloc[i] > -200: # 0.001:
|
||||
if if_gy == 1:
|
||||
act_dpressure[leak_p.index[i]] = (
|
||||
leak_p.iloc[i] - monitor_p.iloc[i]
|
||||
) / dpressure_std.iloc[i]
|
||||
else:
|
||||
act_dpressure[leak_p.index[i]] = leak_p.iloc[i] - monitor_p.iloc[i]
|
||||
|
||||
if similarity_mode == "COS" or (similarity_mode == "CAF" and cos_or_flow == 1):
|
||||
|
||||
"""if leak_p.min()<0:
|
||||
none_flag = 1
|
||||
similarity_cos = 0
|
||||
similarity_dis = 0
|
||||
else:"""
|
||||
none_flag = 0
|
||||
sensor_for_cos = sorted(
|
||||
set(dpressure_s.index).intersection(set(act_dpressure.index))
|
||||
)
|
||||
"""if len(dpressure_s) ==0 or len(dpressure) ==0:
|
||||
jj=9
|
||||
else:"""
|
||||
try:
|
||||
s1 = np.dot(
|
||||
np.transpose(dpressure_s.loc[sensor_for_cos]),
|
||||
dpressure.loc[sensor_for_cos],
|
||||
)
|
||||
s2 = np.linalg.norm(dpressure_s.loc[sensor_for_cos]) * np.linalg.norm(
|
||||
dpressure.loc[sensor_for_cos]
|
||||
)
|
||||
if s2 == 0:
|
||||
s2 = s2 + 0.0001
|
||||
similarity_cos = s1 / s2
|
||||
similarity_dis = 0
|
||||
except Exception as e:
|
||||
print(dpressure_s)
|
||||
print(sensor_for_cos)
|
||||
print(act_dpressure)
|
||||
print(dpressure_std)
|
||||
print(dpressure)
|
||||
|
||||
elif similarity_mode == "DIS" or (similarity_mode == "CAF" and cos_or_flow == 2):
|
||||
"""if leak_p.min()<0:
|
||||
none_flag = 1
|
||||
else:"""
|
||||
none_flag = 0
|
||||
important_sensor = sorted(
|
||||
set(important_sensor).intersection(set(act_dpressure.index))
|
||||
)
|
||||
part_dpressure = dpressure_s[important_sensor] - dpressure[important_sensor]
|
||||
similarity_pre_DIS = np.linalg.norm(part_dpressure)
|
||||
# similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
|
||||
similarity_dis = similarity_pre_DIS
|
||||
similarity_cos = 0
|
||||
|
||||
elif similarity_mode == "CAD_new":
|
||||
act_dpressure = leak_p - monitor_p
|
||||
"""if leak_p.min() < 0:
|
||||
none_flag = 1
|
||||
similarity_cos = 0
|
||||
similarity_dis =0
|
||||
else:"""
|
||||
none_flag = 0
|
||||
# cos
|
||||
s1 = np.dot(np.transpose(dpressure_s), dpressure)
|
||||
s2 = np.linalg.norm(dpressure_s) * np.linalg.norm(dpressure)
|
||||
|
||||
if s2 == 0:
|
||||
s2 = s2 + 0.0001
|
||||
similarity_cos = s1 / s2
|
||||
|
||||
# DIS
|
||||
part_dpressure = act_dpressure.loc[important_sensor]
|
||||
similarity_pre_DIS = np.linalg.norm(part_dpressure)
|
||||
similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
|
||||
similarity_dis = similarity_pre_DIS
|
||||
|
||||
elif similarity_mode == "CAD_new_gy" or similarity_mode == "CDF":
|
||||
# cos
|
||||
sensor_for_cos = sorted(
|
||||
set(dpressure_s.index).intersection(set(act_dpressure.index))
|
||||
)
|
||||
if len(sensor_for_cos) == 0 and len(dpressure_s) == 0:
|
||||
similarity_cos = 0
|
||||
elif len(sensor_for_cos) == 0 and len(dpressure_s) > 0:
|
||||
sensor_for_cos = list(dpressure_s.index)
|
||||
none_flag = 0
|
||||
s1 = np.dot(
|
||||
np.transpose(dpressure_s.loc[sensor_for_cos]),
|
||||
dpressure.loc[sensor_for_cos],
|
||||
)
|
||||
s2 = np.linalg.norm(dpressure_s.loc[sensor_for_cos]) * np.linalg.norm(
|
||||
dpressure.loc[sensor_for_cos]
|
||||
)
|
||||
|
||||
if s2 == 0:
|
||||
s2 = s2 + 0.0001
|
||||
similarity_cos = s1 / s2
|
||||
else:
|
||||
none_flag = 0
|
||||
s1 = np.dot(
|
||||
np.transpose(dpressure_s.loc[sensor_for_cos]),
|
||||
dpressure.loc[sensor_for_cos],
|
||||
)
|
||||
s2 = np.linalg.norm(dpressure_s.loc[sensor_for_cos]) * np.linalg.norm(
|
||||
dpressure.loc[sensor_for_cos]
|
||||
)
|
||||
|
||||
if s2 == 0:
|
||||
s2 = s2 + 0.0001
|
||||
similarity_cos = s1 / s2
|
||||
|
||||
# DIS
|
||||
important_sensor_new = sorted(
|
||||
set(important_sensor).intersection(set(act_dpressure.index))
|
||||
)
|
||||
if len(important_sensor_new) == 0:
|
||||
important_sensor_new = important_sensor
|
||||
act_dpressure = pd.Series(dtype=object)
|
||||
for i in range(len(leak_p_all.index)):
|
||||
# if dpressure_std.iloc [i] > -200: # 0.001:
|
||||
if if_gy == 1:
|
||||
act_dpressure[leak_p_all.index[i]] = (
|
||||
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
|
||||
) / dpressure_std_all.iloc[i]
|
||||
else:
|
||||
act_dpressure[leak_p_all.index[i]] = (
|
||||
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
|
||||
)
|
||||
# part_dpressure = act_dpressure.loc[important_sensor_new]
|
||||
|
||||
part_dpressure = (
|
||||
dpressure.loc[important_sensor_new] - dpressure_s.loc[important_sensor_new]
|
||||
)
|
||||
similarity_pre_DIS = np.linalg.norm(part_dpressure) ## chang test
|
||||
# part_dpressure = dpressure_s.loc[important_sensor]-dpressure.loc[important_sensor]
|
||||
# similarity_pre_DIS = np.linalg.norm(part_dpressure)
|
||||
# similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
|
||||
similarity_dis = similarity_pre_DIS
|
||||
elif similarity_mode == "OF":
|
||||
# cos
|
||||
similarity_cos = 0
|
||||
none_flag = 0
|
||||
# DIS
|
||||
important_sensor_new = sorted(
|
||||
set(important_sensor).intersection(set(act_dpressure.index))
|
||||
)
|
||||
if len(important_sensor_new) == 0:
|
||||
important_sensor_new = important_sensor
|
||||
act_dpressure = pd.Series(dtype=object)
|
||||
for i in range(len(leak_p_all.index)):
|
||||
# if dpressure_std.iloc [i] > -200: # 0.001:
|
||||
if if_gy == 1:
|
||||
act_dpressure[leak_p_all.index[i]] = (
|
||||
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
|
||||
) / dpressure_std_all.iloc[i]
|
||||
else:
|
||||
act_dpressure[leak_p_all.index[i]] = (
|
||||
leak_p_all.iloc[i] - monitor_p_all.iloc[i]
|
||||
)
|
||||
# part_dpressure = act_dpressure.loc[important_sensor_new]
|
||||
|
||||
part_dpressure = (
|
||||
dpressure.loc[important_sensor_new] - dpressure_s.loc[important_sensor_new]
|
||||
)
|
||||
similarity_pre_DIS = np.linalg.norm(part_dpressure) ## chang test
|
||||
# part_dpressure = dpressure_s.loc[important_sensor]-dpressure.loc[important_sensor]
|
||||
# similarity_pre_DIS = np.linalg.norm(part_dpressure)
|
||||
# similarity_pre_DIS_later = 1 / (1 + similarity_pre_DIS)
|
||||
similarity_dis = similarity_pre_DIS
|
||||
|
||||
return similarity_cos, similarity_dis, none_flag
|
||||
|
||||
|
||||
def adjust(
|
||||
similarity_cos,
|
||||
similarity_dis,
|
||||
record_success_candidate,
|
||||
record_success_no_candidate,
|
||||
):
|
||||
if len(record_success_no_candidate) > 0:
|
||||
for each in record_success_no_candidate:
|
||||
similarity_cos[each] = similarity_cos[record_success_candidate].min() * 0.9
|
||||
similarity_dis[each] = similarity_dis[record_success_candidate].max() * 1.1
|
||||
return similarity_cos, similarity_dis
|
||||
|
||||
|
||||
def cal_sq_all_multi(
|
||||
similarity_cos,
|
||||
similarity_dis,
|
||||
similarity_f,
|
||||
candidate_pipe,
|
||||
timestep_list_spc,
|
||||
if_flow,
|
||||
if_only_cos,
|
||||
if_only_flow,
|
||||
cos_h_input,
|
||||
dis_h_input,
|
||||
dis_f_h_input,
|
||||
if_compalsive,
|
||||
cos_sensor_num,
|
||||
flow_sensor_num,
|
||||
):
|
||||
"""融合多种相似性并输出按时刻与候选管段组织的综合相似度。
|
||||
|
||||
该函数会根据模式开关(是否仅流量、是否仅 COS、是否包含流量)对
|
||||
`similarity_cos`、`similarity_dis`、`similarity_f` 做标准化,并计算
|
||||
权重 `sq_cos/sq_dis/sq_f` 后进行加权融合。
|
||||
|
||||
Args:
|
||||
similarity_cos: 压力余弦相似性(DataFrame/Series,通常为时刻 x 候选管段)。
|
||||
similarity_dis: 压力距离相似性(DataFrame/Series,通常为时刻 x 候选管段)。
|
||||
similarity_f: 流量距离相似性(DataFrame/Series,通常为时刻 x 候选管段)。
|
||||
candidate_pipe: 候选管段列表,用于输出列索引。
|
||||
timestep_list_spc: 时刻列表,用于输出行索引。
|
||||
if_flow: 是否启用流量相似性(1 启用,0 禁用)。
|
||||
if_only_cos: 相似性模式标识(0: COS+DIS;1: COS;其他值按分支定义处理)。
|
||||
if_only_flow: 是否仅使用流量相似性(1 是,0 否)。
|
||||
cos_h_input: 外部给定的 COS 权重(强制权重模式下使用)。
|
||||
dis_h_input: 外部给定的 DIS 权重(强制权重模式下使用)。
|
||||
dis_f_h_input: 外部给定的流量权重(强制权重模式下使用)。
|
||||
if_compalsive: 是否使用外部强制权重(1 使用输入权重,0 自动计算权重)。
|
||||
cos_sensor_num: 压力传感器数量,用于权重调整。
|
||||
flow_sensor_num: 流量传感器数量,用于权重调整。
|
||||
|
||||
Returns:
|
||||
tuple[pd.DataFrame | pd.Series, float, float, float]:
|
||||
- output_similarity_pd: 综合相似性结果。
|
||||
- sq_cos: 最终 COS 权重。
|
||||
- sq_dis: 最终 DIS 权重。
|
||||
- sq_f: 最终流量权重。
|
||||
"""
|
||||
if if_only_flow == 1:
|
||||
similarity_f, h_f = cal_sq_single_array(
|
||||
similarity_f.values.reshape((-1, 1)), if_direct=2
|
||||
)
|
||||
sq_cos = 0
|
||||
sq_dis = 0
|
||||
sq_f = 1
|
||||
similarity_all = similarity_f * sq_f
|
||||
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
|
||||
output_similarity_pd = pd.DataFrame(
|
||||
output_similarity, index=timestep_list_spc, columns=candidate_pipe
|
||||
)
|
||||
else:
|
||||
if if_only_cos == 0:
|
||||
if if_flow == 1:
|
||||
# standerdize
|
||||
similarity_cos, h_cos = cal_sq_single_array(
|
||||
similarity_cos.values.reshape((-1, 1)), if_direct=1
|
||||
)
|
||||
similarity_dis, h_dis = cal_sq_single_array(
|
||||
similarity_dis.values.reshape((-1, 1)), if_direct=2
|
||||
)
|
||||
similarity_f, h_f = cal_sq_single_array(
|
||||
similarity_f.values.reshape((-1, 1)), if_direct=2
|
||||
)
|
||||
if if_compalsive == 1:
|
||||
sq_cos = cos_h_input
|
||||
sq_dis = dis_h_input
|
||||
sq_f = dis_f_h_input
|
||||
else:
|
||||
"""sq_cos = h_cos/(h_cos +h_dis +h_f )
|
||||
sq_dis = h_dis/(h_cos +h_dis +h_f )
|
||||
sq_f = h_f/(h_cos +h_dis +h_f )"""
|
||||
sq_cos, sq_dis, sq_f = add_weight_for_SQ(
|
||||
h_cos, h_dis, h_f, cos_sensor_num, flow_sensor_num
|
||||
)
|
||||
|
||||
"""if cos_sensor_num == 2 and sq_cos>0.2:
|
||||
sq_cos = 0.2
|
||||
sq_dis = 0.8*h_dis / (h_dis + h_f)
|
||||
sq_f = 0.8*h_f / (h_dis + h_f)
|
||||
if cos_sensor_num == 1 and sq_dis > 0.3:
|
||||
sq_cos = 0.1
|
||||
sq_dis = 0.3
|
||||
sq_f = 0.6"""
|
||||
sq_cos, sq_dis, sq_f = adjust_ratio("CDF", sq_cos, sq_dis, sq_f)
|
||||
if cos_sensor_num <= 1:
|
||||
sq_cos = 0
|
||||
# similarity
|
||||
|
||||
similarity_all = (
|
||||
similarity_cos * sq_cos
|
||||
+ similarity_dis * sq_dis
|
||||
+ similarity_f * sq_f
|
||||
)
|
||||
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
|
||||
output_similarity_pd = pd.DataFrame(
|
||||
output_similarity, index=timestep_list_spc, columns=candidate_pipe
|
||||
)
|
||||
else:
|
||||
# standerdize
|
||||
similarity_cos, h_cos = cal_sq_single_array(
|
||||
similarity_cos.values.reshape((-1, 1)), if_direct=1
|
||||
)
|
||||
similarity_dis, h_dis = cal_sq_single_array(
|
||||
similarity_dis.values.reshape((-1, 1)), if_direct=2
|
||||
)
|
||||
if if_compalsive == 1:
|
||||
sq_cos = cos_h_input
|
||||
sq_dis = dis_h_input
|
||||
else:
|
||||
sq_cos = h_cos / (h_cos + h_dis)
|
||||
sq_dis = h_dis / (h_cos + h_dis)
|
||||
if cos_sensor_num == 2 and sq_cos > 0.5:
|
||||
sq_cos = 0.5
|
||||
sq_dis = 0.5
|
||||
sq_cos, sq_dis, sq_f = adjust_ratio("CAD_new_gy", sq_cos, sq_dis, 0)
|
||||
sq_f = 0
|
||||
# similarity
|
||||
similarity_all = similarity_cos * sq_cos + similarity_dis * sq_dis
|
||||
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
|
||||
output_similarity_pd = pd.DataFrame(
|
||||
output_similarity, index=timestep_list_spc, columns=candidate_pipe
|
||||
)
|
||||
elif if_only_cos == 1:
|
||||
if if_flow == 1:
|
||||
# standerdize
|
||||
similarity_cos, h_cos = cal_sq_single_array(
|
||||
similarity_cos.values.reshape((-1, 1)), if_direct=1
|
||||
)
|
||||
similarity_f, h_f = cal_sq_single_array(
|
||||
similarity_f.values.reshape((-1, 1)), if_direct=2
|
||||
)
|
||||
if if_compalsive == 1:
|
||||
sq_cos = cos_h_input
|
||||
sq_f = dis_f_h_input
|
||||
else:
|
||||
sq_cos = h_cos / (h_cos + h_f)
|
||||
sq_f = h_f / (h_cos + h_f)
|
||||
sq_cos, sq_dis, sq_f = adjust_ratio("CAF", sq_cos, 0, sq_f)
|
||||
sq_dis = 0
|
||||
# similarity
|
||||
similarity_all = similarity_cos * sq_cos + similarity_f * sq_f
|
||||
output_similarity = similarity_all.reshape((-1, len(candidate_pipe)))
|
||||
output_similarity_pd = pd.DataFrame(
|
||||
output_similarity, index=timestep_list_spc, columns=candidate_pipe
|
||||
)
|
||||
|
||||
else:
|
||||
sq_cos = cos_h_input
|
||||
sq_dis = dis_h_input
|
||||
sq_f = dis_f_h_input
|
||||
output_similarity_pd = similarity_cos
|
||||
else:
|
||||
sq_cos = cos_h_input
|
||||
sq_dis = dis_h_input
|
||||
sq_f = dis_f_h_input
|
||||
output_similarity_pd = 1 / (similarity_dis + 1)
|
||||
return output_similarity_pd, sq_cos, sq_dis, sq_f
|
||||
|
||||
|
||||
def add_weight_for_SQ(h_cos, h_dis, h_f, sensor_cos_num, sensor_f_num):
|
||||
h_f_new = h_f * sensor_f_num
|
||||
if sensor_cos_num <= 1:
|
||||
h_cos_new = 0
|
||||
h_dis_new = h_dis * sensor_cos_num
|
||||
else:
|
||||
h_cos_new = h_cos * sensor_cos_num # / 2
|
||||
h_dis_new = h_dis * sensor_cos_num # / 2
|
||||
cos_sq = h_cos_new / (h_cos_new + h_dis_new + h_f_new)
|
||||
dis_sq = h_dis_new / (h_cos_new + h_dis_new + h_f_new)
|
||||
f_sq = h_f_new / (h_cos_new + h_dis_new + h_f_new)
|
||||
|
||||
if sensor_cos_num == 2 and cos_sq > 0.2:
|
||||
cos_sq = 0.2
|
||||
dis_sq = 0.8 * h_dis_new / (h_dis_new + h_f_new)
|
||||
f_sq = 0.8 * h_f_new / (h_dis_new + h_f_new)
|
||||
"""if sensor_cos_num == 1:
|
||||
if dis_sq / f_sq > sensor_cos_num/sensor_f_num:
|
||||
dis_sq = sensor_cos_num/sensor_f_num
|
||||
f_sq=1-dis_sq"""
|
||||
# if h_dis_new/h_f_new > sensor_cos_num/sensor_f_num
|
||||
return cos_sq, dis_sq, f_sq
|
||||
|
||||
|
||||
def cal_sq_single_array(similarity_pre, if_direct):
|
||||
if similarity_pre.max() - similarity_pre.min() == 0:
|
||||
similarity_pre = np.ones(similarity_pre.shape)
|
||||
else:
|
||||
if if_direct == 1:
|
||||
similarity_pre = (
|
||||
0.998
|
||||
* (similarity_pre - similarity_pre.min())
|
||||
/ (similarity_pre.max() - similarity_pre.min())
|
||||
+ 0.002
|
||||
)
|
||||
else:
|
||||
similarity_pre = (
|
||||
0.998
|
||||
* (similarity_pre.max() - similarity_pre)
|
||||
/ (similarity_pre.max() - similarity_pre.min())
|
||||
+ 0.002
|
||||
)
|
||||
# calculate pij
|
||||
similarity_p = similarity_pre / similarity_pre.sum()
|
||||
# cal xinxishang
|
||||
similarity_lnp = np.zeros((len(similarity_pre), 1))
|
||||
for j in range(len(similarity_p)):
|
||||
similarity_lnp[j] = -similarity_p[j] * math.log(similarity_p[j], math.e)
|
||||
h = 1 - 1 / math.log(len(similarity_pre), math.e) * similarity_lnp.sum()
|
||||
return similarity_pre, h
|
||||
|
||||
|
||||
def cal_similarity_all_multi_new_sq_improve_double_lzr(
|
||||
candidate_pipe,
|
||||
similarity_mode,
|
||||
pressure_leak,
|
||||
monitor_p,
|
||||
predict_p,
|
||||
normal_p,
|
||||
if_flow,
|
||||
if_only_cos,
|
||||
if_only_flow,
|
||||
flow_leak,
|
||||
monitor_f,
|
||||
predict_f,
|
||||
normal_f,
|
||||
timestep_list,
|
||||
Top_sensor_num,
|
||||
if_gy,
|
||||
effective_sensor,
|
||||
cos_h,
|
||||
dis_h,
|
||||
dis_f_h,
|
||||
if_compalsive,
|
||||
max_flow,
|
||||
):
|
||||
similarity = pd.Series(dtype=float, index=candidate_pipe)
|
||||
similarity_detail: pd.DataFrame | None = None
|
||||
important_p_sensor = cal_top_sensors(monitor_p, predict_p, Top_sensor_num)
|
||||
# important_f_sensor, basic_f = cal_top_f_sensor(normal_f)
|
||||
important_f_sensor = monitor_f.columns
|
||||
if (
|
||||
len(important_p_sensor) > 0 or len(important_f_sensor) > 0
|
||||
): # if len(important_p_sensor) > 0
|
||||
break_flag = 0
|
||||
pressure_leak_new = pressure_leak.swaplevel()
|
||||
# flow_leak_new = flow_leak.swaplevel()
|
||||
if isinstance(flow_leak, pd.DataFrame) and len(flow_leak) > 0:
|
||||
flow_leak_new = flow_leak.swaplevel()
|
||||
else:
|
||||
flow_leak_new = None
|
||||
total_similarity_cos = pd.DataFrame(index=timestep_list, columns=candidate_pipe)
|
||||
total_similarity_dis = pd.DataFrame(index=timestep_list, columns=candidate_pipe)
|
||||
total_similarity_dis_f = pd.DataFrame(
|
||||
index=timestep_list, columns=candidate_pipe
|
||||
)
|
||||
|
||||
for timestep in timestep_list:
|
||||
# cal p_cos, p_dis, f_dis
|
||||
if if_only_flow != 1:
|
||||
pressure_leak_temp = pressure_leak_new.loc[timestep].loc[
|
||||
:, effective_sensor
|
||||
]
|
||||
monitor_p_temp = monitor_p.loc[timestep, effective_sensor]
|
||||
predict_p_temp = predict_p.loc[timestep, effective_sensor]
|
||||
normal_p_temp = normal_p.loc[timestep, effective_sensor]
|
||||
|
||||
(
|
||||
total_similarity_cos.loc[timestep, :],
|
||||
total_similarity_dis.loc[timestep, :],
|
||||
) = cal_similarity_all_cos_dis(
|
||||
candidate_pipe,
|
||||
pressure_leak_temp,
|
||||
similarity_mode,
|
||||
monitor_p_temp,
|
||||
predict_p_temp,
|
||||
normal_p_temp,
|
||||
pressure_leak_new.loc[timestep].loc[:, monitor_p.columns],
|
||||
monitor_p.loc[timestep, :],
|
||||
predict_p.loc[timestep, :],
|
||||
normal_p.loc[timestep, :],
|
||||
important_p_sensor,
|
||||
if_gy,
|
||||
cos_or_flow=1,
|
||||
)
|
||||
|
||||
if if_flow == 1:
|
||||
if len(timestep_list) == 1:
|
||||
leak_f_temp = flow_leak_new.loc[timestep].loc[:, important_f_sensor]
|
||||
monitor_f_temp = monitor_f.loc[timestep, important_f_sensor]
|
||||
predict_f_temp = predict_f.loc[timestep, important_f_sensor]
|
||||
normal_f_temp = normal_f.loc[timestep, important_f_sensor]
|
||||
basic_normal_f_temp = abs(max_flow.loc[important_f_sensor])
|
||||
|
||||
leak_f_temp = leak_f_temp / basic_normal_f_temp
|
||||
monitor_f_temp = monitor_f_temp / basic_normal_f_temp
|
||||
predict_f_temp = predict_f_temp / basic_normal_f_temp
|
||||
normal_f_temp = normal_f_temp / basic_normal_f_temp
|
||||
|
||||
else:
|
||||
basic_f = abs(max_flow.loc[important_f_sensor])
|
||||
leak_f_temp = (
|
||||
flow_leak_new.loc[timestep].loc[:, important_f_sensor] / basic_f
|
||||
)
|
||||
monitor_f_temp = (
|
||||
monitor_f.loc[timestep, important_f_sensor] / basic_f
|
||||
)
|
||||
predict_f_temp = (
|
||||
predict_f.loc[timestep, important_f_sensor] / basic_f
|
||||
)
|
||||
normal_f_temp = normal_f.loc[timestep, important_f_sensor] / basic_f
|
||||
_, total_similarity_dis_f.loc[timestep, :] = cal_similarity_all_cos_dis(
|
||||
candidate_pipe,
|
||||
leak_f_temp,
|
||||
similarity_mode,
|
||||
monitor_f_temp,
|
||||
predict_f_temp,
|
||||
normal_f_temp,
|
||||
flow_leak_new.loc[timestep].loc[:, monitor_f.columns],
|
||||
monitor_f.loc[timestep, :],
|
||||
predict_f.loc[timestep, :],
|
||||
normal_f.loc[timestep, :],
|
||||
important_f_sensor,
|
||||
if_gy,
|
||||
cos_or_flow=2,
|
||||
)
|
||||
else:
|
||||
total_similarity_dis_f = []
|
||||
similarity_all, cos_h, dis_h, dis_f_h = cal_sq_all_multi(
|
||||
total_similarity_cos,
|
||||
total_similarity_dis,
|
||||
total_similarity_dis_f,
|
||||
candidate_pipe,
|
||||
timestep_list,
|
||||
if_flow,
|
||||
if_only_cos,
|
||||
if_only_flow,
|
||||
cos_h,
|
||||
dis_h,
|
||||
dis_f_h,
|
||||
if_compalsive,
|
||||
len(important_p_sensor),
|
||||
len(important_f_sensor),
|
||||
)
|
||||
if len(timestep_list) == 1:
|
||||
similarity = similarity_all.iloc[0]
|
||||
elif len(timestep_list) > 3:
|
||||
for each_candidate in candidate_pipe:
|
||||
similarity[each_candidate] = remove_3_sigma(
|
||||
similarity_all.loc[:, each_candidate]
|
||||
)
|
||||
else:
|
||||
for each_candidate in candidate_pipe:
|
||||
similarity[each_candidate] = similarity_all.loc[
|
||||
:, each_candidate
|
||||
].mean()
|
||||
similarity = similarity.sort_values(ascending=False, kind="mergesort")
|
||||
detail_index = [str(pipe) for pipe in candidate_pipe]
|
||||
similarity_detail = pd.DataFrame(index=detail_index)
|
||||
similarity_detail.index.name = "pipe_id"
|
||||
if isinstance(total_similarity_cos, pd.DataFrame) and len(total_similarity_cos) > 0:
|
||||
pressure_cos_mean = (
|
||||
total_similarity_cos.mean(axis=0)
|
||||
.reindex(candidate_pipe)
|
||||
.to_numpy(dtype=float)
|
||||
)
|
||||
else:
|
||||
pressure_cos_mean = np.full(len(candidate_pipe), np.nan)
|
||||
if isinstance(total_similarity_dis, pd.DataFrame) and len(total_similarity_dis) > 0:
|
||||
pressure_dis_mean = (
|
||||
total_similarity_dis.mean(axis=0)
|
||||
.reindex(candidate_pipe)
|
||||
.to_numpy(dtype=float)
|
||||
)
|
||||
else:
|
||||
pressure_dis_mean = np.full(len(candidate_pipe), np.nan)
|
||||
if isinstance(total_similarity_dis_f, pd.DataFrame) and len(total_similarity_dis_f) > 0:
|
||||
flow_dis_mean = (
|
||||
total_similarity_dis_f.mean(axis=0)
|
||||
.reindex(candidate_pipe)
|
||||
.to_numpy(dtype=float)
|
||||
)
|
||||
else:
|
||||
flow_dis_mean = np.full(len(candidate_pipe), np.nan)
|
||||
similarity_detail["pressure_cos_mean"] = pressure_cos_mean
|
||||
similarity_detail["pressure_dis_mean"] = pressure_dis_mean
|
||||
similarity_detail["flow_dis_mean"] = flow_dis_mean
|
||||
similarity_detail["weight_cos"] = float(cos_h)
|
||||
similarity_detail["weight_dis"] = float(dis_h)
|
||||
similarity_detail["weight_flow"] = float(dis_f_h)
|
||||
similarity_detail["final_similarity"] = (
|
||||
similarity.reindex(candidate_pipe).to_numpy(dtype=float)
|
||||
)
|
||||
similarity_detail["similarity_rank"] = (
|
||||
similarity_detail["final_similarity"].rank(method="dense", ascending=False)
|
||||
).astype(int)
|
||||
similarity_detail["pressure_sensor_count"] = int(len(important_p_sensor))
|
||||
similarity_detail["flow_sensor_count"] = int(len(important_f_sensor))
|
||||
similarity_detail = similarity_detail.sort_values(
|
||||
by="final_similarity", ascending=False, kind="mergesort"
|
||||
)
|
||||
else:
|
||||
break_flag = 1
|
||||
similarity = 0
|
||||
cos_h = 0
|
||||
dis_h = 0
|
||||
dis_f_h = 0
|
||||
return similarity, cos_h, dis_h, dis_f_h, break_flag, similarity_detail
|
||||
|
||||
|
||||
def cal_similarity_all_cos_dis(
|
||||
candidate_pipe,
|
||||
pressure_leak,
|
||||
similarity_mode,
|
||||
monitor_p,
|
||||
predict_p,
|
||||
normal_p,
|
||||
pressure_leak_all,
|
||||
monitor_p_all,
|
||||
predict_p_all,
|
||||
normal_p_all,
|
||||
important_sensor,
|
||||
if_gy,
|
||||
cos_or_flow,
|
||||
):
|
||||
similarity_cos = pd.Series(dtype=float, index=candidate_pipe)
|
||||
similarity_dis = pd.Series(dtype=float, index=candidate_pipe)
|
||||
dpressure = normal_p - pressure_leak
|
||||
|
||||
# 无用 ----------------------------------------------
|
||||
mean_dpressure = dpressure.mean()
|
||||
|
||||
monitor_new = pd.DataFrame(index=["monitor"], columns=monitor_p.index)
|
||||
monitor_new.iloc[0] = monitor_p
|
||||
add_m_leak_pressure = [pressure_leak, monitor_p]
|
||||
add_m_leak_pressure = pd.concat(add_m_leak_pressure)
|
||||
pressure_leak_std = add_m_leak_pressure.std(axis=0, ddof=1)
|
||||
pressure_leak_std = pd.Series(pressure_leak_std, index=pressure_leak.columns)
|
||||
|
||||
add_m_leak_pressure_all = [pressure_leak_all, monitor_p_all]
|
||||
add_m_leak_pressure_all = pd.concat(add_m_leak_pressure_all)
|
||||
pressure_leak_std_all = add_m_leak_pressure_all.std(axis=0, ddof=1)
|
||||
pressure_leak_std_all = pd.Series(
|
||||
pressure_leak_std_all, index=pressure_leak.columns
|
||||
)
|
||||
# 无用 ----------------------------------------------
|
||||
|
||||
monitor_p_temp = monitor_p
|
||||
predict_p_temp = predict_p
|
||||
normal_p_temp = normal_p
|
||||
monitor_p_temp_all = monitor_p_all
|
||||
predict_p_temp_all = predict_p_all
|
||||
normal_p_temp_all = normal_p_all
|
||||
record_success_candidate = []
|
||||
record_success_no_candidate = []
|
||||
for i in range(len(candidate_pipe)):
|
||||
leak_p = pressure_leak.iloc[i, :]
|
||||
leak_p_all = pressure_leak_all.iloc[i, :]
|
||||
similarity_cos.iloc[i], similarity_dis.iloc[i], none_flag = (
|
||||
cal_similarity_simple_return_dd(
|
||||
similarity_mode,
|
||||
monitor_p_temp,
|
||||
predict_p_temp,
|
||||
normal_p_temp,
|
||||
leak_p,
|
||||
monitor_p_temp_all,
|
||||
predict_p_temp_all,
|
||||
normal_p_temp_all,
|
||||
leak_p_all,
|
||||
important_sensor,
|
||||
mean_dpressure,
|
||||
pressure_leak_std,
|
||||
pressure_leak_std_all,
|
||||
if_gy,
|
||||
cos_or_flow,
|
||||
)
|
||||
)
|
||||
if none_flag == 0:
|
||||
record_success_candidate.append(candidate_pipe[i])
|
||||
else:
|
||||
record_success_no_candidate.append(candidate_pipe[i])
|
||||
similarity_cos, similarity_dis = adjust(
|
||||
similarity_cos,
|
||||
similarity_dis,
|
||||
record_success_candidate,
|
||||
record_success_no_candidate,
|
||||
)
|
||||
return similarity_cos, similarity_dis
|
||||
|
||||
|
||||
def cal_top_f_sensor(normal_f):
|
||||
if type(normal_f) == pd.core.frame.DataFrame:
|
||||
mean_f = normal_f.mean()
|
||||
else:
|
||||
mean_f = normal_f
|
||||
output_sensor = []
|
||||
output_normal_f = pd.Series(dtype=object)
|
||||
for i in range(len(mean_f.index)):
|
||||
if abs(mean_f.iloc[i]) > 0.01 / 3600:
|
||||
output_sensor.append(mean_f.index[i])
|
||||
output_normal_f[mean_f.index[i]] = mean_f.iloc[i]
|
||||
return output_sensor, output_normal_f
|
||||
|
||||
|
||||
def cal_top_sensors(monitor_p, predict_p, Top_sensor_num):
|
||||
dpressure = abs(predict_p - monitor_p)
|
||||
if type(dpressure) == pd.core.frame.DataFrame:
|
||||
dpressure = dpressure.mean()
|
||||
dpressure_rank = dpressure.sort_values(ascending=False, kind="mergesort")
|
||||
return list(dpressure_rank.index[:Top_sensor_num])
|
||||
|
||||
|
||||
def remove_3_sigma(similarity_t):
|
||||
all_sample = len(similarity_t.index)
|
||||
apart_sample = math.ceil(all_sample * 0.6)
|
||||
similarity = similarity_t.astype("float")
|
||||
mean_t = similarity.mean()
|
||||
std_t = similarity.std()
|
||||
new_similarity = similarity[
|
||||
(similarity <= mean_t + 3 * std_t) & (similarity >= mean_t - 3 * std_t)
|
||||
]
|
||||
mean_t_new = new_similarity.mean()
|
||||
return mean_t_new
|
||||
|
||||
|
||||
def update_similarity(leak_candidate_center, similarity, leak_center_dict):
|
||||
similarity_new = pd.Series(dtype=float)
|
||||
for each_center in leak_candidate_center:
|
||||
houxuan_center = leak_center_dict[each_center]
|
||||
if len(houxuan_center) > 1:
|
||||
temp_similarity = similarity[houxuan_center]
|
||||
similarity_new[each_center] = temp_similarity.max()
|
||||
else:
|
||||
if type(similarity[each_center]) == pd.core.series.Series:
|
||||
similarity_new[each_center] = similarity[each_center].mean()
|
||||
else:
|
||||
similarity_new[each_center] = similarity[each_center]
|
||||
similarity_new = similarity_new.sort_values(ascending=False, kind="mergesort")
|
||||
return similarity_new
|
||||
|
||||
|
||||
def extra_judge(
|
||||
similarity, min_candidates_to_prune: int = 200, std_relax_factor: float = 0.5
|
||||
):
|
||||
if len(similarity.index) == 0:
|
||||
return 1.0, similarity
|
||||
if len(similarity.index) < int(min_candidates_to_prune):
|
||||
return 1.0, similarity
|
||||
|
||||
mean_similarity = float(similarity.mean())
|
||||
std_similarity = float(similarity.std())
|
||||
if not math.isfinite(std_similarity):
|
||||
std_similarity = 0.0
|
||||
|
||||
threshold = mean_similarity - float(std_relax_factor) * std_similarity
|
||||
out_put_similarity = similarity[similarity >= threshold - 1e-10]
|
||||
if len(out_put_similarity.index) == 0:
|
||||
out_put_similarity = similarity.iloc[:1]
|
||||
|
||||
cut_ratio = len(out_put_similarity.index) / len(similarity.index)
|
||||
return cut_ratio, out_put_similarity
|
||||
|
||||
|
||||
def adjust_ratio(similarity_mode, cos_h, dis_h, dis_f_h, low_limit=0.1):
|
||||
if similarity_mode == "CAF":
|
||||
if cos_h < low_limit:
|
||||
cos_h = low_limit
|
||||
dis_f_h = 1 - cos_h
|
||||
elif dis_f_h < low_limit:
|
||||
dis_f_h = low_limit
|
||||
cos_h = 1 - dis_f_h
|
||||
elif similarity_mode == "CAD_new_gy":
|
||||
if dis_h < low_limit:
|
||||
dis_h = low_limit
|
||||
cos_h = 1 - dis_h
|
||||
elif cos_h < low_limit:
|
||||
cos_h = low_limit
|
||||
dis_h = 1 - cos_h
|
||||
elif similarity_mode == "CDF":
|
||||
normal_index = [0, 1, 2]
|
||||
h_list = [cos_h, dis_h, dis_f_h]
|
||||
if cos_h < low_limit:
|
||||
h_list[0] = low_limit
|
||||
normal_index.remove(0)
|
||||
if dis_h < low_limit:
|
||||
h_list[1] = low_limit
|
||||
normal_index.remove(1)
|
||||
if dis_f_h < low_limit:
|
||||
h_list[2] = low_limit
|
||||
normal_index.remove(2)
|
||||
|
||||
if len(normal_index) == 1:
|
||||
h_list[normal_index[0]] = h_list[normal_index[0]] - (sum(h_list) - 1)
|
||||
elif len(normal_index) == 2:
|
||||
sum_list = sum(h_list)
|
||||
multiper = 1 - (sum_list - 1) / (
|
||||
h_list[normal_index[0]] + h_list[normal_index[1]]
|
||||
)
|
||||
h_list[normal_index[0]] = h_list[normal_index[0]] * multiper
|
||||
h_list[normal_index[1]] = h_list[normal_index[1]] * multiper
|
||||
|
||||
cos_h, dis_h, dis_f_h = h_list[0], h_list[1], h_list[2]
|
||||
|
||||
return cos_h, dis_h, dis_f_h
|
||||
|
||||
|
||||
# 返回相似性计算的模式(不同权重),是否计算流量相似性,是否只计算cos相似性,是否只计算流量相似性。
|
||||
def decode_mode(similarity_mode):
|
||||
if similarity_mode == "COS":
|
||||
if_flow = 0
|
||||
if_only_cos = 1
|
||||
if_only_flow = 0
|
||||
elif similarity_mode == "CAD_new_gy":
|
||||
if_flow = 0
|
||||
if_only_cos = 0
|
||||
if_only_flow = 0
|
||||
elif similarity_mode == "CDF":
|
||||
if_flow = 1
|
||||
if_only_cos = 0
|
||||
if_only_flow = 0
|
||||
elif similarity_mode == "CAF":
|
||||
if_flow = 1
|
||||
if_only_cos = 1
|
||||
if_only_flow = 0
|
||||
elif similarity_mode == "DIS":
|
||||
if_flow = 1
|
||||
if_only_cos = 2
|
||||
if_only_flow = 0
|
||||
elif similarity_mode == "OF":
|
||||
if_flow = 1
|
||||
if_only_cos = 0
|
||||
if_only_flow = 1
|
||||
return if_flow, if_only_cos, if_only_flow
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
import app.algorithms.api_ex.flow_data_clean as flow_data_clean
|
||||
import app.algorithms.api_ex.pressure_data_clean as pressure_data_clean
|
||||
from app.algorithms.cleaning import flow as _flow_module
|
||||
from app.algorithms.cleaning import pressure as _pressure_module
|
||||
|
||||
|
||||
############################################################
|
||||
@@ -19,14 +19,15 @@ def flow_data_clean(input_csv_file: str) -> str:
|
||||
"""
|
||||
|
||||
# 提供的 input_csv_path 绝对路径,以下为 默认脚本目录下同名 CSV 文件,构建绝对路径,可根据情况修改
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 使用 algorithms 根目录保持与原 data_cleaning.py 一致的行为
|
||||
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
input_csv_path = os.path.join(script_dir, input_csv_file)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(input_csv_path):
|
||||
raise FileNotFoundError(f"指定的文件不存在: {input_csv_path}")
|
||||
# 调用 Fdataclean.clean_flow_data_kf 函数进行数据清洗
|
||||
out_xlsx_path = flow_data_clean.clean_flow_data_kf(input_csv_path)
|
||||
# 调用 clean_flow_data_kf 函数进行数据清洗
|
||||
out_xlsx_path = _flow_module.clean_flow_data_kf(input_csv_path)
|
||||
print("清洗后的数据已保存到:", out_xlsx_path)
|
||||
|
||||
|
||||
@@ -46,12 +47,13 @@ def pressure_data_clean(input_csv_file: str) -> str:
|
||||
"""
|
||||
|
||||
# 提供的 input_csv_path 绝对路径,以下为 默认脚本目录下同名 CSV 文件,构建绝对路径,可根据情况修改
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# 使用 algorithms 根目录保持与原 data_cleaning.py 一致的行为
|
||||
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
input_csv_path = os.path.join(script_dir, input_csv_file)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(input_csv_path):
|
||||
raise FileNotFoundError(f"指定的文件不存在: {input_csv_path}")
|
||||
# 调用 Fdataclean.clean_flow_data_kf 函数进行数据清洗
|
||||
out_xlsx_path = pressure_data_clean.clean_pressure_data_km(input_csv_path)
|
||||
# 调用 clean_pressure_data_km 函数进行数据清洗
|
||||
out_xlsx_path = _pressure_module.clean_pressure_data_km(input_csv_path)
|
||||
print("清洗后的数据已保存到:", out_xlsx_path)
|
||||
@@ -1,83 +1,10 @@
|
||||
# ...existing code...
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pykalman import KalmanFilter
|
||||
import os
|
||||
|
||||
|
||||
def fill_time_gaps(
|
||||
data: pd.DataFrame,
|
||||
time_col: str = "time",
|
||||
freq: str = "1min",
|
||||
short_gap_threshold: int = 10,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
补齐缺失时间戳并填补数据缺口。
|
||||
|
||||
Args:
|
||||
data: 包含时间列的 DataFrame
|
||||
time_col: 时间列名(默认 'time')
|
||||
freq: 重采样频率(默认 '1min')
|
||||
short_gap_threshold: 短缺口阈值(分钟),<=此值用线性插值,>此值用前向填充
|
||||
|
||||
Returns:
|
||||
补齐时间后的 DataFrame(保留原时间列格式)
|
||||
"""
|
||||
if time_col not in data.columns:
|
||||
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
|
||||
|
||||
# 解析时间列并设为索引
|
||||
data = data.copy()
|
||||
data[time_col] = pd.to_datetime(data[time_col], utc=True)
|
||||
data_indexed = data.set_index(time_col)
|
||||
|
||||
# 生成完整时间范围
|
||||
full_range = pd.date_range(
|
||||
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
|
||||
)
|
||||
|
||||
# 重索引以补齐缺失时间点,同时保留原始时间戳
|
||||
combined_index = data_indexed.index.union(full_range).sort_values().unique()
|
||||
data_reindexed = data_indexed.reindex(combined_index)
|
||||
|
||||
# 按列处理缺口
|
||||
for col in data_reindexed.columns:
|
||||
# 识别缺失值位置
|
||||
is_missing = data_reindexed[col].isna()
|
||||
|
||||
# 计算连续缺失的长度
|
||||
missing_groups = (is_missing != is_missing.shift()).cumsum()
|
||||
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
|
||||
|
||||
# 短缺口:时间插值
|
||||
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
|
||||
if short_gap_mask.any():
|
||||
data_reindexed.loc[short_gap_mask, col] = (
|
||||
data_reindexed[col]
|
||||
.interpolate(method="time", limit_area="inside")
|
||||
.loc[short_gap_mask]
|
||||
)
|
||||
|
||||
# 长缺口:前向填充
|
||||
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
|
||||
if long_gap_mask.any():
|
||||
data_reindexed.loc[long_gap_mask, col] = (
|
||||
data_reindexed[col].ffill().loc[long_gap_mask]
|
||||
)
|
||||
|
||||
# 重置索引并恢复时间列(保留原格式)
|
||||
data_result = data_reindexed.reset_index()
|
||||
data_result.rename(columns={"index": time_col}, inplace=True)
|
||||
|
||||
# 保留时区信息
|
||||
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
# 修正时区格式(Python的%z输出为+0000,需转为+00:00)
|
||||
data_result[time_col] = data_result[time_col].str.replace(
|
||||
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
|
||||
)
|
||||
|
||||
return data_result
|
||||
from app.algorithms._utils import fill_time_gaps
|
||||
|
||||
|
||||
def clean_flow_data_kf(
|
||||
@@ -0,0 +1,579 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
|
||||
ID_LIKE_COLUMNS = {
|
||||
"id",
|
||||
"device_id",
|
||||
"node_id",
|
||||
"sensor_id",
|
||||
"monitor_id",
|
||||
"junction_id",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_time_frame(data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""返回按时间排序的副本,并尽量将 time 列解析为时间类型。"""
|
||||
data = data.copy()
|
||||
if "time" in data.columns:
|
||||
data["time"] = pd.to_datetime(data["time"], errors="coerce")
|
||||
data = data.sort_values(["time"]).reset_index(drop=True)
|
||||
return data
|
||||
|
||||
|
||||
def _select_pressure_columns(data: pd.DataFrame) -> tuple[list[str], list[str]]:
|
||||
"""区分需要清洗的数值列与需要原样保留的列。"""
|
||||
value_cols: list[str] = []
|
||||
keep_cols: list[str] = []
|
||||
for col in data.columns:
|
||||
if col == "time":
|
||||
continue
|
||||
col_key = col.lower()
|
||||
if col_key in ID_LIKE_COLUMNS or col_key.endswith("_id"):
|
||||
keep_cols.append(col)
|
||||
continue
|
||||
numeric = pd.to_numeric(data[col], errors="coerce")
|
||||
if numeric.notna().sum() == 0 or numeric.nunique(dropna=True) <= 1:
|
||||
keep_cols.append(col)
|
||||
else:
|
||||
value_cols.append(col)
|
||||
return value_cols, keep_cols
|
||||
|
||||
|
||||
def _robust_scale(values: pd.Series) -> float:
|
||||
"""基于 MAD 计算稳健尺度。"""
|
||||
series = pd.to_numeric(values, errors="coerce").dropna()
|
||||
if series.empty:
|
||||
return 1.0
|
||||
median = series.median()
|
||||
mad = (series - median).abs().median()
|
||||
if pd.notna(mad) and mad > 0:
|
||||
return float(1.4826 * mad)
|
||||
iqr = series.quantile(0.75) - series.quantile(0.25)
|
||||
if pd.notna(iqr) and iqr > 0:
|
||||
return float(iqr / 1.349)
|
||||
std = series.std()
|
||||
if pd.notna(std) and std > 0:
|
||||
return float(std)
|
||||
return 1.0
|
||||
|
||||
|
||||
def _shrink_toward_baseline(observed: float, baseline: float, scale: float) -> float:
|
||||
"""把观测值向基线值收缩,scale 越小,修复越强。"""
|
||||
if pd.isna(observed):
|
||||
return baseline
|
||||
if pd.isna(baseline):
|
||||
return observed
|
||||
diff = observed - baseline
|
||||
weight = scale / (abs(diff) + scale)
|
||||
return float(baseline + diff * weight)
|
||||
|
||||
|
||||
def _infer_time_frequency(time_values: pd.Series | pd.Index) -> pd.Timedelta:
|
||||
"""从时间序列中推断采样频率,失败时默认 15 分钟。"""
|
||||
parsed = pd.to_datetime(pd.Series(time_values), errors="coerce").dropna().sort_values()
|
||||
if len(parsed) < 2:
|
||||
return pd.Timedelta(minutes=15)
|
||||
|
||||
diffs = parsed.diff().dropna()
|
||||
diffs = diffs[diffs > pd.Timedelta(0)]
|
||||
if diffs.empty:
|
||||
return pd.Timedelta(minutes=15)
|
||||
|
||||
mode = diffs.mode()
|
||||
return mode.iloc[0] if not mode.empty else diffs.median()
|
||||
|
||||
|
||||
def _build_local_pressure_baseline(series: pd.Series) -> pd.Series:
|
||||
"""基于局部插值与中值滤波构造平滑基线。"""
|
||||
baseline = _safe_time_interpolate(series)
|
||||
baseline = baseline.rolling(window=5, center=True, min_periods=1).median()
|
||||
baseline = _safe_time_interpolate(baseline)
|
||||
return baseline.ffill().bfill()
|
||||
|
||||
|
||||
def _build_seasonal_pressure_baseline(series: pd.Series) -> pd.Series:
|
||||
"""按一天内的同一时刻构造季节性基线,适合日周期压力数据。"""
|
||||
if not isinstance(series.index, pd.DatetimeIndex):
|
||||
return pd.Series(np.nan, index=series.index, dtype=float)
|
||||
|
||||
slot_labels = pd.Series(series.index.strftime("%H:%M:%S"), index=series.index)
|
||||
return series.groupby(slot_labels).transform("median")
|
||||
|
||||
|
||||
def _detect_pressure_spikes(series: pd.Series, local_baseline: pd.Series) -> pd.Series:
|
||||
"""识别单点异常上升/下降尖峰,避免过度修正正常波动。"""
|
||||
residual = series - local_baseline
|
||||
neighbor_center = (series.shift(1) + series.shift(-1)) / 2
|
||||
curvature = series - neighbor_center
|
||||
|
||||
residual_scale = max(_robust_scale(residual), 1e-6)
|
||||
curvature_scale = max(_robust_scale(curvature), 1e-6)
|
||||
direction_flip = ((series - series.shift(1)) * (series.shift(-1) - series) < 0).fillna(False)
|
||||
|
||||
return (
|
||||
residual.abs() > 3.5 * residual_scale
|
||||
) & (
|
||||
curvature.abs() > 3.0 * curvature_scale
|
||||
) & direction_flip
|
||||
|
||||
|
||||
def _fill_pressure_gaps(
|
||||
original: pd.Series,
|
||||
repaired: pd.Series,
|
||||
local_baseline: pd.Series,
|
||||
seasonal_baseline: pd.Series,
|
||||
) -> pd.Series:
|
||||
"""短缺口用局部插值,长缺口优先使用同一时刻的季节性轨迹。"""
|
||||
missing_mask = original.isna()
|
||||
if not missing_mask.any():
|
||||
return repaired
|
||||
|
||||
gap_groups = (missing_mask != missing_mask.shift(fill_value=False)).cumsum()
|
||||
gap_lengths = missing_mask.groupby(gap_groups).transform("sum").where(missing_mask, 0)
|
||||
|
||||
filled = repaired.copy()
|
||||
short_gap_mask = missing_mask & (gap_lengths < 4)
|
||||
long_gap_mask = missing_mask & ~short_gap_mask
|
||||
|
||||
filled[short_gap_mask] = local_baseline[short_gap_mask]
|
||||
long_gap_fill = seasonal_baseline.where(seasonal_baseline.notna(), local_baseline)
|
||||
filled[long_gap_mask] = long_gap_fill[long_gap_mask]
|
||||
return filled
|
||||
|
||||
|
||||
def _clean_pressure_series(series: pd.Series) -> pd.Series:
|
||||
"""清洗单个压力时间序列。"""
|
||||
series = pd.to_numeric(series, errors="coerce").astype(float)
|
||||
local_baseline = _build_local_pressure_baseline(series)
|
||||
spike_mask = _detect_pressure_spikes(series, local_baseline)
|
||||
|
||||
repaired = series.copy()
|
||||
repaired[spike_mask] = local_baseline[spike_mask]
|
||||
|
||||
seasonal_baseline = _build_seasonal_pressure_baseline(repaired)
|
||||
repaired = _fill_pressure_gaps(series, repaired, local_baseline, seasonal_baseline)
|
||||
|
||||
if repaired.isna().any():
|
||||
repaired = repaired.where(repaired.notna(), local_baseline)
|
||||
return repaired.ffill().bfill()
|
||||
|
||||
|
||||
def _format_time_column(data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""统一输出时间格式,方便下游直接按 ISO 字符串解析。"""
|
||||
if "time" not in data.columns:
|
||||
return data
|
||||
|
||||
formatted = data.copy()
|
||||
time_values = pd.to_datetime(formatted["time"], errors="coerce")
|
||||
if time_values.isna().all():
|
||||
return formatted
|
||||
|
||||
if time_values.dt.tz is not None:
|
||||
time_strings = time_values.dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
time_strings = time_strings.str.replace(
|
||||
r"([+-]\d{2})(\d{2})$",
|
||||
r"\1:\2",
|
||||
regex=True,
|
||||
)
|
||||
else:
|
||||
time_strings = time_values.dt.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
|
||||
formatted["time"] = time_strings.where(time_values.notna(), formatted["time"])
|
||||
return formatted
|
||||
|
||||
|
||||
def _expand_snapshot_time_grid(data: pd.DataFrame, freq: pd.Timedelta) -> pd.DataFrame:
|
||||
"""仅补齐时间轴,不提前填充值,避免长缺口丢失原始形状特征。"""
|
||||
expanded = data.copy()
|
||||
expanded["time"] = pd.to_datetime(expanded["time"], errors="coerce")
|
||||
expanded = expanded.dropna(subset=["time"]).sort_values("time")
|
||||
if expanded.empty:
|
||||
return data
|
||||
|
||||
indexed = expanded.set_index("time")
|
||||
full_index = pd.date_range(indexed.index.min(), indexed.index.max(), freq=freq)
|
||||
indexed = indexed.reindex(full_index)
|
||||
indexed.index.name = "time"
|
||||
return indexed.reset_index()
|
||||
|
||||
|
||||
def _safe_datetime_index(values: pd.Series | pd.Index | list[object]) -> pd.DatetimeIndex | None:
|
||||
"""尽量把时间值标准化为 DatetimeIndex;失败则返回 None。"""
|
||||
parsed = pd.to_datetime(values, errors="coerce")
|
||||
try:
|
||||
datetime_index = pd.DatetimeIndex(parsed)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if datetime_index.isna().all():
|
||||
return None
|
||||
return datetime_index
|
||||
|
||||
|
||||
def _safe_time_interpolate(series: pd.Series) -> pd.Series:
|
||||
"""仅在索引确实是 DatetimeIndex 时使用 time interpolation。"""
|
||||
if isinstance(series.index, pd.DatetimeIndex):
|
||||
return series.interpolate(method="time", limit_direction="both")
|
||||
return series.interpolate(limit_direction="both")
|
||||
|
||||
|
||||
def _detect_long_form_identifier(data: pd.DataFrame, value_cols: list[str], keep_cols: list[str]) -> str | None:
|
||||
"""识别 time/id/value 长表结构。"""
|
||||
if "time" not in data.columns or len(value_cols) != 1:
|
||||
return None
|
||||
|
||||
identifier_candidates = [
|
||||
col
|
||||
for col in keep_cols
|
||||
if col.lower() in ID_LIKE_COLUMNS or col.lower().endswith("_id")
|
||||
]
|
||||
if len(identifier_candidates) != 1:
|
||||
return None
|
||||
if not data["time"].duplicated().any():
|
||||
return None
|
||||
return identifier_candidates[0]
|
||||
|
||||
|
||||
def _clean_long_form_pressure(
|
||||
data: pd.DataFrame,
|
||||
value_col: str,
|
||||
identifier_col: str,
|
||||
keep_cols: list[str],
|
||||
fill_gaps: bool,
|
||||
) -> pd.DataFrame:
|
||||
"""按测点拆分 long-form 压力数据,再逐列清洗后恢复原结构。"""
|
||||
data = _normalize_time_frame(data)
|
||||
wide_df = (
|
||||
data[[identifier_col, "time", value_col]]
|
||||
.pivot(index="time", columns=identifier_col, values=value_col)
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
sensor_cols = [col for col in wide_df.columns if col != "time"]
|
||||
cleaned_wide = _clean_snapshot_pressure(wide_df, sensor_cols, keep_cols=[], fill_gaps=fill_gaps)
|
||||
|
||||
cleaned_long = cleaned_wide.melt(
|
||||
id_vars="time",
|
||||
var_name=identifier_col,
|
||||
value_name=value_col,
|
||||
)
|
||||
|
||||
passthrough_cols = [col for col in keep_cols if col != identifier_col]
|
||||
if passthrough_cols:
|
||||
metadata = data[[identifier_col] + passthrough_cols].drop_duplicates(subset=[identifier_col])
|
||||
cleaned_long = cleaned_long.merge(metadata, on=identifier_col, how="left")
|
||||
|
||||
try:
|
||||
cleaned_long[identifier_col] = cleaned_long[identifier_col].astype(data[identifier_col].dtype)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
cleaned_long = cleaned_long.sort_values(["time", identifier_col]).reset_index(drop=True)
|
||||
ordered_cols = ["time", identifier_col] + passthrough_cols + [value_col]
|
||||
cleaned_long = cleaned_long[[col for col in ordered_cols if col in cleaned_long.columns]]
|
||||
return cleaned_long
|
||||
|
||||
|
||||
def _build_time_slot_frame(
|
||||
data: pd.DataFrame, value_col: str, expected_slots: int
|
||||
) -> pd.DataFrame:
|
||||
"""把重复时间点整理成 time x slot 的矩阵。"""
|
||||
grouped = data.groupby("time", sort=True)
|
||||
times = list(grouped.groups.keys())
|
||||
slot_frame = pd.DataFrame(index=pd.Index(times, name="time"), columns=range(expected_slots), dtype=float)
|
||||
|
||||
for time_value, group in grouped:
|
||||
values = pd.to_numeric(group[value_col], errors="coerce").tolist()
|
||||
for slot_idx, value in enumerate(values[:expected_slots]):
|
||||
slot_frame.loc[time_value, slot_idx] = value
|
||||
return slot_frame
|
||||
|
||||
|
||||
def _slot_baseline(slot_frame: pd.DataFrame) -> pd.DataFrame:
|
||||
"""对每个槽位做时间插值和平滑,得到基线轨迹。"""
|
||||
baseline = pd.DataFrame(index=slot_frame.index, columns=slot_frame.columns, dtype=float)
|
||||
for col in slot_frame.columns:
|
||||
series = slot_frame[col].astype(float)
|
||||
series = _safe_time_interpolate(series)
|
||||
series = series.rolling(window=5, center=True, min_periods=1).median()
|
||||
series = _safe_time_interpolate(series).ffill().bfill()
|
||||
baseline[col] = series
|
||||
return baseline
|
||||
|
||||
|
||||
def _choose_insertion_position(
|
||||
observed: list[float], baseline_row: pd.Series, expected_slots: int
|
||||
) -> int:
|
||||
"""为少一个观测值的时间组选择最合理的插入位置。"""
|
||||
missing_count = expected_slots - len(observed)
|
||||
if missing_count <= 0:
|
||||
return 0
|
||||
|
||||
best_pos = 0
|
||||
best_cost = float("inf")
|
||||
for insert_pos in range(expected_slots):
|
||||
cost = 0.0
|
||||
obs_idx = 0
|
||||
for slot_idx in range(expected_slots):
|
||||
if slot_idx == insert_pos:
|
||||
continue
|
||||
obs_value = observed[obs_idx]
|
||||
base_value = float(baseline_row.iloc[slot_idx])
|
||||
if pd.notna(obs_value) and pd.notna(base_value):
|
||||
cost += abs(obs_value - base_value)
|
||||
obs_idx += 1
|
||||
if cost < best_cost:
|
||||
best_cost = cost
|
||||
best_pos = insert_pos
|
||||
return best_pos
|
||||
|
||||
|
||||
def _clean_repeated_timestamp_pressure(
|
||||
data: pd.DataFrame, value_col: str, keep_cols: list[str]
|
||||
) -> pd.DataFrame:
|
||||
"""针对同一时间点重复采样的压力数据进行修复。"""
|
||||
data = _normalize_time_frame(data)
|
||||
grouped_sizes = data.groupby("time").size()
|
||||
if grouped_sizes.empty:
|
||||
return data
|
||||
|
||||
expected_slots = int(grouped_sizes.mode().iloc[0]) if not grouped_sizes.mode().empty else int(grouped_sizes.max())
|
||||
expected_slots = max(expected_slots, int(grouped_sizes.max()))
|
||||
slot_frame = _build_time_slot_frame(data, value_col, expected_slots)
|
||||
baseline_frame = _slot_baseline(slot_frame)
|
||||
|
||||
residuals = slot_frame - baseline_frame
|
||||
slot_scales = {
|
||||
col: max(_robust_scale(residuals[col]), 1e-6) for col in residuals.columns
|
||||
}
|
||||
|
||||
cleaned_rows: list[dict[str, object]] = []
|
||||
grouped = data.groupby("time", sort=True)
|
||||
for time_value, group in grouped:
|
||||
observed_values = pd.to_numeric(group[value_col], errors="coerce").tolist()
|
||||
baseline_row = baseline_frame.loc[time_value]
|
||||
insert_pos = _choose_insertion_position(observed_values, baseline_row, expected_slots)
|
||||
|
||||
cleaned_values: list[float] = []
|
||||
obs_idx = 0
|
||||
for slot_idx in range(expected_slots):
|
||||
if slot_idx == insert_pos and len(observed_values) < expected_slots:
|
||||
cleaned_values.append(float(baseline_row.iloc[slot_idx]))
|
||||
continue
|
||||
|
||||
if obs_idx >= len(observed_values):
|
||||
cleaned_values.append(float(baseline_row.iloc[slot_idx]))
|
||||
continue
|
||||
|
||||
observed = observed_values[obs_idx]
|
||||
baseline = float(baseline_row.iloc[slot_idx])
|
||||
cleaned_values.append(
|
||||
_shrink_toward_baseline(observed, baseline, slot_scales.get(slot_idx, 1.0))
|
||||
)
|
||||
obs_idx += 1
|
||||
|
||||
# 其余字段原样保留;常量列(如 id)直接复制第一条记录即可
|
||||
template_row = group.iloc[0].to_dict()
|
||||
for slot_idx, cleaned_value in enumerate(cleaned_values):
|
||||
row = dict(template_row)
|
||||
row["time"] = time_value
|
||||
row[value_col] = cleaned_value
|
||||
cleaned_rows.append(row)
|
||||
|
||||
cleaned_df = pd.DataFrame(cleaned_rows)
|
||||
cleaned_df = cleaned_df.sort_values(["time"]).reset_index(drop=True)
|
||||
ordered_cols = ["time"] + keep_cols + [value_col]
|
||||
ordered_cols = [col for col in ordered_cols if col in cleaned_df.columns]
|
||||
remaining_cols = [col for col in cleaned_df.columns if col not in ordered_cols]
|
||||
cleaned_df = cleaned_df[ordered_cols + remaining_cols]
|
||||
return _format_time_column(cleaned_df)
|
||||
|
||||
|
||||
def _clean_snapshot_pressure(
|
||||
data: pd.DataFrame, value_cols: list[str], keep_cols: list[str], fill_gaps: bool
|
||||
) -> pd.DataFrame:
|
||||
"""针对单条时间序列或多列快照数据进行稳健修复。"""
|
||||
data = _normalize_time_frame(data)
|
||||
if fill_gaps and "time" in data.columns:
|
||||
freq = _infer_time_frequency(data["time"])
|
||||
data = _expand_snapshot_time_grid(data, freq)
|
||||
data["time"] = pd.to_datetime(data["time"], errors="coerce")
|
||||
data = data.sort_values(["time"]).reset_index(drop=True)
|
||||
|
||||
cleaned_df = data.copy()
|
||||
time_index = (
|
||||
_safe_datetime_index(cleaned_df["time"])
|
||||
if "time" in cleaned_df.columns
|
||||
else None
|
||||
)
|
||||
if time_index is None:
|
||||
time_index = pd.RangeIndex(start=0, stop=len(cleaned_df))
|
||||
for col in value_cols:
|
||||
series = pd.Series(
|
||||
pd.to_numeric(cleaned_df[col], errors="coerce").to_numpy(),
|
||||
index=time_index,
|
||||
dtype=float,
|
||||
)
|
||||
cleaned_df[col] = _clean_pressure_series(series).to_numpy()
|
||||
|
||||
ordered_cols = ["time"] + keep_cols + value_cols
|
||||
ordered_cols = [col for col in ordered_cols if col in cleaned_df.columns]
|
||||
remaining_cols = [col for col in cleaned_df.columns if col not in ordered_cols]
|
||||
cleaned_df = cleaned_df[ordered_cols + remaining_cols]
|
||||
return _format_time_column(cleaned_df)
|
||||
|
||||
|
||||
def clean_pressure_data_km(
|
||||
input_csv_path: str, show_plot: bool = False, fill_gaps: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
读取输入 CSV,基于时间结构进行稳健修复。输出为 <input_basename>_cleaned.xlsx(同目录)。
|
||||
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'。
|
||||
返回输出文件的绝对路径。
|
||||
|
||||
Args:
|
||||
input_csv_path: CSV 文件路径
|
||||
show_plot: 是否显示可视化
|
||||
fill_gaps: 是否先补齐时间缺口(默认 True)
|
||||
"""
|
||||
# 读取 CSV
|
||||
input_csv_path = os.path.abspath(input_csv_path)
|
||||
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
data = _normalize_time_frame(data)
|
||||
value_cols, keep_cols = _select_pressure_columns(data)
|
||||
has_repeated_time = "time" in data.columns and data["time"].duplicated().any()
|
||||
identifier_col = _detect_long_form_identifier(data, value_cols, keep_cols)
|
||||
|
||||
if identifier_col is not None:
|
||||
data_repaired = _clean_long_form_pressure(
|
||||
data,
|
||||
value_cols[0],
|
||||
identifier_col,
|
||||
keep_cols,
|
||||
fill_gaps,
|
||||
)
|
||||
elif has_repeated_time and len(value_cols) == 1:
|
||||
data_repaired = _clean_repeated_timestamp_pressure(data, value_cols[0], keep_cols)
|
||||
else:
|
||||
data_repaired = _clean_snapshot_pressure(data, value_cols, keep_cols, fill_gaps)
|
||||
|
||||
# 可选可视化(只展示首个数值列)
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
if show_plot and value_cols:
|
||||
plot_col = value_cols[0]
|
||||
if "time" in data_repaired.columns:
|
||||
x = pd.to_datetime(data_repaired["time"], errors="coerce")
|
||||
else:
|
||||
x = np.arange(len(data_repaired))
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(x, pd.to_numeric(data_repaired[plot_col], errors="coerce"), label="cleaned")
|
||||
plt.xlabel("时间" if "time" in data_repaired.columns else "序号")
|
||||
plt.ylabel("压力监测值")
|
||||
plt.title(f"{plot_col} 清洗结果")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
# 保存到 Excel:两个 sheet
|
||||
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
|
||||
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
|
||||
output_filename = f"{input_base}_cleaned.xlsx"
|
||||
output_path = os.path.join(input_dir, output_filename)
|
||||
|
||||
# 如果原始数据包含时间列,将其添加回结果
|
||||
data_for_save = data.copy()
|
||||
data_repaired_for_save = data_repaired.copy()
|
||||
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path) # 覆盖同名文件
|
||||
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
|
||||
data_for_save.to_excel(writer, sheet_name="raw_pressure_data", index=False)
|
||||
data_repaired_for_save.to_excel(
|
||||
writer, sheet_name="cleaned_pressusre_data", index=False
|
||||
)
|
||||
|
||||
# 返回输出文件的绝对路径
|
||||
return os.path.abspath(output_path)
|
||||
|
||||
|
||||
def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> pd.DataFrame:
|
||||
"""
|
||||
接收一个 DataFrame 数据结构,使用时间感知的稳健修复方法清洗压力数据。
|
||||
返回清洗后的 DataFrame。
|
||||
|
||||
Args:
|
||||
data: 输入 DataFrame(可包含 time 列)
|
||||
show_plot: 是否显示可视化
|
||||
"""
|
||||
# 使用传入的 DataFrame
|
||||
data = data.copy()
|
||||
data = _normalize_time_frame(data)
|
||||
value_cols, keep_cols = _select_pressure_columns(data)
|
||||
has_repeated_time = "time" in data.columns and data["time"].duplicated().any()
|
||||
identifier_col = _detect_long_form_identifier(data, value_cols, keep_cols)
|
||||
|
||||
if identifier_col is not None:
|
||||
data_repaired = _clean_long_form_pressure(
|
||||
data,
|
||||
value_cols[0],
|
||||
identifier_col,
|
||||
keep_cols,
|
||||
fill_gaps=True,
|
||||
)
|
||||
elif has_repeated_time and len(value_cols) == 1:
|
||||
data_repaired = _clean_repeated_timestamp_pressure(data, value_cols[0], keep_cols)
|
||||
else:
|
||||
data_repaired = _clean_snapshot_pressure(data, value_cols, keep_cols, fill_gaps=True)
|
||||
|
||||
if show_plot and value_cols:
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
plot_col = value_cols[0]
|
||||
x = pd.to_datetime(data_repaired["time"], errors="coerce") if "time" in data_repaired.columns else np.arange(len(data_repaired))
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(x, pd.to_numeric(data_repaired[plot_col], errors="coerce"), label="cleaned")
|
||||
plt.xlabel("时间" if "time" in data_repaired.columns else "序号")
|
||||
plt.ylabel("压力监测值")
|
||||
plt.title(f"{plot_col} 清洗结果")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
return data_repaired
|
||||
|
||||
|
||||
# 测试
|
||||
# if __name__ == "__main__":
|
||||
# # 默认使用脚本目录下的 pressure_raw_data.csv
|
||||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# default_csv = os.path.join(script_dir, "pressure_raw_data.csv")
|
||||
# out_path = clean_pressure_data_km(default_csv, show_plot=False)
|
||||
# print("保存路径:", out_path)
|
||||
|
||||
# 测试 clean_pressure_data_dict_km 函数
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
# 读取 szh_pressure_scada.csv 文件
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
csv_path = os.path.join(script_dir, "szh_pressure_scada.csv")
|
||||
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
|
||||
# 排除 Time 列,随机选择 5 列
|
||||
columns_to_exclude = ["Time"]
|
||||
available_columns = [col for col in data.columns if col not in columns_to_exclude]
|
||||
selected_columns = random.sample(available_columns, 5)
|
||||
|
||||
# 将选中的列转换为字典
|
||||
data_dict = {col: data[col].tolist() for col in selected_columns}
|
||||
|
||||
print("选中的列:", selected_columns)
|
||||
print("原始数据长度:", len(data_dict[selected_columns[0]]))
|
||||
|
||||
# 调用函数进行清洗
|
||||
cleaned_dict = clean_pressure_data_df_km(data_dict, show_plot=True)
|
||||
|
||||
print("清洗后的字典键:", list(cleaned_dict.keys()))
|
||||
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
|
||||
print("测试完成:函数运行正常")
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.algorithms.health.analyzer import PipelineHealthAnalyzer
|
||||
|
||||
__all__ = ["PipelineHealthAnalyzer"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.algorithms.isolation.valve import valve_isolation_analysis
|
||||
|
||||
__all__ = ["valve_isolation_analysis"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.algorithms.leakage.identifier import LeakageIdentifier
|
||||
|
||||
__all__ = ["LeakageIdentifier"]
|
||||
@@ -0,0 +1,653 @@
|
||||
import wntr
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from pymoo.core.problem import Problem
|
||||
from pymoo.core.callback import Callback
|
||||
from pymoo.algorithms.soo.nonconvex.ga import GA
|
||||
from pymoo.operators.crossover.sbx import SBX
|
||||
from pymoo.operators.mutation.pm import PM
|
||||
from pymoo.optimize import minimize as pymoo_minimize
|
||||
from pymoo.termination.default import DefaultSingleObjectiveTermination
|
||||
|
||||
|
||||
from app.algorithms._utils import _cleanup_temp_files
|
||||
|
||||
_worker_data: dict[str, Any] = {}
|
||||
DEFAULT_N_WORKERS = max(1, min(cpu_count() - 1, 4))
|
||||
|
||||
|
||||
def _worker_init(
|
||||
inp_path: str,
|
||||
sensor_nodes: list[str],
|
||||
area_ids: list[str],
|
||||
nodes_by_area: dict[str, list[str]],
|
||||
obs_matrix: np.ndarray,
|
||||
q_sum: float,
|
||||
duration_sec: float,
|
||||
timestep_sec: float,
|
||||
) -> None:
|
||||
global _worker_data
|
||||
wn = wntr.network.WaterNetworkModel(inp_path)
|
||||
wn.options.hydraulic.demand_model = "DD"
|
||||
wn.options.time.duration = duration_sec
|
||||
wn.options.time.hydraulic_timestep = timestep_sec
|
||||
wn.options.time.pattern_timestep = timestep_sec
|
||||
wn.options.time.report_timestep = timestep_sec
|
||||
|
||||
demand_objs_by_area = {}
|
||||
allocatable_counts = {}
|
||||
for area_id in area_ids:
|
||||
demand_objs = []
|
||||
for node_name in nodes_by_area.get(area_id, []):
|
||||
if node_name not in wn.node_name_list:
|
||||
continue
|
||||
node = wn.get_node(node_name)
|
||||
if (
|
||||
hasattr(node, "demand_timeseries_list")
|
||||
and len(node.demand_timeseries_list) > 0
|
||||
):
|
||||
demand_objs.append(node.demand_timeseries_list[0])
|
||||
demand_objs_by_area[area_id] = demand_objs
|
||||
allocatable_counts[area_id] = len(demand_objs)
|
||||
|
||||
_worker_data = {
|
||||
"wn": wn,
|
||||
"sensor_nodes": sensor_nodes,
|
||||
"area_ids": area_ids,
|
||||
"nodes_by_area": nodes_by_area,
|
||||
"demand_objs_by_area": demand_objs_by_area,
|
||||
"allocatable_counts": allocatable_counts,
|
||||
"obs_matrix": obs_matrix,
|
||||
"q_sum": q_sum,
|
||||
}
|
||||
|
||||
|
||||
def _worker_evaluate(raw_ratios: np.ndarray) -> float:
|
||||
d = _worker_data
|
||||
effective_ratio_map = LeakageIdentifier._effective_area_ratios(
|
||||
raw_ratios,
|
||||
d["area_ids"],
|
||||
d["nodes_by_area"],
|
||||
allocatable_counts=d["allocatable_counts"],
|
||||
)
|
||||
|
||||
modifications = []
|
||||
for area_id in d["area_ids"]:
|
||||
ratio = effective_ratio_map.get(area_id, 0.0)
|
||||
if ratio <= 0:
|
||||
continue
|
||||
|
||||
demand_objs = d["demand_objs_by_area"].get(area_id, [])
|
||||
if not demand_objs:
|
||||
continue
|
||||
|
||||
per_node_leak = d["q_sum"] * ratio / len(demand_objs)
|
||||
for demand_obj in demand_objs:
|
||||
original_val = demand_obj.base_value
|
||||
demand_obj.base_value = original_val + per_node_leak
|
||||
modifications.append((demand_obj, original_val))
|
||||
|
||||
temp_dir = os.path.abspath(os.path.join("temp", "leakage"))
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
prefix = os.path.join(temp_dir, f"temp_{os.getpid()}")
|
||||
|
||||
try:
|
||||
sim = wntr.sim.EpanetSimulator(d["wn"])
|
||||
results = sim.run_sim(file_prefix=prefix)
|
||||
sim_pressure = results.node["pressure"].loc[:, d["sensor_nodes"]]
|
||||
|
||||
n_steps = min(sim_pressure.shape[0], d["obs_matrix"].shape[0])
|
||||
sim_vals = sim_pressure.values[:n_steps, :]
|
||||
obs_vals = d["obs_matrix"][:n_steps, :]
|
||||
diff = sim_vals - obs_vals
|
||||
|
||||
row_max = np.max(np.abs(diff), axis=1, keepdims=True)
|
||||
row_max[row_max == 0] = 1.0
|
||||
normalized_diff = diff / row_max
|
||||
return float(np.linalg.norm(normalized_diff))
|
||||
|
||||
except Exception:
|
||||
return 1e9
|
||||
|
||||
finally:
|
||||
for demand_obj, original_val in modifications:
|
||||
demand_obj.base_value = original_val
|
||||
_cleanup_temp_files(prefix)
|
||||
|
||||
|
||||
class LeakageIdentifier:
|
||||
FLOW_UNIT_TO_M3S = {
|
||||
"m3/s": 1.0,
|
||||
"m3/h": 1.0 / 3600.0,
|
||||
"L/s": 1.0 / 1000.0,
|
||||
"L/min": 1.0 / 60000.0,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _flow_to_m3s(cls, value: float, unit: str) -> float:
|
||||
if unit not in cls.FLOW_UNIT_TO_M3S:
|
||||
raise ValueError(f"不支持的流量单位: {unit}")
|
||||
return float(value) * cls.FLOW_UNIT_TO_M3S[unit]
|
||||
|
||||
@classmethod
|
||||
def _flow_from_m3s(cls, value_m3s: float, unit: str) -> float:
|
||||
if unit not in cls.FLOW_UNIT_TO_M3S:
|
||||
raise ValueError(f"不支持的流量单位: {unit}")
|
||||
return float(value_m3s) / cls.FLOW_UNIT_TO_M3S[unit]
|
||||
|
||||
@staticmethod
|
||||
def _effective_area_ratios(
|
||||
raw_ratios: Union["np.ndarray", Dict[str, float]],
|
||||
area_ids: List[str],
|
||||
nodes_by_area: Dict[str, List[str]],
|
||||
allocatable_counts: Union[Dict[str, int], None] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""将输入比例转换为有效区域比例,确保有效区域比例和为 1。"""
|
||||
area_count = len(area_ids)
|
||||
if area_count == 0:
|
||||
return {}
|
||||
|
||||
if isinstance(raw_ratios, dict):
|
||||
ratios = np.array(
|
||||
[float(raw_ratios.get(area_id, 0.0)) for area_id in area_ids],
|
||||
dtype=float,
|
||||
)
|
||||
else:
|
||||
arr = np.asarray(raw_ratios, dtype=float).reshape(-1)
|
||||
ratios = np.zeros(area_count, dtype=float)
|
||||
fill_len = min(area_count, arr.shape[0])
|
||||
if fill_len > 0:
|
||||
ratios[:fill_len] = arr[:fill_len]
|
||||
|
||||
# 仅保留非负比例,负值按 0 处理
|
||||
ratios = np.clip(ratios, a_min=0.0, a_max=None)
|
||||
|
||||
# 仅在有效区域(存在可分配节点)内归一化
|
||||
if allocatable_counts is not None:
|
||||
valid_mask = np.array(
|
||||
[int(allocatable_counts.get(area_id, 0)) > 0 for area_id in area_ids],
|
||||
dtype=bool,
|
||||
)
|
||||
else:
|
||||
valid_mask = np.array(
|
||||
[len(nodes_by_area.get(area_id, [])) > 0 for area_id in area_ids],
|
||||
dtype=bool,
|
||||
)
|
||||
if not np.any(valid_mask):
|
||||
raise ValueError("没有可分配漏损的有效分区,无法满足漏损总量约束。")
|
||||
|
||||
effective = np.zeros(area_count, dtype=float)
|
||||
valid_sum = float(np.sum(ratios[valid_mask]))
|
||||
|
||||
if valid_sum > 0:
|
||||
effective[valid_mask] = ratios[valid_mask] / valid_sum
|
||||
else:
|
||||
# 若输入全为 0,则在有效区域内均分,保证总和仍为 1
|
||||
valid_count = int(np.sum(valid_mask))
|
||||
effective[valid_mask] = 1.0 / valid_count
|
||||
|
||||
return {area_id: float(effective[idx]) for idx, area_id in enumerate(area_ids)}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_area_map_df(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""标准化区域映射列名为 ID 和 Area。"""
|
||||
if "ID" in df.columns and "Area" in df.columns:
|
||||
return df
|
||||
|
||||
if "ID" in df.columns and "now" in df.columns:
|
||||
df = df.rename(columns={"now": "Area"})
|
||||
return df
|
||||
|
||||
df = df.copy()
|
||||
df.columns = ["ID", "Area"] + list(df.columns[2:])
|
||||
return df
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inp_path: str,
|
||||
sensor_nodes: List[str],
|
||||
area_map: Union[str, Dict[str, str]],
|
||||
start_time: float = 0,
|
||||
duration: float = 24,
|
||||
timestep: float = 5,
|
||||
q_sum: float = 0.2,
|
||||
):
|
||||
"""
|
||||
初始化漏损识别器。
|
||||
|
||||
参数:
|
||||
inp_path: EPANET .inp 文件路径。
|
||||
sensor_nodes: 用作压力传感器的节点 ID 列表。
|
||||
area_map: 节点到区域的映射。可以是 CSV 文件路径(列:ID, Area),也可以是字典 {NodeID: AreaID}。
|
||||
start_time: 模拟开始时间(小时)。
|
||||
duration: 模拟持续时间(小时)。
|
||||
timestep: 模拟时间步长(分钟)。
|
||||
q_sum: 假设的总漏损流量 (m3/s)。
|
||||
"""
|
||||
self.inp_path = inp_path
|
||||
self.sensor_nodes = sensor_nodes
|
||||
self.start_time = start_time
|
||||
self.duration = duration
|
||||
self.timestep = timestep
|
||||
self.q_sum = q_sum
|
||||
|
||||
# 加载管网模型(仅一次)
|
||||
self.wn = wntr.network.WaterNetworkModel(self.inp_path)
|
||||
|
||||
# 优化 WNTR 设置以提高速度
|
||||
self.wn.options.hydraulic.demand_model = "DD"
|
||||
self.wn.options.time.duration = float(self.duration) * 3600
|
||||
self.wn.options.time.hydraulic_timestep = float(self.timestep) * 60
|
||||
self.wn.options.time.pattern_timestep = float(self.timestep) * 60
|
||||
self.wn.options.time.report_timestep = float(self.timestep) * 60
|
||||
|
||||
# 加载区域映射
|
||||
if isinstance(area_map, str):
|
||||
self.area_map_df = self._load_area_map(area_map)
|
||||
elif isinstance(area_map, dict):
|
||||
self.area_map_df = self._normalize_area_map_df(
|
||||
pd.DataFrame(list(area_map.items()), columns=["ID", "Area"])
|
||||
)
|
||||
else:
|
||||
raise ValueError("area_map 必须是 CSV 文件路径或字典。")
|
||||
|
||||
self.area_ids = sorted(self.area_map_df["Area"].unique())
|
||||
self.num_areas = len(self.area_ids)
|
||||
|
||||
# 按区域对节点进行预分类,以便更快查找
|
||||
self.nodes_by_area = {
|
||||
area: self.area_map_df[self.area_map_df["Area"] == area]["ID"].tolist()
|
||||
for area in self.area_ids
|
||||
}
|
||||
|
||||
def _load_area_map(self, path: str) -> pd.DataFrame:
|
||||
"""加载并验证节点-区域映射文件。"""
|
||||
df = pd.read_csv(path, dtype={"ID": str, "Area": str})
|
||||
return self._normalize_area_map_df(df)
|
||||
|
||||
def run_identification(
|
||||
self,
|
||||
observed_pressure_data: Union[
|
||||
str, pd.DataFrame, Dict[str, List[Any]], List[Dict[str, Any]]
|
||||
],
|
||||
output_dir: str = "Results",
|
||||
pop_size: int = 50,
|
||||
max_gen: int = 100,
|
||||
output_flow_unit: str = "m3/s",
|
||||
save_result: bool = True,
|
||||
ftol: float = 1e-3,
|
||||
ftol_period: int = 15,
|
||||
n_workers: int = DEFAULT_N_WORKERS,
|
||||
):
|
||||
"""
|
||||
运行遗传算法以识别漏损分布。
|
||||
|
||||
参数:
|
||||
observed_pressure_data: 包含 SCADA 压力数据的 CSV 文件路径或 DataFrame/字典列表数据。
|
||||
output_dir: 结果保存目录。
|
||||
pop_size: GA 的种群大小。
|
||||
max_gen: GA 的最大代数。
|
||||
output_flow_unit: 输出漏损流量的单位。
|
||||
save_result: 是否保存识别结果到本地 CSV。
|
||||
ftol: 目标值收敛容差(连续 ftol_period 代改善 < ftol 则停止)。
|
||||
ftol_period: 收敛检测的窗口代数。
|
||||
n_workers: 并行工作进程数(1=串行,>1=并行评估)。
|
||||
"""
|
||||
if save_result:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 加载观测数据
|
||||
if isinstance(observed_pressure_data, str):
|
||||
obs_df = pd.read_csv(observed_pressure_data)
|
||||
observed_name = os.path.basename(observed_pressure_data)
|
||||
elif isinstance(observed_pressure_data, pd.DataFrame):
|
||||
obs_df = observed_pressure_data.copy()
|
||||
observed_name = "observed_pressure.csv"
|
||||
else:
|
||||
obs_df = pd.DataFrame(observed_pressure_data)
|
||||
observed_name = "observed_pressure.csv"
|
||||
|
||||
# 准备 pymoo 问题实例
|
||||
problem = LeakageProblem(
|
||||
self.wn,
|
||||
self.nodes_by_area,
|
||||
self.area_ids,
|
||||
self.sensor_nodes,
|
||||
obs_df,
|
||||
q_sum=self.q_sum,
|
||||
n_workers=n_workers,
|
||||
inp_path=os.path.abspath(self.inp_path),
|
||||
)
|
||||
|
||||
# 配置 pymoo GA 算法
|
||||
n_var = self.num_areas
|
||||
algorithm = GA(
|
||||
pop_size=pop_size,
|
||||
crossover=SBX(prob=0.9, eta=15),
|
||||
mutation=PM(prob=1.0 / max(1, n_var), eta=20),
|
||||
eliminate_duplicates=False,
|
||||
)
|
||||
|
||||
# 终止条件:收敛检测 + 最大代数
|
||||
termination = DefaultSingleObjectiveTermination(
|
||||
ftol=ftol,
|
||||
period=ftol_period,
|
||||
n_max_gen=max_gen,
|
||||
)
|
||||
|
||||
# 回调:记录每代信息
|
||||
callback = _ProgressCallback()
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
res = pymoo_minimize(
|
||||
problem,
|
||||
algorithm,
|
||||
termination,
|
||||
seed=42,
|
||||
verbose=True,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
problem.close()
|
||||
elapsed = time.time() - t0
|
||||
|
||||
# 提取最优解
|
||||
best_ind = res.X # 最优个体(漏损比例原始值)
|
||||
best_obj = float(res.F[0])
|
||||
|
||||
# 输出终止信息
|
||||
print(f"\n优化完成。耗时: {elapsed:.1f}s")
|
||||
print(f"总代数: {res.algorithm.n_gen}, 总评估次数: {problem._eval_count}")
|
||||
print(f"最佳目标值: {best_obj:.6f}")
|
||||
|
||||
# 保存到文件
|
||||
effective_ratio_map = self._effective_area_ratios(
|
||||
best_ind,
|
||||
self.area_ids,
|
||||
self.nodes_by_area,
|
||||
allocatable_counts=problem.allocatable_counts,
|
||||
)
|
||||
normalized_ratios = [
|
||||
effective_ratio_map.get(area_id, 0.0) for area_id in self.area_ids
|
||||
]
|
||||
leakage_flow_m3s = [ratio * self.q_sum for ratio in normalized_ratios]
|
||||
leakage_flow_output = [
|
||||
self._flow_from_m3s(value_m3s, output_flow_unit)
|
||||
for value_m3s in leakage_flow_m3s
|
||||
]
|
||||
|
||||
result_df = pd.DataFrame(
|
||||
{
|
||||
"Area": self.area_ids,
|
||||
"LeakageRatioRaw": best_ind,
|
||||
"LeakageRatio": normalized_ratios,
|
||||
"LeakageFlow_m3_per_s": leakage_flow_m3s,
|
||||
f"LeakageFlow_{output_flow_unit.replace('/', '_per_')}": leakage_flow_output,
|
||||
}
|
||||
)
|
||||
result_path = None
|
||||
if save_result:
|
||||
result_path = os.path.join(
|
||||
output_dir, f"identified_leakage_{observed_name}"
|
||||
)
|
||||
result_df.to_csv(result_path, index=False)
|
||||
print(f"结果已保存至 {result_path}")
|
||||
result_df.attrs["result_path"] = result_path
|
||||
return result_df
|
||||
|
||||
|
||||
class _ProgressCallback(Callback):
|
||||
"""每代回调:记录进度。"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gen_times = []
|
||||
self._t_last = None
|
||||
|
||||
def notify(self, algorithm):
|
||||
now = time.time()
|
||||
if self._t_last is not None:
|
||||
self.gen_times.append(now - self._t_last)
|
||||
self._t_last = now
|
||||
|
||||
|
||||
class LeakageProblem(Problem):
|
||||
"""pymoo 批量评估问题定义。
|
||||
|
||||
搜索空间:n 维 [0, 1] 实数 -> 通过 _effective_area_ratios 归一化到单纯形。
|
||||
目标:模拟压力与观测压力之间的归一化误差范数。
|
||||
无显式约束(sum=1 由归一化自动保证)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wn,
|
||||
nodes_by_area,
|
||||
area_ids,
|
||||
sensor_nodes,
|
||||
observed_data,
|
||||
q_sum: float = 0.2,
|
||||
n_workers: int = DEFAULT_N_WORKERS,
|
||||
inp_path: str | None = None,
|
||||
):
|
||||
n_var = len(area_ids)
|
||||
|
||||
super().__init__(
|
||||
n_var=n_var,
|
||||
n_obj=1,
|
||||
n_ieq_constr=0,
|
||||
xl=np.zeros(n_var),
|
||||
xu=np.ones(n_var),
|
||||
)
|
||||
|
||||
self.wn = wn
|
||||
self.nodes_by_area = nodes_by_area
|
||||
self.area_ids = area_ids
|
||||
self.sensor_nodes = sensor_nodes
|
||||
self.q_sum = q_sum
|
||||
self.n_workers = max(1, int(n_workers))
|
||||
self.inp_path = inp_path
|
||||
|
||||
# 预处理观测数据以匹配模拟格式
|
||||
try:
|
||||
missing_sensors = [
|
||||
s for s in self.sensor_nodes if s not in observed_data.columns
|
||||
]
|
||||
if not missing_sensors:
|
||||
self.obs_matrix = observed_data[self.sensor_nodes].values
|
||||
else:
|
||||
self.obs_matrix = observed_data.values[:, : len(self.sensor_nodes)]
|
||||
except Exception:
|
||||
self.obs_matrix = observed_data.values[:, : len(self.sensor_nodes)]
|
||||
|
||||
duration_sec = float(self.wn.options.time.duration)
|
||||
step_sec = float(self.wn.options.time.hydraulic_timestep)
|
||||
if step_sec > 0:
|
||||
max_steps = int(duration_sec / step_sec) + 1
|
||||
self.obs_matrix = self.obs_matrix[:max_steps, :]
|
||||
|
||||
# 预先缓存每个区域的需水对象,减少每次适应度计算的节点查找
|
||||
self.demand_objs_by_area = {}
|
||||
for area_id in self.area_ids:
|
||||
demand_objs = []
|
||||
for node_name in self.nodes_by_area.get(area_id, []):
|
||||
if node_name not in self.wn.node_name_list:
|
||||
continue
|
||||
node = self.wn.get_node(node_name)
|
||||
if (
|
||||
hasattr(node, "demand_timeseries_list")
|
||||
and len(node.demand_timeseries_list) > 0
|
||||
):
|
||||
demand_objs.append(node.demand_timeseries_list[0])
|
||||
self.demand_objs_by_area[area_id] = demand_objs
|
||||
|
||||
self.allocatable_counts = {
|
||||
area_id: len(self.demand_objs_by_area.get(area_id, []))
|
||||
for area_id in self.area_ids
|
||||
}
|
||||
if not any(count > 0 for count in self.allocatable_counts.values()):
|
||||
raise ValueError("没有可分配漏损的有效分区,无法满足漏损总量约束。")
|
||||
|
||||
# 评估计数器(诊断用)
|
||||
self._eval_count = 0
|
||||
self._pool = None
|
||||
if self.n_workers > 1:
|
||||
if not self.inp_path:
|
||||
raise ValueError("并行评估需要提供 inp_path。")
|
||||
duration_sec = float(self.wn.options.time.duration)
|
||||
timestep_sec = float(self.wn.options.time.hydraulic_timestep)
|
||||
self._pool = Pool(
|
||||
processes=self.n_workers,
|
||||
initializer=_worker_init,
|
||||
initargs=(
|
||||
self.inp_path,
|
||||
list(self.sensor_nodes),
|
||||
list(self.area_ids),
|
||||
{k: list(v) for k, v in self.nodes_by_area.items()},
|
||||
self.obs_matrix.copy(),
|
||||
self.q_sum,
|
||||
duration_sec,
|
||||
timestep_sec,
|
||||
),
|
||||
)
|
||||
|
||||
def _evaluate(self, X, out, *args, **kwargs):
|
||||
"""批量评估种群。
|
||||
|
||||
X: 形状 (pop_size, n_var) 的决策变量矩阵。
|
||||
"""
|
||||
n_pop = X.shape[0]
|
||||
self._eval_count += n_pop
|
||||
|
||||
if self._pool is not None:
|
||||
results = self._pool.map(_worker_evaluate, [X[i] for i in range(n_pop)])
|
||||
out["F"] = np.array(results, dtype=float).reshape(-1, 1)
|
||||
return
|
||||
|
||||
F = np.zeros((n_pop, 1))
|
||||
for i in range(n_pop):
|
||||
F[i, 0] = self._evaluate_single(X[i])
|
||||
out["F"] = F
|
||||
|
||||
def _evaluate_single(self, x):
|
||||
"""评估单个个体,返回归一化误差范数。"""
|
||||
leak_ratios = x
|
||||
|
||||
# 将漏损分布归一化
|
||||
effective_ratio_map = LeakageIdentifier._effective_area_ratios(
|
||||
leak_ratios,
|
||||
self.area_ids,
|
||||
self.nodes_by_area,
|
||||
allocatable_counts=self.allocatable_counts,
|
||||
)
|
||||
|
||||
# 跟踪修改以便稍后恢复
|
||||
modifications = []
|
||||
|
||||
for j, area_id in enumerate(self.area_ids):
|
||||
ratio = effective_ratio_map.get(area_id, 0.0)
|
||||
if ratio <= 0:
|
||||
continue
|
||||
|
||||
demand_objs = self.demand_objs_by_area.get(area_id, [])
|
||||
if not demand_objs:
|
||||
continue
|
||||
|
||||
per_node_leak = self.q_sum * ratio / len(demand_objs)
|
||||
|
||||
for demand_obj in demand_objs:
|
||||
original_val = demand_obj.base_value
|
||||
demand_obj.base_value = original_val + per_node_leak
|
||||
modifications.append((demand_obj, original_val))
|
||||
|
||||
# 结果保存在根目录的temp/leakage文件夹中
|
||||
temp_dir = os.path.abspath(os.path.join("temp", "leakage"))
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
prefix = os.path.join(temp_dir, f"temp_{os.getpid()}")
|
||||
|
||||
try:
|
||||
sim = wntr.sim.EpanetSimulator(self.wn)
|
||||
results = sim.run_sim(file_prefix=prefix)
|
||||
sim_pressure = results.node["pressure"].loc[:, self.sensor_nodes]
|
||||
|
||||
n_steps = min(sim_pressure.shape[0], self.obs_matrix.shape[0])
|
||||
sim_vals = sim_pressure.values[:n_steps, :]
|
||||
obs_vals = self.obs_matrix[:n_steps, :]
|
||||
|
||||
diff = sim_vals - obs_vals
|
||||
|
||||
# 按行最大值归一化
|
||||
row_max = np.max(np.abs(diff), axis=1, keepdims=True)
|
||||
row_max[row_max == 0] = 1.0 # 防止除以零
|
||||
|
||||
normalized_diff = diff / row_max
|
||||
|
||||
# 目标:归一化差值矩阵的 2-范数
|
||||
return float(np.linalg.norm(normalized_diff))
|
||||
|
||||
except Exception:
|
||||
return 1e9
|
||||
|
||||
finally:
|
||||
for demand_obj, original_val in modifications:
|
||||
demand_obj.base_value = original_val
|
||||
|
||||
_cleanup_temp_files(prefix)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._pool is not None:
|
||||
self._pool.close()
|
||||
self._pool.join()
|
||||
self._pool = None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="漏损区域识别")
|
||||
parser.add_argument("--inp", required=True, help=".inp 文件路径")
|
||||
parser.add_argument("--map", help="节点-区域映射 CSV 路径")
|
||||
parser.add_argument("--scada", help="SCADA 压力 CSV 路径 (观测数据)")
|
||||
parser.add_argument("--sensors", help="传感器节点 ID 列表 (逗号分隔)")
|
||||
parser.add_argument("--output", default="Results", help="输出目录")
|
||||
|
||||
parser.add_argument("--pop_size", type=int, default=50, help="种群大小")
|
||||
parser.add_argument("--max_gen", type=int, default=100, help="最大代数")
|
||||
parser.add_argument("--duration", type=float, default=24, help="模拟时长(小时)")
|
||||
parser.add_argument("--q_sum", type=float, default=0.241, help="总漏损流量")
|
||||
parser.add_argument(
|
||||
"--q_sum_unit",
|
||||
default="m3/s",
|
||||
choices=list(LeakageIdentifier.FLOW_UNIT_TO_M3S.keys()),
|
||||
help="q_sum 输入单位(建议与现场习惯一致,内部统一换算为 m3/s)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.map or not args.scada or not args.sensors:
|
||||
parser.error("--map、--scada、--sensors 为必填")
|
||||
|
||||
q_sum_m3s = LeakageIdentifier._flow_to_m3s(args.q_sum, args.q_sum_unit)
|
||||
|
||||
sensors = [sensor.strip() for sensor in args.sensors.split(",") if sensor.strip()]
|
||||
identifier = LeakageIdentifier(
|
||||
args.inp, sensors, args.map, duration=args.duration, q_sum=q_sum_m3s
|
||||
)
|
||||
|
||||
identifier.run_identification(
|
||||
args.scada,
|
||||
args.output,
|
||||
pop_size=args.pop_size,
|
||||
max_gen=args.max_gen,
|
||||
output_flow_unit=args.q_sum_unit,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,8 +1,8 @@
|
||||
import psycopg
|
||||
|
||||
import app.algorithms.api_ex.kmeans_sensor as kmeans_sensor
|
||||
import app.algorithms.api_ex.sensitivity as sensitivity
|
||||
from app.native.api.postgresql_info import get_pgconn_string
|
||||
from app.algorithms.sensor import kmeans as kmeans_sensor
|
||||
from app.algorithms.sensor import sensitivity
|
||||
from app.core.config import get_pgconn_string
|
||||
from app.services.tjnetwork import dump_inp
|
||||
|
||||
|
||||
@@ -6,104 +6,66 @@ import sklearn.cluster
|
||||
import os
|
||||
|
||||
|
||||
|
||||
class QD_KMeans(object):
|
||||
def __init__(self, wn, num_monitors):
|
||||
# self.inp = inp
|
||||
self.cluster_num = num_monitors # 聚类中心个数,也即测压点个数
|
||||
self.wn=wn
|
||||
self.cluster_num = num_monitors # 聚类中心个数,也即测压点个数
|
||||
self.wn = wn
|
||||
self.monitor_nodes = []
|
||||
self.coords = []
|
||||
self.junction_nodes = {} # Added missing initialization
|
||||
|
||||
|
||||
def get_junctions_coordinates(self):
|
||||
|
||||
for junction_name in self.wn.junction_name_list:
|
||||
|
||||
for junction_name in self.wn.junction_name_list:
|
||||
junction = self.wn.get_node(junction_name)
|
||||
self.junction_nodes[junction_name] = junction.coordinates
|
||||
self.coords.append(junction.coordinates )
|
||||
self.coords.append(junction.coordinates)
|
||||
|
||||
# print(f"Total junctions: {self.junction_coordinates}")
|
||||
# print(f"Total junctions: {self.junction_coordinates}")
|
||||
|
||||
def select_monitoring_points(self):
|
||||
if not self.coords: # Add check if coordinates are collected
|
||||
self.get_junctions_coordinates()
|
||||
coords = np.array(self.coords)
|
||||
coords_normalized = (coords - coords.min(axis=0)) / (coords.max(axis=0) - coords.min(axis=0))
|
||||
kmeans = sklearn.cluster.KMeans(n_clusters= self.cluster_num, random_state=42)
|
||||
kmeans.fit(coords_normalized)
|
||||
coords_normalized = (coords - coords.min(axis=0)) / (
|
||||
coords.max(axis=0) - coords.min(axis=0)
|
||||
)
|
||||
kmeans = sklearn.cluster.KMeans(n_clusters=self.cluster_num, random_state=42)
|
||||
kmeans.fit(coords_normalized)
|
||||
|
||||
for center in kmeans.cluster_centers_:
|
||||
distances = np.sum((coords_normalized - center) ** 2, axis=1)
|
||||
nearest_node = self.wn.junction_name_list[np.argmin(distances)]
|
||||
self.monitor_nodes.append(nearest_node)
|
||||
self.monitor_nodes.append(nearest_node)
|
||||
|
||||
return self.monitor_nodes
|
||||
|
||||
|
||||
def visualize_network(self):
|
||||
"""Visualize network with monitoring points"""
|
||||
ax=wntr.graphics.plot_network(self.wn,
|
||||
node_attribute=self.monitor_nodes,
|
||||
node_size=30,
|
||||
title='Optimal sensor')
|
||||
plt.show()
|
||||
ax = wntr.graphics.plot_network(
|
||||
self.wn,
|
||||
node_attribute=self.monitor_nodes,
|
||||
node_size=30,
|
||||
title="Optimal sensor",
|
||||
)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
def kmeans_sensor_placement(name: str, sensor_num: int, min_diameter: int) -> list:
|
||||
inp_name = f'./db_inp/{name}.db.inp'
|
||||
wn= wntr.network.WaterNetworkModel(inp_name)
|
||||
wn_cluster=QD_KMeans(wn, sensor_num)
|
||||
inp_name = f"./db_inp/{name}.db.inp"
|
||||
wn = wntr.network.WaterNetworkModel(inp_name)
|
||||
wn_cluster = QD_KMeans(wn, sensor_num)
|
||||
|
||||
# Select monitoring pointse
|
||||
sensor_ids= wn_cluster.select_monitoring_points()
|
||||
sensor_ids = wn_cluster.select_monitoring_points()
|
||||
|
||||
# wn_cluster.visualize_network()
|
||||
|
||||
return sensor_ids
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#sensorindex = get_ID(name='suzhouhe_2024_cloud_0817', sensor_num=30, min_diameter=500)
|
||||
sensorindex = kmeans_sensor_placement(name='szh', sensor_num=50, min_diameter=300)
|
||||
# sensorindex = get_ID(name='suzhouhe_2024_cloud_0817', sensor_num=30, min_diameter=500)
|
||||
sensorindex = kmeans_sensor_placement(name="szh", sensor_num=50, min_diameter=300)
|
||||
print(sensorindex)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from sklearn.cluster import KMeans
|
||||
from wntr.epanet.toolkit import EpanetException
|
||||
from numpy.linalg import slogdet
|
||||
import random
|
||||
from app.services.tjnetwork import *
|
||||
from matplotlib.lines import Line2D
|
||||
from sklearn.cluster import SpectralClustering
|
||||
import libpysal as ps
|
||||
@@ -21,6 +20,7 @@ import geopandas as gpd
|
||||
from sklearn.metrics import pairwise_distances
|
||||
import app.services.project_info as project_info
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step1: 获取节点坐标
|
||||
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
|
||||
@@ -32,7 +32,7 @@ def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
|
||||
# site: pandas.Series
|
||||
# index:节点名称(wn.node_name_list)
|
||||
# values:每个节点的坐标,格式为 tuple(如 (x, y) 或 (x, y, z))
|
||||
site = wn.query_node_attribute('coordinates')
|
||||
site = wn.query_node_attribute("coordinates")
|
||||
# Coor: pandas.Series
|
||||
# index:与site相同(节点名称)。
|
||||
# values:坐标转换为numpy.ndarray(如array([10.5, 20.3]))
|
||||
@@ -44,9 +44,9 @@ def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
|
||||
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
|
||||
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
|
||||
# xy: dict[str, list], x、y 坐标的字典
|
||||
xy = {'x': x, 'y': y}
|
||||
xy = {"x": x, "y": y}
|
||||
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
|
||||
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
|
||||
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=["x", "y"])
|
||||
return Coor_node
|
||||
|
||||
|
||||
@@ -88,23 +88,23 @@ def skater_partition(G, n_clusters):
|
||||
字典形式的聚类结果,键为区域编号,值为该区域内的节点列表。
|
||||
"""
|
||||
# 1. 获取所有节点坐标,假设每个节点都有 'pos' 属性
|
||||
pos = nx.get_node_attributes(G, 'pos')
|
||||
pos = nx.get_node_attributes(G, "pos")
|
||||
nodes = list(G.nodes())
|
||||
# 构造坐标数组:每行为 [x, y]
|
||||
coords = np.array([pos[node] for node in nodes])
|
||||
|
||||
# 2. 构造 GeoDataFrame:创建 DataFrame 并生成 geometry 列
|
||||
df = pd.DataFrame(coords, columns=['x', 'y'], index=nodes)
|
||||
df = pd.DataFrame(coords, columns=["x", "y"], index=nodes)
|
||||
# 利用 shapely 的 Point 构造空间位置
|
||||
df['geometry'] = df.apply(lambda row: Point(row['x'], row['y']), axis=1)
|
||||
gdf = gpd.GeoDataFrame(df, geometry='geometry')
|
||||
df["geometry"] = df.apply(lambda row: Point(row["x"], row["y"]), axis=1)
|
||||
gdf = gpd.GeoDataFrame(df, geometry="geometry")
|
||||
|
||||
# 3. 构造空间权重矩阵,使用 4 近邻方法(k=4,可根据实际情况调整)
|
||||
w = ps.weights.KNN.from_array(coords, k=4)
|
||||
w.transform = 'R'
|
||||
w.transform = "R"
|
||||
|
||||
# 4. 调用 SKATER:新版本 API 要求传入 gdf, w 以及 attrs_name(这里使用 'x' 和 'y' 作为属性)
|
||||
skater = Skater(gdf, w, attrs_name=['x', 'y'], n_clusters=n_clusters)
|
||||
skater = Skater(gdf, w, attrs_name=["x", "y"], n_clusters=n_clusters)
|
||||
skater.solve()
|
||||
|
||||
# 5. 获取聚类标签,构造成字典格式
|
||||
@@ -134,24 +134,24 @@ def spectral_partition(G, n_clusters):
|
||||
键为聚类标签,值为该聚类对应的节点列表。
|
||||
"""
|
||||
# 1. 获取节点空间坐标,注意保证每个节点都有 'pos' 属性
|
||||
pos_dict = nx.get_node_attributes(G, 'pos')
|
||||
pos_dict = nx.get_node_attributes(G, "pos")
|
||||
nodes = list(G.nodes())
|
||||
coords = np.array([pos_dict[node] for node in nodes])
|
||||
|
||||
# 2. 计算节点之间的欧氏距离矩阵
|
||||
D = pairwise_distances(coords, metric='euclidean')
|
||||
D = pairwise_distances(coords, metric="euclidean")
|
||||
|
||||
# 3. 计算 sigma 值:这里取所有距离的均值,当然也可以根据实际情况调整
|
||||
sigma = np.mean(D)
|
||||
|
||||
# 4. 构造相似度矩阵:使用高斯核函数
|
||||
# A(i, j) = exp( -d(i,j)^2 / (2*sigma^2) )
|
||||
A = np.exp(- (D ** 2) / (2 * sigma ** 2))
|
||||
A = np.exp(-(D**2) / (2 * sigma**2))
|
||||
|
||||
# 5. 使用谱聚类进行图分区
|
||||
clustering = SpectralClustering(n_clusters=n_clusters,
|
||||
affinity='precomputed',
|
||||
random_state=0)
|
||||
clustering = SpectralClustering(
|
||||
n_clusters=n_clusters, affinity="precomputed", random_state=0
|
||||
)
|
||||
labels = clustering.fit_predict(A)
|
||||
|
||||
# 6. 构造字典形式的分区结果
|
||||
@@ -161,6 +161,7 @@ def spectral_partition(G, n_clusters):
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step3: wn_func类,水力计算
|
||||
# wn_func 主要用于计算:
|
||||
@@ -182,7 +183,7 @@ class wn_func(object):
|
||||
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
|
||||
self.wn = wn
|
||||
# self.q:pandas.DataFrame,管道流量,索引为时间步长,列为管道名称
|
||||
self.q = self.results.link['flowrate']
|
||||
self.q = self.results.link["flowrate"]
|
||||
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
|
||||
ReservoirIndex = wn.reservoir_name_list
|
||||
Tankindex = wn.tank_name_list
|
||||
@@ -192,7 +193,7 @@ class wn_func(object):
|
||||
# self.nodes: list[str],所有节点的名称
|
||||
self.nodes = wn.node_name_list
|
||||
# self.coordinates:pandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
|
||||
self.coordinates = wn.query_node_attribute('coordinates')
|
||||
self.coordinates = wn.query_node_attribute("coordinates")
|
||||
# allpumps / allvalves: list[str],所有泵/阀门名称列表
|
||||
allpumps = wn.pump_name_list
|
||||
allvalves = wn.valve_name_list
|
||||
@@ -223,17 +224,27 @@ class wn_func(object):
|
||||
# 泵的起终点、tank、reservoir
|
||||
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
|
||||
self.delnodes = list(
|
||||
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
|
||||
set(ReservoirIndex).union(
|
||||
Tankindex,
|
||||
pumpstnode,
|
||||
pumpednode,
|
||||
valvestnode,
|
||||
valveednode,
|
||||
Reservoirednode,
|
||||
)
|
||||
)
|
||||
# 泵、起终点为tank、reservoir的管道
|
||||
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
|
||||
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
|
||||
self.delpipes = list(
|
||||
set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe)
|
||||
)
|
||||
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
|
||||
# self.L: list[float],所有管道的长度(以米为单位)
|
||||
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
|
||||
self.L = wn.query_link_attribute("length")[self.pipes].tolist()
|
||||
self.n = len(self.nodes)
|
||||
self.m = len(self.pipes)
|
||||
# self.unit_headloss: list[float],单位水头损失(headloss 数据的第一行,单位:米/km)
|
||||
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
|
||||
self.unit_headloss = self.results.link["headloss"].iloc[0, :].tolist()
|
||||
##
|
||||
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
|
||||
|
||||
@@ -246,7 +257,9 @@ class wn_func(object):
|
||||
end_node = wn.links[pipe].end_node.name
|
||||
self.less_than_min_diameter_junction_list.extend([start_node, end_node])
|
||||
# 去重
|
||||
self.less_than_min_diameter_junction_list = list(set(self.less_than_min_diameter_junction_list))
|
||||
self.less_than_min_diameter_junction_list = list(
|
||||
set(self.less_than_min_diameter_junction_list)
|
||||
)
|
||||
|
||||
# Step3.2: 计算水力距离
|
||||
def CtoS(self):
|
||||
@@ -267,7 +280,7 @@ class wn_func(object):
|
||||
q = self.q
|
||||
L = self.L
|
||||
# H1:pandas.DataFrame,水头数据,索引为时间步长,列为节点名
|
||||
H1 = self.results.node['head'].T
|
||||
H1 = self.results.node["head"].T
|
||||
# hh:list[float],计算管道两端水头之差
|
||||
hh = []
|
||||
# 水头损失
|
||||
@@ -281,8 +294,18 @@ class wn_func(object):
|
||||
# headloss:pandas.DataFrame,管道水头损失矩阵
|
||||
headloss = pd.DataFrame(hh, index=pipes).T
|
||||
# s1:管道阻力系数,s2:将管道阻力系数与管道的起始节点和终止节点对应
|
||||
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
hf = pd.DataFrame(
|
||||
np.array([0] * (n**2)).reshape(n, n),
|
||||
index=nodes,
|
||||
columns=nodes,
|
||||
dtype=float,
|
||||
)
|
||||
weightL = pd.DataFrame(
|
||||
np.array([0] * (n**2)).reshape(n, n),
|
||||
index=nodes,
|
||||
columns=nodes,
|
||||
dtype=float,
|
||||
)
|
||||
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵,index代表起始节点,columns代表终止节点
|
||||
G = nx.DiGraph()
|
||||
for i in range(0, m):
|
||||
@@ -299,11 +322,16 @@ class wn_func(object):
|
||||
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
|
||||
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
|
||||
|
||||
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
hydraulicL = pd.DataFrame(
|
||||
np.array([0] * (n**2)).reshape(n, n),
|
||||
index=nodes,
|
||||
columns=nodes,
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
for a in nodes:
|
||||
if a in G.nodes:
|
||||
d = nx.shortest_path_length(G, source=a, weight='weight')
|
||||
d = nx.shortest_path_length(G, source=a, weight="weight")
|
||||
for b in list(d.keys()):
|
||||
hydraulicL.loc[a, b] = d[b]
|
||||
|
||||
@@ -332,11 +360,17 @@ class wn_func(object):
|
||||
for t in self.wn.tanks():
|
||||
self.nonjunc_index.append(t[0])
|
||||
# Conn:numpy.matrix,节点-管道连接矩阵,起点 -1,终点 1
|
||||
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
|
||||
Conn = np.mat(
|
||||
np.zeros([n, m - p - v])
|
||||
) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
|
||||
# NConn:numpy.matrix,节点-节点连接矩阵,有管道相连的地方设为 1
|
||||
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系,之间有管道为1,反之为0
|
||||
# pipes:list[str],去除泵和阀门的管道列表
|
||||
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
|
||||
pipes = [
|
||||
pipe
|
||||
for pipe in self.wn.pipes()
|
||||
if pipe not in self.wn.pumps() and pipe not in self.wn.valves()
|
||||
]
|
||||
for pipe_name, pipe in pipes:
|
||||
start = self.wn.node_name_list.index(pipe.start_node_name)
|
||||
end = self.wn.node_name_list.index(pipe.end_node_name)
|
||||
@@ -346,12 +380,21 @@ class wn_func(object):
|
||||
NConn[start, end] = 1
|
||||
NConn[end, start] = 1
|
||||
self.A = Conn
|
||||
link_name_list = [link for link in self.wn.link_name_list if
|
||||
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
|
||||
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
|
||||
link_name_list = [
|
||||
link
|
||||
for link in self.wn.link_name_list
|
||||
if link not in self.wn.pump_name_list
|
||||
and link not in self.wn.valve_name_list
|
||||
]
|
||||
self.A2 = pd.DataFrame(
|
||||
self.A, index=self.wn.node_name_list, columns=link_name_list
|
||||
)
|
||||
self.A2 = self.A2.drop(self.delnodes)
|
||||
for pipe in self.delpipes:
|
||||
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
|
||||
if (
|
||||
pipe not in self.wn.pump_name_list
|
||||
and pipe not in self.wn.valve_name_list
|
||||
):
|
||||
self.A2 = self.A2.drop(columns=pipe)
|
||||
self.junc_list = self.A2.index
|
||||
self.A2 = np.mat(self.A2) # 节点管道关系
|
||||
@@ -373,10 +416,10 @@ class wn_func(object):
|
||||
except EpanetException:
|
||||
pass
|
||||
finally:
|
||||
h = result.link['headloss'][self.pipes].values[0]
|
||||
q = result.link['flowrate'][self.pipes].values[0]
|
||||
l = self.wn.query_link_attribute('length')[self.pipes]
|
||||
C = self.wn.query_link_attribute('roughness')[self.pipes]
|
||||
h = result.link["headloss"][self.pipes].values[0]
|
||||
q = result.link["flowrate"][self.pipes].values[0]
|
||||
l = self.wn.query_link_attribute("length")[self.pipes]
|
||||
C = self.wn.query_link_attribute("roughness")[self.pipes]
|
||||
# headloss:numpy.ndarray,水头损失数组
|
||||
headloss = np.array(h)
|
||||
# 调整流量方向
|
||||
@@ -394,7 +437,7 @@ class wn_func(object):
|
||||
try:
|
||||
det = np.linalg.det(X)
|
||||
except RuntimeError as e:
|
||||
sign, logdet = slogdet(X) # 防止溢出
|
||||
sign, logdet = slogdet(X) # 防止溢出
|
||||
det = sign * np.exp(logdet)
|
||||
if det != 0:
|
||||
J_H_Cw = X.I * A * S
|
||||
@@ -431,7 +474,10 @@ class Sensorplacement(wn_func):
|
||||
"""
|
||||
Sensorplacement 类继承了 wn_func 类,并且用于计算和优化传感器布置的位置。
|
||||
"""
|
||||
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int, min_diameter: int):
|
||||
|
||||
def __init__(
|
||||
self, wn: wntr.network.WaterNetworkModel, sensornum: int, min_diameter: int
|
||||
):
|
||||
"""
|
||||
|
||||
:param wn: 由wntr生成的模型
|
||||
@@ -443,7 +489,9 @@ class Sensorplacement(wn_func):
|
||||
|
||||
# 1.某个节点到所有节点的加权距离之和
|
||||
# 2.某个节点到该组内所有节点的加权距离之和
|
||||
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
|
||||
def sensor(
|
||||
self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]
|
||||
):
|
||||
"""
|
||||
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
|
||||
:param SS: 灵敏度矩阵,每个节点的行和列代表不同节点,矩阵元素表示节点间的灵敏度。SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
|
||||
@@ -528,7 +576,7 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
|
||||
:return: 测压点节点ID
|
||||
"""
|
||||
# inp_file_real:str,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
|
||||
inp_file_real = f'./db_inp/{name}.db.inp'
|
||||
inp_file_real = f"./db_inp/{name}.db.inp"
|
||||
# sensornum:int,需要布置的传感器数量
|
||||
# sensornum = sensor_num
|
||||
# wn_real:wntr.network.WaterNetworkModel,加载 EPANET 水力模型
|
||||
@@ -539,7 +587,7 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
|
||||
results_real = sim_real.run_sim()
|
||||
|
||||
# real_C:list[float],包含所有管道粗糙度的列表
|
||||
real_C = wn_real.query_link_attribute('roughness').tolist()
|
||||
real_C = wn_real.query_link_attribute("roughness").tolist()
|
||||
# wn_fun1:wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
|
||||
wn_fun1 = wn_func(wn_real, min_diameter=min_diameter)
|
||||
# nodes:list[str],管网的节点名称列表
|
||||
@@ -599,7 +647,6 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
|
||||
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
|
||||
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
|
||||
|
||||
|
||||
# 重新打开数据库
|
||||
# if is_project_open(name=name):
|
||||
# close_project(name=name)
|
||||
@@ -638,7 +685,7 @@ def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
|
||||
return sensorindex
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sensorindex = get_ID(name=project_info.name, sensor_num=20, min_diameter=300)
|
||||
|
||||
print(sensorindex)
|
||||
@@ -0,0 +1,19 @@
|
||||
from app.algorithms.simulation.scenarios import (
|
||||
convert_to_local_unit,
|
||||
burst_analysis,
|
||||
valve_close_analysis,
|
||||
flushing_analysis,
|
||||
contaminant_simulation,
|
||||
age_analysis,
|
||||
pressure_regulation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"convert_to_local_unit",
|
||||
"burst_analysis",
|
||||
"valve_close_analysis",
|
||||
"flushing_analysis",
|
||||
"contaminant_simulation",
|
||||
"age_analysis",
|
||||
"pressure_regulation",
|
||||
]
|
||||
@@ -1,6 +1,26 @@
|
||||
import numpy as np
|
||||
from app.services.tjnetwork import *
|
||||
from api.s36_wda_cal import *
|
||||
from app.services.tjnetwork import (
|
||||
ChangeSet,
|
||||
close_project,
|
||||
copy_project,
|
||||
delete_project,
|
||||
get_pattern,
|
||||
get_patterns,
|
||||
get_pump,
|
||||
get_reservoir,
|
||||
get_status,
|
||||
get_tank,
|
||||
get_time,
|
||||
have_project,
|
||||
is_project_open,
|
||||
open_project,
|
||||
read_all,
|
||||
run_project,
|
||||
set_pattern,
|
||||
set_status,
|
||||
set_tank,
|
||||
set_time,
|
||||
)
|
||||
# from get_real_status import *
|
||||
from datetime import datetime,timedelta
|
||||
from math import modf
|
||||
@@ -5,14 +5,39 @@ from math import pi, sqrt
|
||||
import pytz
|
||||
|
||||
import app.services.simulation as simulation
|
||||
from app.algorithms.api_ex.run_simulation import (
|
||||
from app.algorithms.simulation.runner import (
|
||||
run_simulation_ex,
|
||||
from_clock_to_seconds_2,
|
||||
)
|
||||
from app.native.api.project import copy_project
|
||||
from app.services.epanet.epanet import Output
|
||||
from app.services.scheme_management import store_scheme_info
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
ChangeSet,
|
||||
OPTION_DEMAND_MODEL_PDA,
|
||||
OPTION_QUALITY_CHEMICAL,
|
||||
SOURCE_TYPE_SETPOINT,
|
||||
add_pattern,
|
||||
add_source,
|
||||
close_project,
|
||||
copy_project,
|
||||
delete_project,
|
||||
get_demand,
|
||||
get_emitter,
|
||||
get_node_links,
|
||||
get_option,
|
||||
get_pattern,
|
||||
get_pipe,
|
||||
get_source,
|
||||
get_time,
|
||||
have_project,
|
||||
is_junction,
|
||||
is_project_open,
|
||||
open_project,
|
||||
set_demand,
|
||||
set_emitter,
|
||||
set_option,
|
||||
set_source,
|
||||
set_time,
|
||||
)
|
||||
|
||||
|
||||
############################################################
|
||||
@@ -637,24 +662,24 @@ def age_analysis(
|
||||
new_name,
|
||||
"realtime",
|
||||
modify_pattern_start_time,
|
||||
modify_total_duration,
|
||||
duration=modify_total_duration,
|
||||
downloading_prohibition=True,
|
||||
)
|
||||
simulation_result = json.loads(result)
|
||||
output_data = simulation_result.get("output")
|
||||
if not isinstance(output_data, dict):
|
||||
raise RuntimeError("run_simulation_ex did not return JSON output content")
|
||||
# step 2. restore the base model status
|
||||
# execute_undo(name) #有疑惑
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
output = Output("./temp/{}.db.out".format(new_name))
|
||||
# element_name = output.element_name()
|
||||
# node_name = element_name['nodes']
|
||||
# link_name = element_name['links']
|
||||
nodes_age = []
|
||||
node_result = output.node_results()
|
||||
node_result = output_data.get("node_results") or []
|
||||
for node in node_result:
|
||||
nodes_age.append(node["result"][-1]["quality"])
|
||||
links_age = []
|
||||
link_result = output.link_results()
|
||||
link_result = output_data.get("link_results") or []
|
||||
for link in link_result:
|
||||
links_age.append(link["result"][-1]["quality"])
|
||||
age_result = {"nodes": nodes_age, "links": links_age}
|
||||
@@ -3,28 +3,36 @@
|
||||
|
||||
仅管理员可访问
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, Query, Path
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
|
||||
from app.auth.metadata_dependencies import (
|
||||
get_current_metadata_admin,
|
||||
get_current_metadata_user,
|
||||
)
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.db.metadb.database import get_metadata_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def get_audit_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> AuditRepository:
|
||||
"""获取审计日志仓储"""
|
||||
return AuditRepository(session)
|
||||
|
||||
@router.get("/logs", response_model=List[AuditLogResponse])
|
||||
|
||||
@router.get(
|
||||
"/logs",
|
||||
summary="查询审计日志",
|
||||
description="查询审计日志(仅管理员)",
|
||||
response_model=List[AuditLogResponse],
|
||||
)
|
||||
async def get_audit_logs(
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
@@ -38,9 +46,9 @@ async def get_audit_logs(
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志(仅管理员)
|
||||
|
||||
支持按用户、时间、操作类型等条件过滤
|
||||
查询审计日志
|
||||
|
||||
支持按用户、时间、操作类型等条件过滤,仅管理员可访问
|
||||
"""
|
||||
logs = await audit_repo.get_logs(
|
||||
user_id=user_id,
|
||||
@@ -50,11 +58,16 @@ async def get_audit_logs(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
)
|
||||
return logs
|
||||
|
||||
@router.get("/logs/count")
|
||||
|
||||
@router.get(
|
||||
"/logs/count",
|
||||
summary="获取审计日志总数",
|
||||
description="获取审计日志总数(仅管理员)",
|
||||
)
|
||||
async def get_audit_logs_count(
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
@@ -66,7 +79,9 @@ async def get_audit_logs_count(
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> dict:
|
||||
"""
|
||||
获取审计日志总数(仅管理员)
|
||||
获取审计日志总数
|
||||
|
||||
获取符合条件的审计日志的总数,仅管理员可访问
|
||||
"""
|
||||
count = await audit_repo.get_log_count(
|
||||
user_id=user_id,
|
||||
@@ -74,23 +89,29 @@ async def get_audit_logs_count(
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
end_time=end_time,
|
||||
)
|
||||
return {"count": count}
|
||||
|
||||
@router.get("/logs/my", response_model=List[AuditLogResponse])
|
||||
|
||||
@router.get(
|
||||
"/logs/my",
|
||||
summary="查询我的审计日志",
|
||||
description="查询当前用户的审计日志",
|
||||
response_model=List[AuditLogResponse],
|
||||
)
|
||||
async def get_my_audit_logs(
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询当前用户的审计日志
|
||||
|
||||
|
||||
普通用户只能查看自己的操作记录
|
||||
"""
|
||||
logs = await audit_repo.get_logs(
|
||||
@@ -99,6 +120,6 @@ async def get_my_audit_logs(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
)
|
||||
return logs
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token, create_refresh_token, verify_password
|
||||
from app.domain.schemas.user import UserCreate, UserResponse, UserLogin, Token
|
||||
from app.infra.repositories.user_repository import UserRepository
|
||||
from app.infra.db.metadb.repositories.user_repository import UserRepository
|
||||
from app.auth.dependencies import get_user_repository, get_current_active_user
|
||||
from app.domain.schemas.user import UserInDB
|
||||
import logging
|
||||
@@ -14,53 +14,55 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_data: UserCreate, user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> UserResponse:
|
||||
"""
|
||||
用户注册
|
||||
|
||||
|
||||
创建新用户账号
|
||||
"""
|
||||
# 检查用户名和邮箱是否已存在
|
||||
if await user_repo.user_exists(username=user_data.username):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
detail="Username already registered",
|
||||
)
|
||||
|
||||
|
||||
if await user_repo.user_exists(email=user_data.email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered"
|
||||
)
|
||||
|
||||
|
||||
# 创建用户
|
||||
try:
|
||||
user = await user_repo.create_user(user_data)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create user"
|
||||
detail="Failed to create user",
|
||||
)
|
||||
return UserResponse.model_validate(user)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during user registration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Registration failed"
|
||||
detail="Registration failed",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> Token:
|
||||
"""
|
||||
用户登录(OAuth2 标准格式)
|
||||
|
||||
|
||||
返回 JWT Access Token 和 Refresh Token
|
||||
"""
|
||||
# 验证用户(支持用户名或邮箱登录)
|
||||
@@ -68,119 +70,121 @@ async def login(
|
||||
if not user:
|
||||
# 尝试用邮箱登录
|
||||
user = await user_repo.get_user_by_email(form_data.username)
|
||||
|
||||
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user account"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user account"
|
||||
)
|
||||
|
||||
|
||||
# 生成 Token
|
||||
access_token = create_access_token(subject=user.username)
|
||||
refresh_token = create_refresh_token(subject=user.username)
|
||||
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/simple", response_model=Token)
|
||||
async def login_simple(
|
||||
username: str,
|
||||
password: str,
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> Token:
|
||||
"""
|
||||
简化版登录接口(保持向后兼容)
|
||||
|
||||
|
||||
直接使用 username 和 password 参数
|
||||
"""
|
||||
# 验证用户
|
||||
user = await user_repo.get_user_by_username(username)
|
||||
if not user:
|
||||
user = await user_repo.get_user_by_email(username)
|
||||
|
||||
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password"
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user account"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user account"
|
||||
)
|
||||
|
||||
|
||||
# 生成 Token
|
||||
access_token = create_access_token(subject=user.username)
|
||||
refresh_token = create_refresh_token(subject=user.username)
|
||||
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: UserInDB = Depends(get_current_active_user)
|
||||
current_user: UserInDB = Depends(get_current_active_user),
|
||||
) -> UserResponse:
|
||||
"""
|
||||
获取当前登录用户信息
|
||||
"""
|
||||
return UserResponse.model_validate(current_user)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
refresh_token: str,
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
refresh_token: str, user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> Token:
|
||||
"""
|
||||
刷新 Access Token
|
||||
|
||||
|
||||
使用 Refresh Token 获取新的 Access Token
|
||||
"""
|
||||
from jose import jwt, JWTError
|
||||
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
token_type: str = payload.get("type")
|
||||
|
||||
|
||||
if username is None or token_type != "refresh":
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
# 验证用户仍然存在且激活
|
||||
user = await user_repo.get_user_by_username(username)
|
||||
if not user or not user.is_active:
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
# 生成新的 Access Token
|
||||
new_access_token = create_access_token(subject=user.username)
|
||||
|
||||
|
||||
return Token(
|
||||
access_token=new_access_token,
|
||||
refresh_token=refresh_token, # 保持原 refresh token
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_username
|
||||
from app.services.burst_detection import (
|
||||
get_burst_detection_scheme_detail,
|
||||
list_burst_detection_schemes,
|
||||
run_burst_detection,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class BurstDetectionRequest(BaseModel):
|
||||
"""爆管检测请求模型"""
|
||||
network: str = Field(..., description="管网名称(或数据库名称)")
|
||||
observed_pressure_data: (
|
||||
dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]] | None
|
||||
) = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"压力观测数据。支持列式字典 {sensor_id: [values,...]}、"
|
||||
"逐时刻对象数组 [{sensor_id: value,...}, ...]、"
|
||||
"或二维数组 [[t1_s1, t1_s2], [t2_s1, t2_s2], ...]。"
|
||||
),
|
||||
)
|
||||
points_per_day: int = Field(1440, description="每天的数据点数")
|
||||
mu: int = Field(100, description="异常值检测的参数")
|
||||
iforest_params: dict[str, Any] | None = Field(None, description="隔离森林算法参数")
|
||||
scada_start: datetime | None = Field(None, description="SCADA数据起始时间")
|
||||
scada_end: datetime | None = Field(None, description="SCADA数据结束时间")
|
||||
sensor_nodes: list[str] | None = Field(None, description="传感器节点列表")
|
||||
scheme_name: str | None = Field(None, description="方案名称")
|
||||
data_source: str = Field("monitoring", description="数据来源:monitoring(监测)或simulation(模拟)")
|
||||
simulation_scheme_name: str | None = Field(None, description="模拟方案名称")
|
||||
simulation_scheme_type: str | None = Field(None, description="模拟方案类型")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/detect/",
|
||||
summary="执行爆管检测",
|
||||
description="基于压力观测数据和其他参数执行爆管检测分析"
|
||||
)
|
||||
async def detect_burst(
|
||||
data: BurstDetectionRequest = Body(..., description="爆管检测请求数据"),
|
||||
username: str = Depends(get_current_keycloak_username),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
执行爆管检测分析。
|
||||
|
||||
使用异常检测算法(隔离森林)识别压力时间序列中的异常,
|
||||
将其作为潜在的爆管事件。
|
||||
|
||||
Args:
|
||||
data: 包含管网名称(或数据库名称)、压力数据及相关参数的请求体
|
||||
username: 当前认证用户名
|
||||
|
||||
Returns:
|
||||
包含检测结果的字典
|
||||
|
||||
Raises:
|
||||
HTTPException: 当处理过程中发生错误时
|
||||
"""
|
||||
try:
|
||||
return run_burst_detection(**data.model_dump(), username=username)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schemes/",
|
||||
summary="查询爆管检测方案列表",
|
||||
description="获取指定网络的所有爆管检测方案"
|
||||
)
|
||||
async def query_burst_detection_schemes(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
query_date: datetime | None = Query(None, description="查询日期(可选)"),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取爆管检测方案列表。
|
||||
|
||||
查询指定网络的所有已配置的爆管检测方案,
|
||||
可按日期进行筛选。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
query_date: 查询日期(可选)
|
||||
|
||||
Returns:
|
||||
爆管检测方案列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询失败时
|
||||
"""
|
||||
try:
|
||||
return list_burst_detection_schemes(network=network, query_date=query_date)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schemes/{scheme_name}",
|
||||
summary="获取爆管检测方案详情",
|
||||
description="获取指定爆管检测方案的详细信息"
|
||||
)
|
||||
async def query_burst_detection_scheme_detail(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scheme_name: str = Path(..., description="爆管检测方案名称"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取爆管检测方案详情。
|
||||
|
||||
查询指定爆管检测方案的完整配置和参数信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
scheme_name: 爆管检测方案名称
|
||||
|
||||
Returns:
|
||||
包含方案详情的字典
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询失败时
|
||||
"""
|
||||
try:
|
||||
return get_burst_detection_scheme_detail(network=network, scheme_name=scheme_name)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
@@ -0,0 +1,129 @@
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_username
|
||||
from app.services.burst_location import (
|
||||
get_burst_location_scheme_detail,
|
||||
list_burst_location_schemes,
|
||||
run_burst_location_by_network,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class BurstLocationRequest(BaseModel):
|
||||
"""爆管定位请求模型"""
|
||||
network: str = Field(..., description="管网名称(或数据库名称)")
|
||||
data_source: Literal["monitoring", "simulation"] = Field("monitoring", description="数据来源:monitoring(监测)或simulation(模拟)")
|
||||
pressure_scada_ids: list[str] | None = Field(None, description="压力SCADA传感器ID列表")
|
||||
burst_pressure: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="爆管时的压力数据")
|
||||
normal_pressure: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="正常时的压力数据")
|
||||
burst_leakage: float = Field(..., description="爆管时的漏水量")
|
||||
flow_scada_ids: list[str] | None = Field(None, description="流量SCADA传感器ID列表")
|
||||
burst_flow: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="爆管时的流量数据")
|
||||
normal_flow: dict[str, float] | list[dict[str, Any]] | None = Field(None, description="正常时的流量数据")
|
||||
min_dpressure: float = Field(2.0, description="最小压力差(bar)")
|
||||
basic_pressure: float = Field(10.0, description="基准压力(bar)")
|
||||
scada_burst_start: datetime | None = Field(None, description="SCADA爆管开始时间")
|
||||
scada_burst_end: datetime | None = Field(None, description="SCADA爆管结束时间")
|
||||
use_scada_flow: bool = Field(False, description="是否使用SCADA流量数据")
|
||||
scheme_name: str | None = Field(None, description="方案名称")
|
||||
simulation_scheme_name: str | None = Field(None, description="模拟方案名称")
|
||||
simulation_scheme_type: str | None = Field(None, description="模拟方案类型")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/locate/",
|
||||
summary="执行爆管定位",
|
||||
description="基于压力和流量数据定位管网中的爆管位置"
|
||||
)
|
||||
async def locate_burst(
|
||||
data: BurstLocationRequest = Body(..., description="爆管定位请求数据"),
|
||||
username: str = Depends(get_current_keycloak_username),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
执行爆管定位分析。
|
||||
|
||||
使用压力和流量SCADA数据,通过对比爆管和正常状态下的数据差异,
|
||||
定位管网中的爆管位置。
|
||||
|
||||
Args:
|
||||
data: 包含管网名称(或数据库名称)、压力、流量数据及相关参数的请求体
|
||||
username: 当前认证用户名
|
||||
|
||||
Returns:
|
||||
包含定位结果的字典
|
||||
|
||||
Raises:
|
||||
HTTPException: 当数据类型或值不正确时
|
||||
"""
|
||||
try:
|
||||
return run_burst_location_by_network(**data.model_dump(), username=username)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schemes/",
|
||||
summary="查询爆管定位方案列表",
|
||||
description="获取指定网络的所有爆管定位方案"
|
||||
)
|
||||
async def query_burst_schemes(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
query_date: datetime | None = Query(None, description="查询日期(可选)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取爆管定位方案列表。
|
||||
|
||||
查询指定网络的所有已配置的爆管定位方案,
|
||||
可按日期进行筛选。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
query_date: 查询日期(可选)
|
||||
|
||||
Returns:
|
||||
爆管定位方案列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询失败时
|
||||
"""
|
||||
try:
|
||||
return list_burst_location_schemes(network=network, query_date=query_date)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schemes/{scheme_name}",
|
||||
summary="获取爆管定位方案详情",
|
||||
description="获取指定爆管定位方案的详细信息"
|
||||
)
|
||||
async def query_burst_scheme_detail(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scheme_name: str = Path(..., description="爆管定位方案名称")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取爆管定位方案详情。
|
||||
|
||||
查询指定爆管定位方案的完整配置和参数信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
scheme_name: 爆管定位方案名称
|
||||
|
||||
Returns:
|
||||
包含方案详情的字典
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询失败时
|
||||
"""
|
||||
try:
|
||||
return get_burst_location_scheme_detail(network=network, scheme_name=scheme_name)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
@@ -1,16 +1,26 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Query
|
||||
from app.infra.cache.redis_client import redis_client
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/clearrediskey/")
|
||||
async def fastapi_clear_redis_key(key: str):
|
||||
@router.post("/clearrediskey/", summary="清除单个缓存键", description="根据键名清除单个Redis缓存")
|
||||
async def fastapi_clear_redis_key(key: str = Query(..., description="缓存键名")):
|
||||
"""
|
||||
清除单个缓存键
|
||||
|
||||
根据指定的键名删除Redis中对应的缓存
|
||||
"""
|
||||
redis_client.delete(key)
|
||||
return True
|
||||
|
||||
|
||||
@router.post("/clearrediskeys/")
|
||||
async def fastapi_clear_redis_keys(keys: str):
|
||||
@router.post("/clearrediskeys/", summary="清除匹配的缓存键", description="根据模式清除匹配的Redis缓存键")
|
||||
async def fastapi_clear_redis_keys(keys: str = Query(..., description="缓存键模式(支持通配符)")):
|
||||
"""
|
||||
清除匹配的缓存键
|
||||
|
||||
根据指定的模式删除Redis中所有匹配的缓存键
|
||||
"""
|
||||
# delete keys contains the key
|
||||
matched_keys = redis_client.keys(f"*{keys}*")
|
||||
if matched_keys:
|
||||
@@ -19,14 +29,24 @@ async def fastapi_clear_redis_keys(keys: str):
|
||||
return True
|
||||
|
||||
|
||||
@router.post("/clearallredis/")
|
||||
@router.post("/clearallredis/", summary="清除所有缓存", description="清空整个Redis数据库的所有缓存")
|
||||
async def fastapi_clear_all_redis():
|
||||
"""
|
||||
清除所有缓存
|
||||
|
||||
清空Redis数据库中的所有缓存键值对
|
||||
"""
|
||||
redis_client.flushdb()
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/queryredis/")
|
||||
@router.get("/queryredis/", summary="查询缓存键列表", description="获取Redis中所有的缓存键")
|
||||
async def fastapi_query_redis():
|
||||
"""
|
||||
查询缓存键列表
|
||||
|
||||
获取Redis数据库中所有的缓存键列表
|
||||
"""
|
||||
# Helper to decode bytes to str for JSON response if needed,
|
||||
# but original just returned keys (which might be bytes in redis-py unless decode_responses=True)
|
||||
# create_redis_client usually sets decode_responses=False by default.
|
||||
|
||||
@@ -1,31 +1,70 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
get_control,
|
||||
get_control_schema,
|
||||
get_rule,
|
||||
get_rule_schema,
|
||||
set_control,
|
||||
set_rule,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getcontrolschema/")
|
||||
async def fastapi_get_control_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getcontrolschema/", summary="获取控制架构", description="获取网络中控制对象的架构定义")
|
||||
async def fastapi_get_control_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取控制架构。
|
||||
|
||||
返回指定网络中控制对象的属性架构定义。
|
||||
"""
|
||||
return get_control_schema(network)
|
||||
|
||||
@router.get("/getcontrolproperties/")
|
||||
async def fastapi_get_control_properties(network: str) -> dict[str, Any]:
|
||||
@router.get("/getcontrolproperties/", summary="获取控制属性", description="获取指定网络中的控制属性信息")
|
||||
async def fastapi_get_control_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取控制属性。
|
||||
|
||||
返回指定网络中的控制对象属性信息。
|
||||
"""
|
||||
return get_control(network)
|
||||
|
||||
@router.post("/setcontrolproperties/", response_model=None)
|
||||
async def fastapi_set_control_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setcontrolproperties/", response_model=None, summary="设置控制属性", description="更新指定网络中的控制属性")
|
||||
async def fastapi_set_control_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置控制属性。
|
||||
|
||||
更新指定网络中的控制属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_control(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getruleschema/")
|
||||
async def fastapi_get_rule_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getruleschema/", summary="获取规则架构", description="获取网络中规则对象的架构定义")
|
||||
async def fastapi_get_rule_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取规则架构。
|
||||
|
||||
返回指定网络中规则对象的属性架构定义。
|
||||
"""
|
||||
return get_rule_schema(network)
|
||||
|
||||
@router.get("/getruleproperties/")
|
||||
async def fastapi_get_rule_properties(network: str) -> dict[str, Any]:
|
||||
@router.get("/getruleproperties/", summary="获取规则属性", description="获取指定网络中的规则属性信息")
|
||||
async def fastapi_get_rule_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取规则属性。
|
||||
|
||||
返回指定网络中的规则对象属性信息。
|
||||
"""
|
||||
return get_rule(network)
|
||||
|
||||
@router.post("/setruleproperties/", response_model=None)
|
||||
async def fastapi_set_rule_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setruleproperties/", response_model=None, summary="设置规则属性", description="更新指定网络中的规则属性")
|
||||
async def fastapi_set_rule_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置规则属性。
|
||||
|
||||
更新指定网络中的规则属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_rule(network, ChangeSet(props))
|
||||
|
||||
@@ -1,42 +1,95 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_curve,
|
||||
delete_curve,
|
||||
get_curve,
|
||||
get_curve_schema,
|
||||
get_curves,
|
||||
is_curve,
|
||||
set_curve,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getcurveschema")
|
||||
async def fastapi_get_curve_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getcurveschema", summary="获取曲线架构", description="获取网络中曲线对象的架构定义")
|
||||
async def fastapi_get_curve_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取曲线架构。
|
||||
|
||||
返回指定网络中曲线对象的属性架构定义。
|
||||
"""
|
||||
return get_curve_schema(network)
|
||||
|
||||
@router.post("/addcurve/", response_model=None)
|
||||
async def fastapi_add_curve(network: str, curve: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addcurve/", response_model=None, summary="添加曲线", description="在网络中添加一条新的曲线")
|
||||
async def fastapi_add_curve(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
curve: str = Query(..., description="曲线ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加曲线。
|
||||
|
||||
在指定网络中创建一条新的曲线,并设置其初始属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {
|
||||
"id": curve,
|
||||
} | props
|
||||
return add_curve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletecurve/", response_model=None)
|
||||
async def fastapi_delete_curve(network: str, curve: str) -> ChangeSet:
|
||||
@router.post("/deletecurve/", response_model=None, summary="删除曲线", description="从网络中删除指定的曲线")
|
||||
async def fastapi_delete_curve(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
curve: str = Query(..., description="曲线ID")
|
||||
) -> ChangeSet:
|
||||
"""删除曲线。
|
||||
|
||||
从指定网络中删除指定的曲线及其相关数据。
|
||||
"""
|
||||
ps = {"id": curve}
|
||||
return delete_curve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getcurveproperties/")
|
||||
async def fastapi_get_curve_properties(network: str, curve: str) -> dict[str, Any]:
|
||||
@router.get("/getcurveproperties/", summary="获取曲线属性", description="获取指定曲线的属性信息")
|
||||
async def fastapi_get_curve_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
curve: str = Query(..., description="曲线ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取曲线属性。
|
||||
|
||||
返回指定曲线的所有属性信息。
|
||||
"""
|
||||
return get_curve(network, curve)
|
||||
|
||||
@router.post("/setcurveproperties/", response_model=None)
|
||||
@router.post("/setcurveproperties/", response_model=None, summary="设置曲线属性", description="更新指定曲线的属性")
|
||||
async def fastapi_set_curve_properties(
|
||||
network: str, curve: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
curve: str = Query(..., description="曲线ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置曲线属性。
|
||||
|
||||
更新指定曲线的属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": curve} | props
|
||||
return set_curve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getcurves/")
|
||||
async def fastapi_get_curves(network: str) -> list[str]:
|
||||
@router.get("/getcurves/", summary="获取所有曲线", description="获取网络中的所有曲线列表")
|
||||
async def fastapi_get_curves(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
|
||||
"""获取所有曲线。
|
||||
|
||||
返回指定网络中的所有曲线ID列表。
|
||||
"""
|
||||
return get_curves(network)
|
||||
|
||||
@router.get("/iscurve/")
|
||||
async def fastapi_is_curve(network: str, curve: str) -> bool:
|
||||
@router.get("/iscurve/", summary="检查曲线存在性", description="检查指定的曲线是否存在")
|
||||
async def fastapi_is_curve(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
curve: str = Query(..., description="曲线ID")
|
||||
) -> bool:
|
||||
"""检查曲线是否存在。
|
||||
|
||||
判断指定的曲线是否在网络中存在。
|
||||
"""
|
||||
return is_curve(network, curve)
|
||||
|
||||
@@ -1,60 +1,137 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
get_energy,
|
||||
get_energy_schema,
|
||||
get_option_v3,
|
||||
get_option_v3_schema,
|
||||
get_pump_energy,
|
||||
get_pump_energy_schema,
|
||||
get_time,
|
||||
get_time_schema,
|
||||
set_energy,
|
||||
set_option_v3,
|
||||
set_pump_energy,
|
||||
set_time,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/gettimeschema")
|
||||
async def fastapi_get_time_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/gettimeschema", summary="获取时间选项架构", description="获取网络中时间选项的架构定义")
|
||||
async def fastapi_get_time_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取时间选项架构。
|
||||
|
||||
返回指定网络中时间相关选项的属性架构定义。
|
||||
"""
|
||||
return get_time_schema(network)
|
||||
|
||||
@router.get("/gettimeproperties/")
|
||||
async def fastapi_get_time_properties(network: str) -> dict[str, Any]:
|
||||
@router.get("/gettimeproperties/", summary="获取时间选项属性", description="获取指定网络中的时间选项属性信息")
|
||||
async def fastapi_get_time_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取时间选项属性。
|
||||
|
||||
返回指定网络中的时间相关选项属性。
|
||||
"""
|
||||
return get_time(network)
|
||||
|
||||
@router.post("/settimeproperties/", response_model=None)
|
||||
async def fastapi_set_time_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/settimeproperties/", response_model=None, summary="设置时间选项属性", description="更新指定网络中的时间选项属性")
|
||||
async def fastapi_set_time_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置时间选项属性。
|
||||
|
||||
更新指定网络中的时间相关选项属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_time(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getenergyschema/")
|
||||
async def fastapi_get_energy_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getenergyschema/", summary="获取能耗选项架构", description="获取网络中能耗选项的架构定义")
|
||||
async def fastapi_get_energy_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取能耗选项架构。
|
||||
|
||||
返回指定网络中能耗相关选项的属性架构定义。
|
||||
"""
|
||||
return get_energy_schema(network)
|
||||
|
||||
@router.get("/getenergyproperties/")
|
||||
async def fastapi_get_energy_properties(network: str) -> dict[str, Any]:
|
||||
@router.get("/getenergyproperties/", summary="获取能耗选项属性", description="获取指定网络中的能耗选项属性信息")
|
||||
async def fastapi_get_energy_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取能耗选项属性。
|
||||
|
||||
返回指定网络中的能耗相关选项属性。
|
||||
"""
|
||||
return get_energy(network)
|
||||
|
||||
@router.post("/setenergyproperties/", response_model=None)
|
||||
async def fastapi_set_energy_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setenergyproperties/", response_model=None, summary="设置能耗选项属性", description="更新指定网络中的能耗选项属性")
|
||||
async def fastapi_set_energy_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置能耗选项属性。
|
||||
|
||||
更新指定网络中的能耗相关选项属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_energy(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getpumpenergyschema/")
|
||||
async def fastapi_get_pump_energy_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getpumpenergyschema/", summary="获取泵能耗选项架构", description="获取网络中泵能耗选项的架构定义")
|
||||
async def fastapi_get_pump_energy_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取泵能耗选项架构。
|
||||
|
||||
返回指定网络中泵能耗相关选项的属性架构定义。
|
||||
"""
|
||||
return get_pump_energy_schema(network)
|
||||
|
||||
@router.get("/getpumpenergyproperties//")
|
||||
async def fastapi_get_pump_energy_proeprties(network: str, pump: str) -> dict[str, Any]:
|
||||
@router.get("/getpumpenergyproperties//", summary="获取泵能耗属性", description="获取指定泵的能耗属性信息")
|
||||
async def fastapi_get_pump_energy_proeprties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="泵ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取泵能耗属性。
|
||||
|
||||
返回指定泵的能耗相关属性。
|
||||
"""
|
||||
return get_pump_energy(network, pump)
|
||||
|
||||
@router.get("/setpumpenergyproperties//", response_model=None)
|
||||
@router.get("/setpumpenergyproperties//", response_model=None, summary="设置泵能耗属性", description="更新指定泵的能耗属性")
|
||||
async def fastapi_set_pump_energy_properties(
|
||||
network: str, pump: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="泵ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置泵能耗属性。
|
||||
|
||||
更新指定泵的能耗相关属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": pump} | props
|
||||
return set_pump_energy(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getoptionschema/")
|
||||
async def fastapi_get_option_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getoptionschema/", summary="获取选项架构", description="获取网络中选项对象的架构定义")
|
||||
async def fastapi_get_option_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取选项架构。
|
||||
|
||||
返回指定网络中选项对象的属性架构定义。
|
||||
"""
|
||||
return get_option_v3_schema(network)
|
||||
|
||||
@router.get("/getoptionproperties/")
|
||||
async def fastapi_get_option_properties(network: str) -> dict[str, Any]:
|
||||
@router.get("/getoptionproperties/", summary="获取选项属性", description="获取指定网络中的选项属性信息")
|
||||
async def fastapi_get_option_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取选项属性。
|
||||
|
||||
返回指定网络中的选项对象属性信息。
|
||||
"""
|
||||
return get_option_v3(network)
|
||||
|
||||
@router.post("/setoptionproperties/", response_model=None)
|
||||
async def fastapi_set_option_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setoptionproperties/", response_model=None, summary="设置选项属性", description="更新指定网络中的选项属性")
|
||||
async def fastapi_set_option_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置选项属性。
|
||||
|
||||
更新指定网络中的选项属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_option_v3(network, ChangeSet(props))
|
||||
|
||||
@@ -1,42 +1,95 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_pattern,
|
||||
delete_pattern,
|
||||
get_pattern,
|
||||
get_pattern_schema,
|
||||
get_patterns,
|
||||
is_pattern,
|
||||
set_pattern,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpatternschema")
|
||||
async def fastapi_get_pattern_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getpatternschema", summary="获取模式架构", description="获取网络中模式对象的架构定义")
|
||||
async def fastapi_get_pattern_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取模式架构。
|
||||
|
||||
返回指定网络中模式对象的属性架构定义。
|
||||
"""
|
||||
return get_pattern_schema(network)
|
||||
|
||||
@router.post("/addpattern/", response_model=None)
|
||||
async def fastapi_add_pattern(network: str, pattern: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addpattern/", response_model=None, summary="添加模式", description="在网络中添加一个新的模式")
|
||||
async def fastapi_add_pattern(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pattern: str = Query(..., description="模式ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加模式。
|
||||
|
||||
在指定网络中创建一个新的模式,并设置其初始属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {
|
||||
"id": pattern,
|
||||
} | props
|
||||
return add_pattern(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletepattern/", response_model=None)
|
||||
async def fastapi_delete_pattern(network: str, pattern: str) -> ChangeSet:
|
||||
@router.post("/deletepattern/", response_model=None, summary="删除模式", description="从网络中删除指定的模式")
|
||||
async def fastapi_delete_pattern(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pattern: str = Query(..., description="模式ID")
|
||||
) -> ChangeSet:
|
||||
"""删除模式。
|
||||
|
||||
从指定网络中删除指定的模式及其相关数据。
|
||||
"""
|
||||
ps = {"id": pattern}
|
||||
return delete_pattern(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpatternproperties/")
|
||||
async def fastapi_get_pattern_properties(network: str, pattern: str) -> dict[str, Any]:
|
||||
@router.get("/getpatternproperties/", summary="获取模式属性", description="获取指定模式的属性信息")
|
||||
async def fastapi_get_pattern_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pattern: str = Query(..., description="模式ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取模式属性。
|
||||
|
||||
返回指定模式的所有属性信息。
|
||||
"""
|
||||
return get_pattern(network, pattern)
|
||||
|
||||
@router.post("/setpatternproperties/", response_model=None)
|
||||
@router.post("/setpatternproperties/", response_model=None, summary="设置模式属性", description="更新指定模式的属性")
|
||||
async def fastapi_set_pattern_properties(
|
||||
network: str, pattern: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pattern: str = Query(..., description="模式ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置模式属性。
|
||||
|
||||
更新指定模式的属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": pattern} | props
|
||||
return set_pattern(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/ispattern/")
|
||||
async def fastapi_is_pattern(network: str, pattern: str) -> bool:
|
||||
@router.get("/ispattern/", summary="检查模式存在性", description="检查指定的模式是否存在")
|
||||
async def fastapi_is_pattern(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pattern: str = Query(..., description="模式ID")
|
||||
) -> bool:
|
||||
"""检查模式是否存在。
|
||||
|
||||
判断指定的模式是否在网络中存在。
|
||||
"""
|
||||
return is_pattern(network, pattern)
|
||||
|
||||
@router.get("/getpatterns/")
|
||||
async def fastapi_get_patterns(network: str) -> list[str]:
|
||||
@router.get("/getpatterns/", summary="获取所有模式", description="获取网络中的所有模式列表")
|
||||
async def fastapi_get_patterns(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
|
||||
"""获取所有模式。
|
||||
|
||||
返回指定网络中的所有模式ID列表。
|
||||
"""
|
||||
return get_patterns(network)
|
||||
|
||||
@@ -1,119 +1,297 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_mixing,
|
||||
add_source,
|
||||
api,
|
||||
delete_mixing,
|
||||
delete_source,
|
||||
get_emitter,
|
||||
get_emitter_schema,
|
||||
get_mixing,
|
||||
get_mixing_schema,
|
||||
get_pipe_reaction,
|
||||
get_pipe_reaction_schema,
|
||||
get_quality,
|
||||
get_quality_schema,
|
||||
get_reaction,
|
||||
get_reaction_schema,
|
||||
get_source,
|
||||
get_source_schema,
|
||||
get_tank_reaction,
|
||||
get_tank_reaction_schema,
|
||||
set_emitter,
|
||||
set_pipe_reaction,
|
||||
set_quality,
|
||||
set_reaction,
|
||||
set_source,
|
||||
set_tank_reaction,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getqualityschema/")
|
||||
async def fastapi_get_quality_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getqualityschema/", summary="获取水质架构", description="获取网络中水质对象的架构定义")
|
||||
async def fastapi_get_quality_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取水质架构。
|
||||
|
||||
返回指定网络中水质对象的属性架构定义。
|
||||
"""
|
||||
return get_quality_schema(network)
|
||||
|
||||
@router.get("/getqualityproperties/")
|
||||
async def fastapi_get_quality_properties(network: str, node: str) -> dict[str, Any]:
|
||||
@router.get("/getqualityproperties/", summary="获取水质属性", description="获取指定节点的水质属性信息")
|
||||
async def fastapi_get_quality_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取水质属性。
|
||||
|
||||
返回指定节点的水质属性信息。
|
||||
"""
|
||||
return get_quality(network, node)
|
||||
|
||||
@router.post("/setqualityproperties/", response_model=None)
|
||||
async def fastapi_set_quality_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setqualityproperties/", response_model=None, summary="设置水质属性", description="更新指定节点的水质属性")
|
||||
async def fastapi_set_quality_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置水质属性。
|
||||
|
||||
更新指定节点的水质属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_quality(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getemitterschema")
|
||||
async def fastapi_get_emitter_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getemitterschema", summary="获取发射器架构", description="获取网络中发射器对象的架构定义")
|
||||
async def fastapi_get_emitter_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取发射器架构。
|
||||
|
||||
返回指定网络中发射器对象的属性架构定义。
|
||||
"""
|
||||
return get_emitter_schema(network)
|
||||
|
||||
@router.get("/getemitterproperties/")
|
||||
async def fastapi_get_emitter_properties(network: str, junction: str) -> dict[str, Any]:
|
||||
@router.get("/getemitterproperties/", summary="获取发射器属性", description="获取指定连接点的发射器属性信息")
|
||||
async def fastapi_get_emitter_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="连接点ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取发射器属性。
|
||||
|
||||
返回指定连接点的发射器属性信息。
|
||||
"""
|
||||
return get_emitter(network, junction)
|
||||
|
||||
@router.post("/setemitterproperties/", response_model=None)
|
||||
@router.post("/setemitterproperties/", response_model=None, summary="设置发射器属性", description="更新指定连接点的发射器属性")
|
||||
async def fastapi_set_emitter_properties(
|
||||
network: str, junction: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="连接点ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置发射器属性。
|
||||
|
||||
更新指定连接点的发射器属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"junction": junction} | props
|
||||
return set_emitter(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getsourcechema/")
|
||||
async def fastapi_get_source_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getsourcechema/", summary="获取水源架构", description="获取网络中水源对象的架构定义")
|
||||
async def fastapi_get_source_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取水源架构。
|
||||
|
||||
返回指定网络中水源对象的属性架构定义。
|
||||
"""
|
||||
return get_source_schema(network)
|
||||
|
||||
@router.get("/getsource/")
|
||||
async def fastapi_get_source(network: str, node: str) -> dict[str, Any]:
|
||||
@router.get("/getsource/", summary="获取水源属性", description="获取指定节点的水源属性信息")
|
||||
async def fastapi_get_source(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取水源属性。
|
||||
|
||||
返回指定节点的水源属性信息。
|
||||
"""
|
||||
return get_source(network, node)
|
||||
|
||||
@router.post("/setsource/", response_model=None)
|
||||
async def fastapi_set_source(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setsource/", response_model=None, summary="设置水源属性", description="更新指定节点的水源属性")
|
||||
async def fastapi_set_source(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置水源属性。
|
||||
|
||||
更新指定节点的水源属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_source(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addsource/", response_model=None)
|
||||
async def fastapi_add_source(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addsource/", response_model=None, summary="添加水源", description="在网络中添加一个新的水源")
|
||||
async def fastapi_add_source(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加水源。
|
||||
|
||||
在指定网络中创建一个新的水源,并设置其初始属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
return add_source(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletesource/", response_model=None)
|
||||
async def fastapi_delete_source(network: str, node: str) -> ChangeSet:
|
||||
@router.post("/deletesource/", response_model=None, summary="删除水源", description="从网络中删除指定节点的水源")
|
||||
async def fastapi_delete_source(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> ChangeSet:
|
||||
"""删除水源。
|
||||
|
||||
从指定网络中删除指定节点的水源。
|
||||
"""
|
||||
props = {"node": node}
|
||||
return delete_source(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getreactionschema/")
|
||||
async def fastapi_get_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getreactionschema/", summary="获取反应架构", description="获取网络中反应对象的架构定义")
|
||||
async def fastapi_get_reaction_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取反应架构。
|
||||
|
||||
返回指定网络中反应对象的属性架构定义。
|
||||
"""
|
||||
return get_reaction_schema(network)
|
||||
|
||||
@router.get("/getreaction/")
|
||||
async def fastapi_get_reaction(network: str) -> dict[str, Any]:
|
||||
@router.get("/getreaction/", summary="获取反应属性", description="获取指定网络中的反应属性信息")
|
||||
async def fastapi_get_reaction(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取反应属性。
|
||||
|
||||
返回指定网络中的反应属性信息。
|
||||
"""
|
||||
return get_reaction(network)
|
||||
|
||||
@router.post("/setreaction/", response_model=None)
|
||||
async def fastapi_set_reaction(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setreaction/", response_model=None, summary="设置反应属性", description="更新指定网络中的反应属性")
|
||||
async def fastapi_set_reaction(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置反应属性。
|
||||
|
||||
更新指定网络中的反应属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_reaction(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getpipereactionschema/")
|
||||
async def fastapi_get_pipe_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getpipereactionschema/", summary="获取管道反应架构", description="获取网络中管道反应对象的架构定义")
|
||||
async def fastapi_get_pipe_reaction_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取管道反应架构。
|
||||
|
||||
返回指定网络中管道反应对象的属性架构定义。
|
||||
"""
|
||||
return get_pipe_reaction_schema(network)
|
||||
|
||||
@router.get("/getpipereaction/")
|
||||
async def fastapi_get_pipe_reaction(network: str, pipe: str) -> dict[str, Any]:
|
||||
@router.get("/getpipereaction/", summary="获取管道反应属性", description="获取指定管道的反应属性信息")
|
||||
async def fastapi_get_pipe_reaction(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取管道反应属性。
|
||||
|
||||
返回指定管道的反应属性信息。
|
||||
"""
|
||||
return get_pipe_reaction(network, pipe)
|
||||
|
||||
@router.post("/setpipereaction/", response_model=None)
|
||||
async def fastapi_set_pipe_reaction(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setpipereaction/", response_model=None, summary="设置管道反应属性", description="更新指定管道的反应属性")
|
||||
async def fastapi_set_pipe_reaction(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置管道反应属性。
|
||||
|
||||
更新指定管道的反应属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_pipe_reaction(network, ChangeSet(props))
|
||||
|
||||
@router.get("/gettankreactionschema/")
|
||||
async def fastapi_get_tank_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/gettankreactionschema/", summary="获取水池反应架构", description="获取网络中水池反应对象的架构定义")
|
||||
async def fastapi_get_tank_reaction_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取水池反应架构。
|
||||
|
||||
返回指定网络中水池反应对象的属性架构定义。
|
||||
"""
|
||||
return get_tank_reaction_schema(network)
|
||||
|
||||
@router.get("/gettankreaction/")
|
||||
async def fastapi_get_tank_reaction(network: str, tank: str) -> dict[str, Any]:
|
||||
@router.get("/gettankreaction/", summary="获取水池反应属性", description="获取指定水池的反应属性信息")
|
||||
async def fastapi_get_tank_reaction(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水池ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取水池反应属性。
|
||||
|
||||
返回指定水池的反应属性信息。
|
||||
"""
|
||||
return get_tank_reaction(network, tank)
|
||||
|
||||
@router.post("/settankreaction/", response_model=None)
|
||||
async def fastapi_set_tank_reaction(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/settankreaction/", response_model=None, summary="设置水池反应属性", description="更新指定水池的反应属性")
|
||||
async def fastapi_set_tank_reaction(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置水池反应属性。
|
||||
|
||||
更新指定水池的反应属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_tank_reaction(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getmixingschema/")
|
||||
async def fastapi_get_mixing_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getmixingschema/", summary="获取混合架构", description="获取网络中混合对象的架构定义")
|
||||
async def fastapi_get_mixing_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取混合架构。
|
||||
|
||||
返回指定网络中混合对象的属性架构定义。
|
||||
"""
|
||||
return get_mixing_schema(network)
|
||||
|
||||
@router.get("/getmixing/")
|
||||
async def fastapi_get_mixing(network: str, tank: str) -> dict[str, Any]:
|
||||
@router.get("/getmixing/", summary="获取混合属性", description="获取指定水池的混合属性信息")
|
||||
async def fastapi_get_mixing(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水池ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取混合属性。
|
||||
|
||||
返回指定水池的混合属性信息。
|
||||
"""
|
||||
return get_mixing(network, tank)
|
||||
|
||||
@router.post("/setmixing/", response_model=None)
|
||||
async def fastapi_set_mixing(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setmixing/", response_model=None, summary="设置混合属性", description="更新指定水池的混合属性")
|
||||
async def fastapi_set_mixing(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置混合属性。
|
||||
|
||||
更新指定水池的混合属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return api.set_mixing(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addmixing/", response_model=None)
|
||||
async def fastapi_add_mixing(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addmixing/", response_model=None, summary="添加混合", description="在网络中添加一个新的混合")
|
||||
async def fastapi_add_mixing(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加混合。
|
||||
|
||||
在指定网络中创建一个新的混合,并设置其初始属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
return add_mixing(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletemixing/", response_model=None)
|
||||
async def fastapi_delete_mixing(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/deletemixing/", response_model=None, summary="删除混合", description="从网络中删除指定的混合")
|
||||
async def fastapi_delete_mixing(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""删除混合。
|
||||
|
||||
从指定网络中删除指定的混合及其相关数据。
|
||||
"""
|
||||
props = await req.json()
|
||||
return delete_mixing(network, ChangeSet(props))
|
||||
|
||||
@@ -1,76 +1,180 @@
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi import APIRouter, Request, Query, Path, Body, Response
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_label,
|
||||
add_vertex,
|
||||
delete_label,
|
||||
delete_vertex,
|
||||
get_all_vertex_links,
|
||||
get_all_vertices,
|
||||
get_backdrop,
|
||||
get_backdrop_schema,
|
||||
get_label,
|
||||
get_label_schema,
|
||||
get_vertex,
|
||||
get_vertex_schema,
|
||||
set_backdrop,
|
||||
set_label,
|
||||
set_vertex,
|
||||
)
|
||||
from fastapi.responses import PlainTextResponse
|
||||
import json
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getvertexschema/")
|
||||
async def fastapi_get_vertex_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getvertexschema/", summary="获取图形元素架构", description="获取网络中图形元素对象的架构定义")
|
||||
async def fastapi_get_vertex_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取图形元素架构。
|
||||
|
||||
返回指定网络中图形元素对象的属性架构定义。
|
||||
"""
|
||||
return get_vertex_schema(network)
|
||||
|
||||
@router.get("/getvertexproperties/")
|
||||
async def fastapi_get_vertex_properties(network: str, link: str) -> dict[str, Any]:
|
||||
@router.get("/getvertexproperties/", summary="获取图形元素属性", description="获取指定图形元素的属性信息")
|
||||
async def fastapi_get_vertex_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="图形元素链接")
|
||||
) -> dict[str, Any]:
|
||||
"""获取图形元素属性。
|
||||
|
||||
返回指定图形元素的所有属性信息。
|
||||
"""
|
||||
return get_vertex(network, link)
|
||||
|
||||
@router.post("/setvertexproperties/", response_model=None)
|
||||
async def fastapi_set_vertex_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setvertexproperties/", response_model=None, summary="设置图形元素属性", description="更新指定图形元素的属性")
|
||||
async def fastapi_set_vertex_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置图形元素属性。
|
||||
|
||||
更新指定图形元素的属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_vertex(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addvertex/", response_model=None)
|
||||
async def fastapi_add_vertex(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addvertex/", response_model=None, summary="添加图形元素", description="在网络中添加一个新的图形元素")
|
||||
async def fastapi_add_vertex(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加图形元素。
|
||||
|
||||
在指定网络中创建一个新的图形元素,并设置其初始属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
return add_vertex(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletevertex/", response_model=None)
|
||||
async def fastapi_delete_vertex(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/deletevertex/", response_model=None, summary="删除图形元素", description="从网络中删除指定的图形元素")
|
||||
async def fastapi_delete_vertex(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""删除图形元素。
|
||||
|
||||
从指定网络中删除指定的图形元素及其相关数据。
|
||||
"""
|
||||
props = await req.json()
|
||||
return delete_vertex(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallvertexlinks/", response_class=PlainTextResponse)
|
||||
async def fastapi_get_all_vertex_links(network: str) -> list[str]:
|
||||
@router.get("/getallvertexlinks/", response_class=PlainTextResponse, summary="获取所有图形元素链接", description="获取网络中的所有图形元素链接列表")
|
||||
async def fastapi_get_all_vertex_links(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
|
||||
"""获取所有图形元素链接。
|
||||
|
||||
返回指定网络中的所有图形元素链接列表。
|
||||
"""
|
||||
return json.dumps(get_all_vertex_links(network))
|
||||
|
||||
@router.get("/getallvertices/", response_class=PlainTextResponse)
|
||||
async def fastapi_get_all_vertices(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getallvertices/", response_class=PlainTextResponse, summary="获取所有图形元素", description="获取网络中的所有图形元素详细信息")
|
||||
async def fastapi_get_all_vertices(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[str, Any]]:
|
||||
"""获取所有图形元素。
|
||||
|
||||
返回指定网络中的所有图形元素详细信息。
|
||||
"""
|
||||
return json.dumps(get_all_vertices(network))
|
||||
|
||||
@router.get("/getlabelschema/")
|
||||
async def fastapi_get_label_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getlabelschema/", summary="获取标签架构", description="获取网络中标签对象的架构定义")
|
||||
async def fastapi_get_label_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取标签架构。
|
||||
|
||||
返回指定网络中标签对象的属性架构定义。
|
||||
"""
|
||||
return get_label_schema(network)
|
||||
|
||||
@router.get("/getlabelproperties/")
|
||||
@router.get("/getlabelproperties/", summary="获取标签属性", description="获取指定坐标处的标签属性信息")
|
||||
async def fastapi_get_label_properties(
|
||||
network: str, x: float, y: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
x: float = Query(..., description="X坐标"),
|
||||
y: float = Query(..., description="Y坐标")
|
||||
) -> dict[str, Any]:
|
||||
"""获取标签属性。
|
||||
|
||||
返回指定坐标处的标签属性信息。
|
||||
"""
|
||||
return get_label(network, x, y)
|
||||
|
||||
@router.post("/setlabelproperties/", response_model=None)
|
||||
async def fastapi_set_label_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setlabelproperties/", response_model=None, summary="设置标签属性", description="更新指定标签的属性")
|
||||
async def fastapi_set_label_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置标签属性。
|
||||
|
||||
更新指定标签的属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_label(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addlabel/", response_model=None)
|
||||
async def fastapi_add_label(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addlabel/", response_model=None, summary="添加标签", description="在网络中添加一个新的标签")
|
||||
async def fastapi_add_label(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加标签。
|
||||
|
||||
在指定网络中创建一个新的标签,并设置其初始属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
return add_label(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletelabel/", response_model=None)
|
||||
async def fastapi_delete_label(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/deletelabel/", response_model=None, summary="删除标签", description="从网络中删除指定的标签")
|
||||
async def fastapi_delete_label(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""删除标签。
|
||||
|
||||
从指定网络中删除指定的标签及其相关数据。
|
||||
"""
|
||||
props = await req.json()
|
||||
return delete_label(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getbackdropschema/")
|
||||
async def fastapi_get_backdrop_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getbackdropschema/", summary="获取背景架构", description="获取网络中背景对象的架构定义")
|
||||
async def fastapi_get_backdrop_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""获取背景架构。
|
||||
|
||||
返回指定网络中背景对象的属性架构定义。
|
||||
"""
|
||||
return get_backdrop_schema(network)
|
||||
|
||||
@router.get("/getbackdropproperties/")
|
||||
async def fastapi_get_backdrop_properties(network: str) -> dict[str, Any]:
|
||||
@router.get("/getbackdropproperties/", summary="获取背景属性", description="获取指定网络的背景属性信息")
|
||||
async def fastapi_get_backdrop_properties(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取背景属性。
|
||||
|
||||
返回指定网络的背景属性信息。
|
||||
"""
|
||||
return get_backdrop(network)
|
||||
|
||||
@router.post("/setbackdropproperties/", response_model=None)
|
||||
async def fastapi_set_backdrop_properties(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setbackdropproperties/", response_model=None, summary="设置背景属性", description="更新指定网络的背景属性")
|
||||
async def fastapi_set_backdrop_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置背景属性。
|
||||
|
||||
更新指定网络的背景属性值。
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_backdrop(network, ChangeSet(props))
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
from typing import Any, List, Dict, Optional
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone, time as dt_time
|
||||
import msgpack
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from py_linq import Enumerable
|
||||
|
||||
import app.infra.db.influxdb.api as influxdb_api
|
||||
import app.services.time_api as time_api
|
||||
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Basic Node/Link Latest Record Queries
|
||||
|
||||
@router.get("/querynodelatestrecordbyid/")
|
||||
async def fastapi_query_node_latest_record_by_id(id: str) -> Any:
|
||||
return influxdb_api.query_latest_record_by_ID(id, type="node")
|
||||
|
||||
@router.get("/querylinklatestrecordbyid/")
|
||||
async def fastapi_query_link_latest_record_by_id(id: str) -> Any:
|
||||
return influxdb_api.query_latest_record_by_ID(id, type="link")
|
||||
|
||||
@router.get("/queryscadalatestrecordbyid/")
|
||||
async def fastapi_query_scada_latest_record_by_id(id: str) -> Any:
|
||||
return influxdb_api.query_latest_record_by_ID(id, type="scada")
|
||||
|
||||
# Time-based Queries
|
||||
|
||||
@router.get("/queryallrecordsbytime/")
|
||||
async def fastapi_query_all_records_by_time(querytime: str) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_all_records_by_time(query_time=querytime)
|
||||
return {"nodes": results[0], "links": results[1]}
|
||||
|
||||
@router.get("/queryallrecordsbytimeproperty/")
|
||||
async def fastapi_query_all_record_by_time_property(
|
||||
querytime: str, type: str, property: str, bucket: str = "realtime_simulation_result"
|
||||
) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_all_record_by_time_property(
|
||||
query_time=querytime, type=type, property=property, bucket=bucket
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
@router.get("/queryallschemerecordsbytimeproperty/")
|
||||
async def fastapi_query_all_scheme_record_by_time_property(
|
||||
querytime: str,
|
||||
type: str,
|
||||
property: str,
|
||||
schemename: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> dict[str, list]:
|
||||
"""
|
||||
查询指定方案某一时刻的所有记录,查询 'node' 或 'link' 的某一属性值
|
||||
"""
|
||||
results: list = influxdb_api.query_all_scheme_record_by_time_property(
|
||||
query_time=querytime,
|
||||
type=type,
|
||||
property=property,
|
||||
scheme_name=schemename,
|
||||
bucket=bucket,
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
@router.get("/querysimulationrecordsbyidtime/")
|
||||
async def fastapi_query_simulation_record_by_ids_time(
|
||||
id: str, querytime: str, type: str, bucket: str = "realtime_simulation_result"
|
||||
) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_simulation_result_by_ID_time(
|
||||
ID=id, type=type, query_time=querytime, bucket=bucket
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
@router.get("/queryschemesimulationrecordsbyidtime/")
|
||||
async def fastapi_query_scheme_simulation_record_by_ids_time(
|
||||
scheme_name: str,
|
||||
id: str,
|
||||
querytime: str,
|
||||
type: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_scheme_simulation_result_by_ID_time(
|
||||
scheme_name=scheme_name, ID=id, type=type, query_time=querytime, bucket=bucket
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
# Date-based Queries with Caching
|
||||
|
||||
@router.get("/queryallrecordsbydate/")
|
||||
async def fastapi_query_all_records_by_date(querydate: str) -> dict:
|
||||
is_today_or_future = time_api.is_today_or_future(querydate)
|
||||
logger.info(f"isToday or future: {is_today_or_future}")
|
||||
|
||||
cache_key = f"queryallrecordsbydate_{querydate}"
|
||||
|
||||
if not is_today_or_future:
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
logger.info("return from cache redis")
|
||||
return results
|
||||
|
||||
logger.info("query from influxdb")
|
||||
nodes_links: tuple = influxdb_api.query_all_records_by_date(query_date=querydate)
|
||||
results = {"nodes": nodes_links[0], "links": nodes_links[1]}
|
||||
|
||||
if not is_today_or_future:
|
||||
logger.info("save to cache redis")
|
||||
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
|
||||
|
||||
logger.info("return results")
|
||||
return results
|
||||
|
||||
@router.get("/queryallrecordsbytimerange/")
|
||||
async def fastapi_query_all_records_by_time_range(
|
||||
starttime: str, endtime: str
|
||||
) -> dict[str, list]:
|
||||
cache_key = f"queryallrecordsbytimerange_{starttime}_{endtime}"
|
||||
|
||||
if not time_api.is_today_or_future(starttime):
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
nodes_links: tuple = influxdb_api.query_all_records_by_time_range(
|
||||
starttime=starttime, endtime=endtime
|
||||
)
|
||||
results = {"nodes": nodes_links[0], "links": nodes_links[1]}
|
||||
|
||||
if not time_api.is_today_or_future(starttime):
|
||||
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryallrecordsbydatewithtype/")
|
||||
async def fastapi_query_all_records_by_date_with_type(
|
||||
querydate: str, querytype: str
|
||||
) -> list:
|
||||
cache_key = f"queryallrecordsbydatewithtype_{querydate}_{querytype}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
results = influxdb_api.query_all_records_by_date_with_type(
|
||||
query_date=querydate, query_type=querytype
|
||||
)
|
||||
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryallrecordsbyidsdatetype/")
|
||||
async def fastapi_query_all_records_by_ids_date_type(
|
||||
ids: str, querydate: str, querytype: str
|
||||
) -> list:
|
||||
cache_key = f"queryallrecordsbydatewithtype_{querydate}_{querytype}"
|
||||
data = redis_client.get(cache_key)
|
||||
|
||||
results = []
|
||||
if data:
|
||||
results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
else:
|
||||
results = influxdb_api.query_all_records_by_date_with_type(
|
||||
query_date=querydate, query_type=querytype
|
||||
)
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
query_ids = ids.split(",")
|
||||
# Using Enumerable from py_linq as in original code
|
||||
e_results = Enumerable(results)
|
||||
lst_results = e_results.where(lambda x: x["ID"] in query_ids).to_list()
|
||||
|
||||
return lst_results
|
||||
|
||||
@router.get("/queryallrecordsbydateproperty/")
|
||||
async def fastapi_query_all_records_by_date_property(
|
||||
querydate: str, querytype: str, property: str
|
||||
) -> list[dict]:
|
||||
cache_key = f"queryallrecordsbydateproperty_{querydate}_{querytype}_{property}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
result_dict = influxdb_api.query_all_record_by_date_property(
|
||||
query_date=querydate, type=querytype, property=property
|
||||
)
|
||||
packed = msgpack.packb(result_dict, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
return result_dict
|
||||
|
||||
# Curve Queries
|
||||
|
||||
@router.get("/querynodecurvebyidpropertydaterange/")
|
||||
async def fastapi_query_node_curve_by_id_property_daterange(
|
||||
id: str, prop: str, startdate: str, enddate: str
|
||||
):
|
||||
return influxdb_api.query_curve_by_ID_property_daterange(
|
||||
id, type="node", property=prop, start_date=startdate, end_date=enddate
|
||||
)
|
||||
|
||||
@router.get("/querylinkcurvebyidpropertydaterange/")
|
||||
async def fastapi_query_link_curve_by_id_property_daterange(
|
||||
id: str, prop: str, startdate: str, enddate: str
|
||||
):
|
||||
return influxdb_api.query_curve_by_ID_property_daterange(
|
||||
id, type="link", property=prop, start_date=startdate, end_date=enddate
|
||||
)
|
||||
|
||||
# SCADA Data Queries
|
||||
|
||||
@router.get("/queryscadadatabydeviceidandtime/")
|
||||
async def fastapi_query_scada_data_by_device_id_and_time(ids: str, querytime: str):
|
||||
query_ids = ids.split(",")
|
||||
logger.info(querytime)
|
||||
return influxdb_api.query_SCADA_data_by_device_ID_and_time(
|
||||
query_ids_list=query_ids, query_time=querytime
|
||||
)
|
||||
|
||||
@router.get("/queryscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/queryfillingscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_filling_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_filling_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/querycleaningscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_cleaning_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_cleaning_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/querysimulationscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_simulation_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_simulation_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/querycleanedscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_cleaned_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_cleaned_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/queryscadadatabydeviceidanddate/")
|
||||
async def fastapi_query_scada_data_by_device_id_and_date(ids: str, querydate: str):
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_SCADA_data_by_device_ID_and_date(
|
||||
query_ids_list=query_ids, query_date=querydate
|
||||
)
|
||||
|
||||
@router.get("/queryallscadarecordsbydate/")
|
||||
async def fastapi_query_all_scada_records_by_date(querydate: str):
|
||||
is_today_or_future = time_api.is_today_or_future(querydate)
|
||||
logger.info(f"isToday or future: {is_today_or_future}")
|
||||
|
||||
cache_key = f"queryallscadarecordsbydate_{querydate}"
|
||||
|
||||
if not is_today_or_future:
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
logger.info("return from cache redis")
|
||||
return loaded_dict
|
||||
|
||||
logger.info("query from influxdb")
|
||||
result_dict = influxdb_api.query_all_SCADA_records_by_date(query_date=querydate)
|
||||
|
||||
if not is_today_or_future:
|
||||
logger.info("save to cache redis")
|
||||
packed = msgpack.packb(result_dict, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
logger.info("return results")
|
||||
return result_dict
|
||||
|
||||
@router.get("/queryallschemeallrecords/")
|
||||
async def fastapi_query_all_scheme_all_records(
|
||||
schemetype: str, schemename: str, querydate: str
|
||||
) -> tuple:
|
||||
cache_key = f"queryallschemeallrecords_{schemetype}_{schemename}_{querydate}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
results = influxdb_api.query_scheme_all_record(
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryschemeallrecordsproperty/")
|
||||
async def fastapi_query_all_scheme_all_records_property(
|
||||
schemetype: str, schemename: str, querydate: str, querytype: str, queryproperty: str
|
||||
) -> Optional[List]:
|
||||
cache_key = f"queryallschemeallrecords_{schemetype}_{schemename}_{querydate}"
|
||||
data = redis_client.get(cache_key)
|
||||
all_results = None
|
||||
if data:
|
||||
all_results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
else:
|
||||
all_results = influxdb_api.query_scheme_all_record(
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(all_results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
results = None
|
||||
if querytype == "node":
|
||||
results = all_results[0]
|
||||
elif querytype == "link":
|
||||
results = all_results[1]
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryinfluxdbbuckets/")
|
||||
async def fastapi_query_influxdb_buckets():
|
||||
return influxdb_api.query_buckets()
|
||||
|
||||
@router.get("/queryinfluxdbbucketmeasurements/")
|
||||
async def fastapi_query_influxdb_bucket_measurements(bucket: str):
|
||||
return influxdb_api.query_measurements(bucket=bucket)
|
||||
|
||||
############################################################
|
||||
# download history data
|
||||
############################################################
|
||||
|
||||
class Download_History_Data_Manually(BaseModel):
|
||||
"""
|
||||
download_date:样式如 datetime(2025, 5, 4)
|
||||
"""
|
||||
|
||||
download_date: datetime
|
||||
|
||||
|
||||
@router.post("/download_history_data_manually/")
|
||||
async def fastapi_download_history_data_manually(
|
||||
data: Download_History_Data_Manually,
|
||||
) -> None:
|
||||
item = data.dict()
|
||||
tz = timezone(timedelta(hours=8))
|
||||
begin_dt = datetime.combine(item.get("download_date").date(), dt_time.min).replace(
|
||||
tzinfo=tz
|
||||
)
|
||||
end_dt = datetime.combine(item.get("download_date").date(), dt_time(23, 59, 59)).replace(
|
||||
tzinfo=tz
|
||||
)
|
||||
|
||||
begin_time = begin_dt.isoformat()
|
||||
end_time = end_dt.isoformat()
|
||||
|
||||
influxdb_api.download_history_data_manually(
|
||||
begin_time=begin_time, end_time=end_time
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List, Any
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from app.native.api import ChangeSet
|
||||
from fastapi import APIRouter, Request, HTTPException, Query, Body
|
||||
from app.services.tjnetwork import (
|
||||
ChangeSet,
|
||||
get_all_extension_data_keys,
|
||||
get_all_extension_data,
|
||||
get_extension_data,
|
||||
@@ -10,20 +10,93 @@ from app.services.tjnetwork import (
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getallextensiondatakeys/")
|
||||
async def get_all_extension_data_keys_endpoint(network: str) -> list[str]:
|
||||
@router.get(
|
||||
"/getallextensiondatakeys/",
|
||||
summary="获取所有扩展数据键",
|
||||
description="获取指定网络的所有扩展数据的键列表"
|
||||
)
|
||||
async def get_all_extension_data_keys_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[str]:
|
||||
"""
|
||||
获取所有扩展数据键。
|
||||
|
||||
返回指定网络中所有可用的扩展数据键。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
扩展数据键列表
|
||||
"""
|
||||
return get_all_extension_data_keys(network)
|
||||
|
||||
@router.get("/getallextensiondata/")
|
||||
async def get_all_extension_data_endpoint(network: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getallextensiondata/",
|
||||
summary="获取所有扩展数据",
|
||||
description="获取指定网络的所有扩展数据"
|
||||
)
|
||||
async def get_all_extension_data_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取所有扩展数据。
|
||||
|
||||
返回指定网络的所有扩展数据及其值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
扩展数据字典
|
||||
"""
|
||||
return get_all_extension_data(network)
|
||||
|
||||
@router.get("/getextensiondata/")
|
||||
async def get_extension_data_endpoint(network: str, key: str) -> str | None:
|
||||
@router.get(
|
||||
"/getextensiondata/",
|
||||
summary="获取指定扩展数据",
|
||||
description="获取指定网络中指定键的扩展数据值"
|
||||
)
|
||||
async def get_extension_data_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
key: str = Query(..., description="扩展数据键")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取指定扩展数据。
|
||||
|
||||
返回指定网络中指定键对应的扩展数据值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
key: 扩展数据键
|
||||
|
||||
Returns:
|
||||
扩展数据值,如果不存在返回None
|
||||
"""
|
||||
return get_extension_data(network, key)
|
||||
|
||||
@router.post("/setextensiondata/", response_model=None)
|
||||
async def set_extension_data_endpoint(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setextensiondata/",
|
||||
response_model=None,
|
||||
summary="设置扩展数据",
|
||||
description="设置指定网络中的扩展数据"
|
||||
)
|
||||
async def set_extension_data_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置扩展数据。
|
||||
|
||||
在指定网络中设置扩展数据,并返回变更集信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 包含扩展数据的请求体
|
||||
|
||||
Returns:
|
||||
变更集信息
|
||||
"""
|
||||
props = await req.json()
|
||||
print(props)
|
||||
cs = set_extension_data(network, ChangeSet(props))
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_username
|
||||
from app.services.leakage_identifier import (
|
||||
get_leakage_identify_scheme_detail,
|
||||
list_leakage_identify_schemes,
|
||||
run_leakage_identification,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
DEFAULT_N_WORKERS = max(1, min((os.cpu_count() or 1) - 1, 4))
|
||||
|
||||
|
||||
class LeakageIdentifyRequest(BaseModel):
|
||||
"""漏损识别请求模型"""
|
||||
network: str = Field(..., description="管网名称(或数据库名称)")
|
||||
observed_pressure_data: str | dict[str, list[Any]] | list[dict[str, Any]] | None = Field(
|
||||
None, description="观测的压力数据"
|
||||
)
|
||||
start_time: float = Field(0, description="起始时间(小时)")
|
||||
duration: float = Field(24, description="持续时间(小时)")
|
||||
timestep: float = Field(5, description="时间步长(分钟)")
|
||||
q_sum: float = Field(0.2, description="总流量(m3/s)")
|
||||
q_sum_unit: str = Field("m3/s", description="流量单位")
|
||||
output_dir: str = Field("db_inp", description="输出目录")
|
||||
pop_size: int = Field(50, description="种群大小")
|
||||
max_gen: int = Field(100, description="最大代数")
|
||||
n_workers: int = Field(DEFAULT_N_WORKERS, description="工作线程数")
|
||||
output_flow_unit: str = Field("m3/s", description="输出流量单位")
|
||||
dma_count: int | None = Field(None, description="DMA区域数量")
|
||||
scada_start: datetime | None = Field(None, description="SCADA数据起始时间")
|
||||
scada_end: datetime | None = Field(None, description="SCADA数据结束时间")
|
||||
sensor_nodes: list[str] | None = Field(None, description="传感器节点列表")
|
||||
scheme_name: str | None = Field(None, description="方案名称")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/identify/",
|
||||
summary="执行漏损识别",
|
||||
description="基于压力观测数据和遗传算法识别管网中的漏损位置和大小"
|
||||
)
|
||||
async def identify_leakage(
|
||||
data: LeakageIdentifyRequest = Body(..., description="漏损识别请求数据"),
|
||||
username: str = Depends(get_current_keycloak_username),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
执行漏损识别分析。
|
||||
|
||||
使用遗传算法对比模型计算和实测压力数据,
|
||||
识别管网中的漏损节点和漏水量。
|
||||
|
||||
Args:
|
||||
data: 包含管网名称(或数据库名称)、压力数据及优化参数的请求体
|
||||
username: 当前认证用户名
|
||||
|
||||
Returns:
|
||||
包含识别结果的字典
|
||||
|
||||
Raises:
|
||||
HTTPException: 当处理过程中发生错误时
|
||||
"""
|
||||
try:
|
||||
return run_leakage_identification(**data.model_dump(), username=username)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schemes/",
|
||||
summary="查询漏损识别方案列表",
|
||||
description="获取指定网络的所有漏损识别方案"
|
||||
)
|
||||
async def query_leakage_schemes(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
query_date: datetime | None = Query(None, description="查询日期(可选)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取漏损识别方案列表。
|
||||
|
||||
查询指定网络的所有已配置的漏损识别方案,
|
||||
可按日期进行筛选。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
query_date: 查询日期(可选)
|
||||
|
||||
Returns:
|
||||
漏损识别方案列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询失败时
|
||||
"""
|
||||
try:
|
||||
return list_leakage_identify_schemes(network=network, query_date=query_date)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schemes/{scheme_name}",
|
||||
summary="获取漏损识别方案详情",
|
||||
description="获取指定漏损识别方案的详细信息"
|
||||
)
|
||||
async def query_leakage_scheme_detail(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scheme_name: str = Path(..., description="漏损识别方案名称")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取漏损识别方案详情。
|
||||
|
||||
查询指定漏损识别方案的完整配置和参数信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
scheme_name: 漏损识别方案名称
|
||||
|
||||
Returns:
|
||||
包含方案详情的字典
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询失败时
|
||||
"""
|
||||
try:
|
||||
return get_leakage_identify_scheme_detail(
|
||||
network=network, scheme_name=scheme_name
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
||||
import psycopg
|
||||
from psycopg import AsyncConnection
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -19,17 +20,22 @@ from app.domain.schemas.metadata import (
|
||||
ProjectMetaResponse,
|
||||
ProjectSummaryResponse,
|
||||
)
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/meta/project", response_model=ProjectMetaResponse)
|
||||
@router.get("/meta/project", summary="获取项目元数据", description="获取当前项目的元数据和配置信息", response_model=ProjectMetaResponse)
|
||||
async def get_project_metadata(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
"""
|
||||
获取项目元数据
|
||||
|
||||
返回当前项目的完整元数据,包括项目基本信息和GeoServer配置
|
||||
"""
|
||||
project = await metadata_repo.get_project_by_id(ctx.project_id)
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
@@ -53,17 +59,23 @@ async def get_project_metadata(
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
map_extent=project.map_extent,
|
||||
status=project.status,
|
||||
project_role=ctx.project_role,
|
||||
geoserver=geoserver_payload,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
|
||||
@router.get("/meta/projects", summary="列出用户项目", description="获取当前用户有权限的所有项目列表", response_model=list[ProjectSummaryResponse])
|
||||
async def list_user_projects(
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
"""
|
||||
列出用户的所有项目
|
||||
|
||||
返回当前用户有权限访问的项目摘要列表
|
||||
"""
|
||||
try:
|
||||
projects = await metadata_repo.list_projects_for_user(current_user.id)
|
||||
except SQLAlchemyError as exc:
|
||||
@@ -90,12 +102,33 @@ async def list_user_projects(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/meta/db/health")
|
||||
@router.get("/meta/db/health", summary="检查数据库健康状态", description="检查项目数据库连接的健康状况")
|
||||
async def project_db_health(
|
||||
pg_session: AsyncSession = Depends(get_project_pg_session),
|
||||
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
await pg_session.execute(text("SELECT 1"))
|
||||
async with ts_conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
"""
|
||||
检查数据库健康状态
|
||||
|
||||
检查PostgreSQL和TimescaleDB数据库的连接状态
|
||||
"""
|
||||
try:
|
||||
await pg_session.execute(text("SELECT 1"))
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error("Project PostgreSQL health check failed", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project PostgreSQL health check failed: {exc}",
|
||||
) from exc
|
||||
|
||||
try:
|
||||
async with ts_conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
except psycopg.Error as exc:
|
||||
logger.error("Project TimescaleDB health check failed", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project TimescaleDB health check failed: {exc}",
|
||||
) from exc
|
||||
|
||||
return {"postgres": "ok", "timescale": "ok"}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
import random
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import status
|
||||
from pydantic import BaseModel
|
||||
@@ -12,8 +11,13 @@ from app.services.tjnetwork import (
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/getjson/")
|
||||
@router.get("/getjson/", summary="获取JSON示例", description="获取JSON格式响应示例")
|
||||
async def fastapi_get_json():
|
||||
"""
|
||||
获取JSON示例
|
||||
|
||||
返回示例JSON格式的响应
|
||||
"""
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
@@ -24,32 +28,37 @@ async def fastapi_get_json():
|
||||
)
|
||||
|
||||
|
||||
@router.get("/getallsensorplacements/")
|
||||
async def fastapi_get_all_sensor_placements(network: str) -> list[dict[Any, Any]]:
|
||||
@router.get("/getallsensorplacements/", summary="获取所有传感器位置", description="获取网络中所有传感器的放置位置信息")
|
||||
async def fastapi_get_all_sensor_placements(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
|
||||
"""
|
||||
获取所有传感器位置
|
||||
|
||||
返回网络中所有传感器的放置位置及其配置信息
|
||||
"""
|
||||
return get_all_sensor_placements(network)
|
||||
|
||||
|
||||
@router.get("/getallburstlocateresults/")
|
||||
async def fastapi_get_all_burst_locate_results(network: str) -> list[dict[Any, Any]]:
|
||||
@router.get("/getallburstlocateresults/", summary="获取所有爆管定位结果", description="获取网络中所有爆管定位的分析结果")
|
||||
async def fastapi_get_all_burst_locate_results(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
|
||||
"""
|
||||
获取所有爆管定位结果
|
||||
|
||||
返回网络中所有的爆管定位分析结果
|
||||
"""
|
||||
return get_all_burst_locate_results(network)
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
"""测试数据模型"""
|
||||
str_info: str
|
||||
|
||||
|
||||
@router.post("/test_dict/")
|
||||
@router.post("/test_dict/", summary="测试字典处理", description="测试处理字典类型数据")
|
||||
async def fastapi_test_dict(data: Item) -> dict[str, str]:
|
||||
"""
|
||||
测试字典处理
|
||||
|
||||
接收Item模型,返回其字典格式
|
||||
"""
|
||||
item = data.dict()
|
||||
return item
|
||||
|
||||
@router.get("/getrealtimedata/")
|
||||
async def fastapi_get_realtimedata():
|
||||
data = [random.randint(0, 100) for _ in range(100)]
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/getsimulationresult/")
|
||||
async def fastapi_get_simulationresult():
|
||||
data = [random.randint(0, 100) for _ in range(100)]
|
||||
return data
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
calculate_demand_to_network,
|
||||
calculate_demand_to_nodes,
|
||||
calculate_demand_to_region,
|
||||
get_demand,
|
||||
get_demand_schema,
|
||||
set_demand,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -8,21 +17,54 @@ router = APIRouter()
|
||||
# demand 9.[DEMANDS]
|
||||
############################################################
|
||||
|
||||
@router.get("/getdemandschema")
|
||||
async def fastapi_get_demand_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getdemandschema",
|
||||
summary="获取需水量属性架构",
|
||||
description="获取指定水网中需水量(Demand)的属性架构定义"
|
||||
)
|
||||
async def fastapi_get_demand_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取需水量属性架构。
|
||||
|
||||
返回指定水网的需水量属性架构,包括所有可配置的属性及其类型定义。
|
||||
"""
|
||||
return get_demand_schema(network)
|
||||
|
||||
|
||||
@router.get("/getdemandproperties/")
|
||||
async def fastapi_get_demand_properties(network: str, junction: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getdemandproperties/",
|
||||
summary="获取需水量属性",
|
||||
description="获取指定水网中节点的需水量属性信息"
|
||||
)
|
||||
async def fastapi_get_demand_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取节点的需水量属性。
|
||||
|
||||
返回指定节点的所有需水量信息,包括需水量值、水压等级等。
|
||||
"""
|
||||
return get_demand(network, junction)
|
||||
|
||||
|
||||
# example: set_demand(p, ChangeSet({'junction': 'j1', 'demands': [{'demand': 10.0, 'pattern': None, 'category': 'x'}, {'demand': 20.0, 'pattern': None, 'category': None}]}))
|
||||
@router.post("/setdemandproperties/", response_model=None)
|
||||
@router.post(
|
||||
"/setdemandproperties/",
|
||||
response_model=None,
|
||||
summary="设置需水量属性",
|
||||
description="设置指定水网中节点的需水量属性信息"
|
||||
)
|
||||
async def fastapi_set_demand_properties(
|
||||
network: str, junction: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置节点的需水量属性。
|
||||
|
||||
修改指定节点的需水量信息。请求体应包含需水量值、水压等级等属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"junction": junction} | props
|
||||
return set_demand(network, ChangeSet(ps))
|
||||
@@ -30,26 +72,68 @@ async def fastapi_set_demand_properties(
|
||||
############################################################
|
||||
# water distribution 36.[Water Distribution]
|
||||
############################################################
|
||||
@router.get("/calculatedemandtonodes/")
|
||||
@router.get(
|
||||
"/calculatedemandtonodes/",
|
||||
summary="计算需水量到节点分配",
|
||||
description="将总需水量按指定方式分配到多个节点"
|
||||
)
|
||||
async def fastapi_calculate_demand_to_nodes(
|
||||
network: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
计算需水量到节点分配。
|
||||
|
||||
将指定的总需水量均匀或按比例分配到指定的节点列表中。
|
||||
|
||||
请求体格式:
|
||||
{
|
||||
"demand": 需水量值(float),
|
||||
"nodes": 节点ID列表(list[str])
|
||||
}
|
||||
"""
|
||||
props = await req.json()
|
||||
demand = props["demand"]
|
||||
nodes = props["nodes"]
|
||||
return calculate_demand_to_nodes(network, demand, nodes)
|
||||
|
||||
@router.get("/calculatedemandtoregion/")
|
||||
@router.get(
|
||||
"/calculatedemandtoregion/",
|
||||
summary="计算需水量到区域分配",
|
||||
description="将总需水量按区域特征分配到该区域内的节点"
|
||||
)
|
||||
async def fastapi_calculate_demand_to_region(
|
||||
network: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
计算需水量到区域分配。
|
||||
|
||||
根据区域内节点的特征(如面积、人口等)将总需水量分配到该区域的各个节点。
|
||||
|
||||
请求体格式:
|
||||
{
|
||||
"demand": 需水量值(float),
|
||||
"region": 区域ID(str)
|
||||
}
|
||||
"""
|
||||
props = await req.json()
|
||||
demand = props["demand"]
|
||||
region = props["region"]
|
||||
return calculate_demand_to_region(network, demand, region)
|
||||
|
||||
@router.get("/calculatedemandtonetwork/")
|
||||
@router.get(
|
||||
"/calculatedemandtonetwork/",
|
||||
summary="计算需水量到整网分配",
|
||||
description="将需水量均匀分配到整个水网的所有需水节点"
|
||||
)
|
||||
async def fastapi_calculate_demand_to_network(
|
||||
network: str, demand: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
demand: float = Query(..., description="总需水量(m³/h)", gt=0)
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
计算需水量到整网分配。
|
||||
|
||||
将指定的需水量均匀分配到整个水网的所有需水节点。
|
||||
"""
|
||||
return calculate_demand_to_network(network, demand)
|
||||
|
||||
@@ -1,6 +1,42 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
delete_junction,
|
||||
delete_pipe,
|
||||
delete_pump,
|
||||
delete_reservoir,
|
||||
delete_tank,
|
||||
delete_valve,
|
||||
get_all_scada_info,
|
||||
get_element_properties,
|
||||
get_element_properties_with_type,
|
||||
get_element_type,
|
||||
get_element_type_value,
|
||||
get_link_properties,
|
||||
get_link_type,
|
||||
get_links,
|
||||
get_node_links,
|
||||
get_node_properties,
|
||||
get_node_type,
|
||||
get_nodes,
|
||||
get_scada_info,
|
||||
get_status,
|
||||
get_status_schema,
|
||||
get_title,
|
||||
get_title_schema,
|
||||
is_junction,
|
||||
is_link,
|
||||
is_node,
|
||||
is_pipe,
|
||||
is_pump,
|
||||
is_reservoir,
|
||||
is_tank,
|
||||
is_valve,
|
||||
set_status,
|
||||
set_title,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -8,110 +44,291 @@ router = APIRouter()
|
||||
# type
|
||||
############################################################
|
||||
|
||||
@router.get("/isnode/")
|
||||
async def fastapi_is_node(network: str, node: str) -> bool:
|
||||
@router.get(
|
||||
"/isnode/",
|
||||
summary="检查节点有效性",
|
||||
description="检查指定ID是否为水网中的有效节点"
|
||||
)
|
||||
async def fastapi_is_node(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为节点。"""
|
||||
return is_node(network, node)
|
||||
|
||||
@router.get("/isjunction/")
|
||||
async def fastapi_is_junction(network: str, node: str) -> bool:
|
||||
@router.get(
|
||||
"/isjunction/",
|
||||
summary="检查是否为接点",
|
||||
description="检查指定ID是否为水网中的接点(需求点)"
|
||||
)
|
||||
async def fastapi_is_junction(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为接点。"""
|
||||
return is_junction(network, node)
|
||||
|
||||
@router.get("/isreservoir/")
|
||||
async def fastapi_is_reservoir(network: str, node: str) -> bool:
|
||||
@router.get(
|
||||
"/isreservoir/",
|
||||
summary="检查是否为水源",
|
||||
description="检查指定ID是否为水网中的水源(水库/河流)"
|
||||
)
|
||||
async def fastapi_is_reservoir(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为水源。"""
|
||||
return is_reservoir(network, node)
|
||||
|
||||
@router.get("/istank/")
|
||||
async def fastapi_is_tank(network: str, node: str) -> bool:
|
||||
@router.get(
|
||||
"/istank/",
|
||||
summary="检查是否为蓄水池",
|
||||
description="检查指定ID是否为水网中的蓄水池"
|
||||
)
|
||||
async def fastapi_is_tank(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为蓄水池。"""
|
||||
return is_tank(network, node)
|
||||
|
||||
@router.get("/islink/")
|
||||
async def fastapi_is_link(network: str, link: str) -> bool:
|
||||
@router.get(
|
||||
"/islink/",
|
||||
summary="检查管线有效性",
|
||||
description="检查指定ID是否为水网中的有效管线"
|
||||
)
|
||||
async def fastapi_is_link(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为管线。"""
|
||||
return is_link(network, link)
|
||||
|
||||
@router.get("/ispipe/")
|
||||
async def fastapi_is_pipe(network: str, link: str) -> bool:
|
||||
@router.get(
|
||||
"/ispipe/",
|
||||
summary="检查是否为管道",
|
||||
description="检查指定ID是否为水网中的管道"
|
||||
)
|
||||
async def fastapi_is_pipe(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为管道。"""
|
||||
return is_pipe(network, link)
|
||||
|
||||
@router.get("/ispump/")
|
||||
async def fastapi_is_pump(network: str, link: str) -> bool:
|
||||
@router.get(
|
||||
"/ispump/",
|
||||
summary="检查是否为泵",
|
||||
description="检查指定ID是否为水网中的泵"
|
||||
)
|
||||
async def fastapi_is_pump(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为泵。"""
|
||||
return is_pump(network, link)
|
||||
|
||||
@router.get("/isvalve/")
|
||||
async def fastapi_is_valve(network: str, link: str) -> bool:
|
||||
@router.get(
|
||||
"/isvalve/",
|
||||
summary="检查是否为阀门",
|
||||
description="检查指定ID是否为水网中的阀门"
|
||||
)
|
||||
async def fastapi_is_valve(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> bool:
|
||||
"""检查指定ID是否为阀门。"""
|
||||
return is_valve(network, link)
|
||||
|
||||
@router.get("/getnodetype/")
|
||||
async def fastapi_get_node_type(network: str, node: str) -> str:
|
||||
@router.get(
|
||||
"/getnodetype/",
|
||||
summary="获取节点类型",
|
||||
description="获取指定节点的类型(接点/水源/蓄水池)"
|
||||
)
|
||||
async def fastapi_get_node_type(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> str:
|
||||
"""获取节点的类型标识。"""
|
||||
return get_node_type(network, node)
|
||||
|
||||
@router.get("/getlinktype/")
|
||||
async def fastapi_get_link_type(network: str, link: str) -> str:
|
||||
@router.get(
|
||||
"/getlinktype/",
|
||||
summary="获取管线类型",
|
||||
description="获取指定管线的类型(管道/泵/阀门)"
|
||||
)
|
||||
async def fastapi_get_link_type(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> str:
|
||||
"""获取管线的类型标识。"""
|
||||
return get_link_type(network, link)
|
||||
|
||||
@router.get("/getelementtype/")
|
||||
async def fastapi_get_element_type(network: str, element: str) -> str:
|
||||
@router.get(
|
||||
"/getelementtype/",
|
||||
summary="获取元素类型",
|
||||
description="获取指定元素的类型(节点或管线)"
|
||||
)
|
||||
async def fastapi_get_element_type(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
element: str = Query(..., description="元素ID")
|
||||
) -> str:
|
||||
"""获取元素的类型标识。"""
|
||||
return get_element_type(network, element)
|
||||
|
||||
@router.get("/getelementtypevalue/")
|
||||
async def fastapi_get_element_type_value(network: str, element: str) -> int:
|
||||
@router.get(
|
||||
"/getelementtypevalue/",
|
||||
summary="获取元素类型值",
|
||||
description="获取指定元素的类型数值标识"
|
||||
)
|
||||
async def fastapi_get_element_type_value(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
element: str = Query(..., description="元素ID")
|
||||
) -> int:
|
||||
"""获取元素的类型数值。"""
|
||||
return get_element_type_value(network, element)
|
||||
|
||||
@router.get("/getnodes/")
|
||||
async def fastapi_get_nodes(network: str) -> list[str]:
|
||||
@router.get(
|
||||
"/getnodes/",
|
||||
summary="获取所有节点",
|
||||
description="获取指定水网中的所有节点ID列表"
|
||||
)
|
||||
async def fastapi_get_nodes(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
|
||||
"""获取水网中所有节点的ID列表。"""
|
||||
return get_nodes(network)
|
||||
|
||||
@router.get("/getlinks/")
|
||||
async def fastapi_get_links(network: str) -> list[str]:
|
||||
@router.get(
|
||||
"/getlinks/",
|
||||
summary="获取所有管线",
|
||||
description="获取指定水网中的所有管线ID列表"
|
||||
)
|
||||
async def fastapi_get_links(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[str]:
|
||||
"""获取水网中所有管线的ID列表。"""
|
||||
return get_links(network)
|
||||
|
||||
@router.get("/getnodelinks/")
|
||||
def get_node_links_endpoint(network: str, node: str) -> list[str]:
|
||||
@router.get(
|
||||
"/getnodelinks/",
|
||||
summary="获取节点的关联管线",
|
||||
description="获取指定节点连接的所有管线ID列表"
|
||||
)
|
||||
def get_node_links_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> list[str]:
|
||||
"""获取节点关联的所有管线。"""
|
||||
return get_node_links(network, node)
|
||||
|
||||
############################################################
|
||||
# Node & Link properties
|
||||
############################################################
|
||||
|
||||
@router.get("/getnodeproperties/")
|
||||
async def fast_get_node_properties(network: str, node: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getnodeproperties/",
|
||||
summary="获取节点属性",
|
||||
description="获取指定节点的所有属性信息"
|
||||
)
|
||||
async def fast_get_node_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取节点的完整属性信息。"""
|
||||
return get_node_properties(network, node)
|
||||
|
||||
@router.get("/getlinkproperties/")
|
||||
async def fast_get_link_properties(network: str, link: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getlinkproperties/",
|
||||
summary="获取管线属性",
|
||||
description="获取指定管线的所有属性信息"
|
||||
)
|
||||
async def fast_get_link_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取管线的完整属性信息。"""
|
||||
return get_link_properties(network, link)
|
||||
|
||||
@router.get("/getscadaproperties/")
|
||||
async def fast_get_scada_properties(network: str, scada: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getscadaproperties/",
|
||||
summary="获取SCADA点属性",
|
||||
description="获取指定SCADA点的属性信息"
|
||||
)
|
||||
async def fast_get_scada_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scada: str = Query(..., description="SCADA点ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取SCADA点的属性信息。"""
|
||||
return get_scada_info(network, scada)
|
||||
|
||||
@router.get("/getallscadaproperties/")
|
||||
async def fast_get_all_scada_properties(network: str) -> list[dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getallscadaproperties/",
|
||||
summary="获取所有SCADA点属性",
|
||||
description="获取指定水网中所有SCADA点的属性信息"
|
||||
)
|
||||
async def fast_get_all_scada_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取水网中所有SCADA点的属性列表。"""
|
||||
return get_all_scada_info(network)
|
||||
|
||||
@router.get("/getelementpropertieswithtype/")
|
||||
@router.get(
|
||||
"/getelementpropertieswithtype/",
|
||||
summary="获取指定类型元素属性",
|
||||
description="获取指定类型的元素属性信息"
|
||||
)
|
||||
async def fast_get_element_properties_with_type(
|
||||
network: str, elementtype: str, element: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
elementtype: str = Query(..., description="元素类型"),
|
||||
element: str = Query(..., description="元素ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取指定类型元素的属性。"""
|
||||
return get_element_properties_with_type(network, elementtype, element)
|
||||
|
||||
@router.get("/getelementproperties/")
|
||||
async def fast_get_element_properties(network: str, element: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getelementproperties/",
|
||||
summary="获取元素属性",
|
||||
description="获取指定元素的属性信息"
|
||||
)
|
||||
async def fast_get_element_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
element: str = Query(..., description="元素ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取元素的完整属性信息。"""
|
||||
return get_element_properties(network, element)
|
||||
|
||||
############################################################
|
||||
# title 1.[TITLE]
|
||||
############################################################
|
||||
|
||||
@router.get("/gettitleschema/")
|
||||
async def fast_get_title_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/gettitleschema/",
|
||||
summary="获取标题属性架构",
|
||||
description="获取指定水网的标题(标题)属性架构定义"
|
||||
)
|
||||
async def fast_get_title_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""获取水网标题的属性架构。"""
|
||||
return get_title_schema(network)
|
||||
|
||||
@router.get("/gettitle/")
|
||||
async def fast_get_title(network: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/gettitle/",
|
||||
summary="获取水网标题属性",
|
||||
description="获取指定水网的标题(Title)信息"
|
||||
)
|
||||
async def fast_get_title(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""获取水网的标题属性。"""
|
||||
return get_title(network)
|
||||
|
||||
@router.get("/settitle/", response_model=None)
|
||||
async def fastapi_set_title(network: str, req: Request) -> ChangeSet:
|
||||
@router.get(
|
||||
"/settitle/",
|
||||
response_model=None,
|
||||
summary="设置水网标题属性",
|
||||
description="设置指定水网的标题(Title)信息"
|
||||
)
|
||||
async def fastapi_set_title(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置水网的标题属性。"""
|
||||
props = await req.json()
|
||||
return set_title(network, ChangeSet(props))
|
||||
|
||||
@@ -119,18 +336,41 @@ async def fastapi_set_title(network: str, req: Request) -> ChangeSet:
|
||||
# status 10.[STATUS]
|
||||
############################################################
|
||||
|
||||
@router.get("/getstatusschema")
|
||||
async def fastapi_get_status_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getstatusschema",
|
||||
summary="获取状态属性架构",
|
||||
description="获取指定水网的状态(Status)属性架构定义"
|
||||
)
|
||||
async def fastapi_get_status_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""获取水网状态的属性架构。"""
|
||||
return get_status_schema(network)
|
||||
|
||||
@router.get("/getstatus/")
|
||||
async def fastapi_get_status(network: str, link: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getstatus/",
|
||||
summary="获取管线状态",
|
||||
description="获取指定管线的状态信息"
|
||||
)
|
||||
async def fastapi_get_status(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取管线的状态属性。"""
|
||||
return get_status(network, link)
|
||||
|
||||
@router.post("/setstatus/", response_model=None)
|
||||
@router.post(
|
||||
"/setstatus/",
|
||||
response_model=None,
|
||||
summary="设置管线状态",
|
||||
description="设置指定管线的状态信息"
|
||||
)
|
||||
async def fastapi_set_status_properties(
|
||||
network: str, link: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置管线的状态属性。"""
|
||||
props = await req.json()
|
||||
ps = {"link": link} | props
|
||||
return set_status(network, ChangeSet(ps))
|
||||
@@ -139,8 +379,17 @@ async def fastapi_set_status_properties(
|
||||
# General Deletion
|
||||
############################################################
|
||||
|
||||
@router.post("/deletenode/", response_model=None)
|
||||
async def fastapi_delete_node(network: str, node: str) -> ChangeSet:
|
||||
@router.post(
|
||||
"/deletenode/",
|
||||
response_model=None,
|
||||
summary="删除节点",
|
||||
description="删除指定的节点(接点/水源/蓄水池)"
|
||||
)
|
||||
async def fastapi_delete_node(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> ChangeSet:
|
||||
"""删除指定的节点。自动识别节点类型并调用相应的删除操作。"""
|
||||
ps = {"id": node}
|
||||
if is_junction(network, node):
|
||||
return delete_junction(network, ChangeSet(ps))
|
||||
@@ -150,8 +399,17 @@ async def fastapi_delete_node(network: str, node: str) -> ChangeSet:
|
||||
return delete_tank(network, ChangeSet(ps))
|
||||
return ChangeSet() # Should probably raise error or return empty
|
||||
|
||||
@router.post("/deletelink/", response_model=None)
|
||||
async def fastapi_delete_link(network: str, link: str) -> ChangeSet:
|
||||
@router.post(
|
||||
"/deletelink/",
|
||||
response_model=None,
|
||||
summary="删除管线",
|
||||
description="删除指定的管线(管道/泵/阀门)"
|
||||
)
|
||||
async def fastapi_delete_link(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
link: str = Query(..., description="管线ID")
|
||||
) -> ChangeSet:
|
||||
"""删除指定的管线。自动识别管线类型并调用相应的删除操作。"""
|
||||
ps = {"id": link}
|
||||
if is_pipe(network, link):
|
||||
return delete_pipe(network, ChangeSet(ps))
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
from fastapi import APIRouter, Request, Depends
|
||||
from fastapi import APIRouter, Request, Depends, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
get_all_scada_info,
|
||||
get_major_node_coords,
|
||||
get_major_pipe_nodes,
|
||||
get_network_in_extent,
|
||||
get_network_link_nodes,
|
||||
get_network_node_coords,
|
||||
get_node_coord,
|
||||
)
|
||||
from app.auth.dependencies import get_current_user as verify_token
|
||||
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
|
||||
import msgpack
|
||||
@@ -25,19 +34,44 @@ router = APIRouter()
|
||||
# props = await req.json()
|
||||
# return set_coord(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getnodecoord/")
|
||||
async def fastapi_get_node_coord(network: str, node: str) -> dict[str, float] | None:
|
||||
@router.get(
|
||||
"/getnodecoord/",
|
||||
summary="获取节点坐标",
|
||||
description="获取指定节点的地理坐标(X, Y)"
|
||||
)
|
||||
async def fastapi_get_node_coord(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
node: str = Query(..., description="节点ID")
|
||||
) -> dict[str, float] | None:
|
||||
"""获取节点的地理坐标信息。"""
|
||||
return get_node_coord(network, node)
|
||||
|
||||
# Additional geometry queries found in main.py logic (implicit or explicit)
|
||||
@router.get("/getnetworkinextent/")
|
||||
@router.get(
|
||||
"/getnetworkinextent/",
|
||||
summary="获取范围内的网络元素",
|
||||
description="获取指定地理范围内的网络节点和管线"
|
||||
)
|
||||
async def fastapi_get_network_in_extent(
|
||||
network: str, x1: float, y1: float, x2: float, y2: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
x1: float = Query(..., description="范围左下角X坐标", alias="x1"),
|
||||
y1: float = Query(..., description="范围左下角Y坐标", alias="y1"),
|
||||
x2: float = Query(..., description="范围右上角X坐标", alias="x2"),
|
||||
y2: float = Query(..., description="范围右上角Y坐标", alias="y2")
|
||||
) -> dict[str, Any]:
|
||||
"""获取地理范围内的网络几何信息。"""
|
||||
return get_network_in_extent(network, x1, y1, x2, y2)
|
||||
|
||||
@router.get("/getnetworkgeometries/", dependencies=[Depends(verify_token)])
|
||||
async def fastapi_get_network_geometries(network: str) -> dict[str, Any] | None:
|
||||
@router.get(
|
||||
"/getnetworkgeometries/",
|
||||
dependencies=[Depends(verify_token)],
|
||||
summary="获取完整网络几何信息",
|
||||
description="获取整个水网的所有节点、管线和SCADA点的几何信息(需要身份验证)"
|
||||
)
|
||||
async def fastapi_get_network_geometries(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, Any] | None:
|
||||
"""获取完整的网络几何信息,包括所有节点、管线和SCADA点。结果从缓存返回。"""
|
||||
cache_key = f"getnetworkgeometries_{network}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
@@ -55,18 +89,39 @@ async def fastapi_get_network_geometries(network: str) -> dict[str, Any] | None:
|
||||
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
|
||||
return results
|
||||
|
||||
@router.get("/getmajornodecoords/")
|
||||
@router.get(
|
||||
"/getmajornodecoords/",
|
||||
summary="获取主要节点坐标",
|
||||
description="获取直径大于等于指定值的节点坐标"
|
||||
)
|
||||
async def fastapi_get_majornode_coords(
|
||||
network: str, diameter: int
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
diameter: int = Query(..., description="最小直径(mm)", gt=0)
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""获取主要节点的坐标。只返回直径大于等于指定值的节点。"""
|
||||
return get_major_node_coords(network, diameter)
|
||||
|
||||
@router.get("/getmajorpipenodes/")
|
||||
async def fastapi_get_major_pipe_nodes(network: str, diameter: int) -> list[str] | None:
|
||||
@router.get(
|
||||
"/getmajorpipenodes/",
|
||||
summary="获取主要管道节点",
|
||||
description="获取直径大于等于指定值的管道的节点ID"
|
||||
)
|
||||
async def fastapi_get_major_pipe_nodes(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
diameter: int = Query(..., description="最小直径(mm)", gt=0)
|
||||
) -> list[str] | None:
|
||||
"""获取主要管道节点。只返回直径大于等于指定值的管道。"""
|
||||
return get_major_pipe_nodes(network, diameter)
|
||||
|
||||
@router.get("/getnetworklinknodes/")
|
||||
async def fastapi_get_network_link_nodes(network: str) -> list[str] | None:
|
||||
@router.get(
|
||||
"/getnetworklinknodes/",
|
||||
summary="获取网络管线节点",
|
||||
description="获取指定水网所有管线的起点和终点节点"
|
||||
)
|
||||
async def fastapi_get_network_link_nodes(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[str] | None:
|
||||
"""获取网络中所有管线的连接节点。"""
|
||||
return get_network_link_nodes(network)
|
||||
|
||||
# @router.get("/getallcoords/")
|
||||
|
||||
@@ -1,111 +1,361 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_junction,
|
||||
delete_junction,
|
||||
get_all_junctions,
|
||||
get_junction,
|
||||
get_junction_schema,
|
||||
set_junction,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getjunctionschema")
|
||||
async def fast_get_junction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getjunctionschema", summary="获取节点架构", description="获取指定项目的节点属性架构和数据类型定义。")
|
||||
async def fast_get_junction_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取节点架构信息。
|
||||
|
||||
返回指定项目的节点属性架构,包括所有属性的类型和约束信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
"""
|
||||
return get_junction_schema(network)
|
||||
|
||||
@router.post("/addjunction/", response_model=None)
|
||||
@router.post("/addjunction/", response_model=None, summary="添加节点", description="在供水网络中添加新的节点,指定节点ID和空间坐标。")
|
||||
async def fastapi_add_junction(
|
||||
network: str, junction: str, x: float, y: float, z: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
x: float = Query(..., description="X 坐标"),
|
||||
y: float = Query(..., description="Y 坐标"),
|
||||
z: float = Query(..., description="标高(海拔高度)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新节点到供水网络。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
x: X 坐标值
|
||||
y: Y 坐标值
|
||||
z: 标高(海拔高度)
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction, "x": x, "y": y, "elevation": z}
|
||||
return add_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletejunction/", response_model=None)
|
||||
async def fastapi_delete_junction(network: str, junction: str) -> ChangeSet:
|
||||
@router.post("/deletejunction/", response_model=None, summary="删除节点", description="从供水网络中删除指定的节点。")
|
||||
async def fastapi_delete_junction(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除指定的节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction}
|
||||
return delete_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getjunctionelevation/")
|
||||
async def fastapi_get_junction_elevation(network: str, junction: str) -> float:
|
||||
@router.get("/getjunctionelevation/", summary="获取节点标高", description="获取指定节点的标高(海拔高度)。")
|
||||
async def fastapi_get_junction_elevation(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> float:
|
||||
"""
|
||||
获取节点的标高值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
float: 节点标高值
|
||||
"""
|
||||
ps = get_junction(network, junction)
|
||||
return ps["elevation"]
|
||||
|
||||
@router.get("/getjunctionx/")
|
||||
async def fastapi_get_junction_x(network: str, junction: str) -> float:
|
||||
@router.get("/getjunctionx/", summary="获取节点 X 坐标", description="获取指定节点的 X 坐标值。")
|
||||
async def fastapi_get_junction_x(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> float:
|
||||
"""
|
||||
获取节点的 X 坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
float: 节点 X 坐标值
|
||||
"""
|
||||
ps = get_junction(network, junction)
|
||||
return ps["x"]
|
||||
|
||||
@router.get("/getjunctiony/")
|
||||
async def fastapi_get_junction_y(network: str, junction: str) -> float:
|
||||
@router.get("/getjunctiony/", summary="获取节点 Y 坐标", description="获取指定节点的 Y 坐标值。")
|
||||
async def fastapi_get_junction_y(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> float:
|
||||
"""
|
||||
获取节点的 Y 坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
float: 节点 Y 坐标值
|
||||
"""
|
||||
ps = get_junction(network, junction)
|
||||
return ps["y"]
|
||||
|
||||
@router.get("/getjunctioncoord/")
|
||||
async def fastapi_get_junction_coord(network: str, junction: str) -> dict[str, float]:
|
||||
@router.get("/getjunctioncoord/", summary="获取节点坐标", description="获取指定节点的 X 和 Y 坐标。")
|
||||
async def fastapi_get_junction_coord(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
获取节点的坐标信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
dict: 包含 x 和 y 坐标的字典
|
||||
"""
|
||||
ps = get_junction(network, junction)
|
||||
coord = {"x": ps["x"], "y": ps["y"]}
|
||||
return coord
|
||||
|
||||
@router.get("/getjunctiondemand/")
|
||||
async def fastapi_get_junction_demand(network: str, junction: str) -> float:
|
||||
@router.get("/getjunctiondemand/", summary="获取节点需水量", description="获取指定节点的需水量。")
|
||||
async def fastapi_get_junction_demand(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> float:
|
||||
"""
|
||||
获取节点的需水量。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
float: 节点的需水量值
|
||||
"""
|
||||
ps = get_junction(network, junction)
|
||||
return ps["demand"]
|
||||
|
||||
@router.get("/getjunctionpattern/")
|
||||
async def fastapi_get_junction_pattern(network: str, junction: str) -> str:
|
||||
@router.get("/getjunctionpattern/", summary="获取节点需水模式", description="获取指定节点的需水模式标识。")
|
||||
async def fastapi_get_junction_pattern(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> str:
|
||||
"""
|
||||
获取节点的需水模式。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
str: 节点的需水模式标识
|
||||
"""
|
||||
ps = get_junction(network, junction)
|
||||
return ps["pattern"]
|
||||
|
||||
@router.post("/setjunctionelevation/", response_model=None)
|
||||
@router.post("/setjunctionelevation/", response_model=None, summary="设置节点标高", description="设置指定节点的标高值。")
|
||||
async def fastapi_set_junction_elevation(
|
||||
network: str, junction: str, elevation: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
elevation: float = Query(..., description="标高(海拔高度)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置节点的标高。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
elevation: 标高值
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction, "elevation": elevation}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctionx/", response_model=None)
|
||||
async def fastapi_set_junction_x(network: str, junction: str, x: float) -> ChangeSet:
|
||||
@router.post("/setjunctionx/", response_model=None, summary="设置节点 X 坐标", description="设置指定节点的 X 坐标值。")
|
||||
async def fastapi_set_junction_x(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
x: float = Query(..., description="X 坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置节点的 X 坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
x: X 坐标值
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction, "x": x}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctiony/", response_model=None)
|
||||
async def fastapi_set_junction_y(network: str, junction: str, y: float) -> ChangeSet:
|
||||
@router.post("/setjunctiony/", response_model=None, summary="设置节点 Y 坐标", description="设置指定节点的 Y 坐标值。")
|
||||
async def fastapi_set_junction_y(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
y: float = Query(..., description="Y 坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置节点的 Y 坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
y: Y 坐标值
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction, "y": y}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctioncoord/", response_model=None)
|
||||
@router.post("/setjunctioncoord/", response_model=None, summary="设置节点坐标", description="设置指定节点的 X 和 Y 坐标。")
|
||||
async def fastapi_set_junction_coord(
|
||||
network: str, junction: str, x: float, y: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
x: float = Query(..., description="X 坐标值"),
|
||||
y: float = Query(..., description="Y 坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置节点的坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
x: X 坐标值
|
||||
y: Y 坐标值
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction, "x": x, "y": y}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctiondemand/", response_model=None)
|
||||
@router.post("/setjunctiondemand/", response_model=None, summary="设置节点需水量", description="设置指定节点的需水量。")
|
||||
async def fastapi_set_junction_demand(
|
||||
network: str, junction: str, demand: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
demand: float = Query(..., description="需水量值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置节点的需水量。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
demand: 需水量值
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction, "demand": demand}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctionpattern/", response_model=None)
|
||||
@router.post("/setjunctionpattern/", response_model=None, summary="设置节点需水模式", description="设置指定节点的需水模式标识。")
|
||||
async def fastapi_set_junction_pattern(
|
||||
network: str, junction: str, pattern: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
pattern: str = Query(..., description="需水模式标识")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置节点的需水模式。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
pattern: 需水模式标识
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
ps = {"id": junction, "pattern": pattern}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getjunctionproperties/")
|
||||
@router.get("/getjunctionproperties/", summary="获取节点属性", description="获取指定节点的所有属性信息。")
|
||||
async def fastapi_get_junction_properties(
|
||||
network: str, junction: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取节点的完整属性信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
|
||||
Returns:
|
||||
dict: 包含节点所有属性的字典
|
||||
"""
|
||||
return get_junction(network, junction)
|
||||
|
||||
@router.get("/getalljunctionproperties/")
|
||||
async def fastapi_get_all_junction_properties(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getalljunctionproperties/", summary="获取所有节点属性", description="获取指定项目中所有节点的属性信息。")
|
||||
async def fastapi_get_all_junction_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取所有节点的属性信息列表。
|
||||
|
||||
此端点返回指定项目中所有节点的详细属性。缓存查询结果以提高性能。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
list: 包含所有节点属性的列表
|
||||
"""
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client # Redis logic removed for clean split, can be re-added if needed or imported
|
||||
results = get_all_junctions(network)
|
||||
return results
|
||||
|
||||
@router.post("/setjunctionproperties/", response_model=None)
|
||||
@router.post("/setjunctionproperties/", response_model=None, summary="批量设置节点属性", description="批量设置指定节点的多个属性。")
|
||||
async def fastapi_set_junction_properties(
|
||||
network: str, junction: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
junction: str = Query(..., description="节点 ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
批量设置节点属性。
|
||||
|
||||
允许一次性设置节点的多个属性,如坐标、标高、需水量等。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
junction: 节点 ID
|
||||
req: 包含属性和值的 JSON 请求体
|
||||
|
||||
Returns:
|
||||
ChangeSet: 包含变更信息的结果
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": junction} | props
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@@ -1,25 +1,63 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
PIPE_STATUS_OPEN,
|
||||
add_pipe,
|
||||
delete_pipe,
|
||||
get_all_pipes,
|
||||
get_pipe,
|
||||
get_pipe_schema,
|
||||
set_pipe,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpipeschema")
|
||||
async def fastapi_get_pipe_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getpipeschema", summary="获取管道模式", description="获取管道对象的模式定义,包含所有可用字段及其类型")
|
||||
async def fastapi_get_pipe_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取管道数据模式定义。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含管道模式信息的字典
|
||||
"""
|
||||
return get_pipe_schema(network)
|
||||
|
||||
@router.post("/addpipe/", response_model=None)
|
||||
@router.post("/addpipe/", response_model=None, summary="添加管道", description="向网络中添加新的管道,需要提供管道的基本参数如长度、管径、粗糙度等")
|
||||
async def fastapi_add_pipe(
|
||||
network: str,
|
||||
pipe: str,
|
||||
node1: str,
|
||||
node2: str,
|
||||
length: float = 0,
|
||||
diameter: float = 0,
|
||||
roughness: float = 0,
|
||||
minor_loss: float = 0,
|
||||
status: str = PIPE_STATUS_OPEN,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道标识符"),
|
||||
node1: str = Query(..., description="管道起始节点ID"),
|
||||
node2: str = Query(..., description="管道终止节点ID"),
|
||||
length: float = Query(0, description="管道长度(单位:米)"),
|
||||
diameter: float = Query(0, description="管道管径(单位:毫米)"),
|
||||
roughness: float = Query(0, description="管道粗糙度"),
|
||||
minor_loss: float = Query(0, description="管道局部阻力系数"),
|
||||
status: str = Query(PIPE_STATUS_OPEN, description="管道状态(开启/关闭)"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新管道到网络。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
node1: 起始节点ID
|
||||
node2: 终止节点ID
|
||||
length: 管道长度
|
||||
diameter: 管道管径
|
||||
roughness: 管道粗糙度
|
||||
minor_loss: 局部阻力系数
|
||||
status: 管道状态
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次操作的变更信息
|
||||
"""
|
||||
ps = {
|
||||
"id": pipe,
|
||||
"node1": node1,
|
||||
@@ -32,102 +70,342 @@ async def fastapi_add_pipe(
|
||||
}
|
||||
return add_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletepipe/", response_model=None)
|
||||
async def fastapi_delete_pipe(network: str, pipe: str) -> ChangeSet:
|
||||
@router.post("/deletepipe/", response_model=None, summary="删除管道", description="从网络中删除指定的管道")
|
||||
async def fastapi_delete_pipe(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="要删除的管道ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除管道。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次删除操作的变更信息
|
||||
"""
|
||||
ps = {"id": pipe}
|
||||
return delete_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpipenode1/")
|
||||
async def fastapi_get_pipe_node1(network: str, pipe: str) -> str | None:
|
||||
@router.get("/getpipenode1/", summary="获取管道起始节点", description="获取指定管道的起始节点ID")
|
||||
async def fastapi_get_pipe_node1(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取管道的起始节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
起始节点ID,如果不存在则返回None
|
||||
"""
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["node1"]
|
||||
|
||||
@router.get("/getpipenode2/")
|
||||
async def fastapi_get_pipe_node2(network: str, pipe: str) -> str | None:
|
||||
@router.get("/getpipenode2/", summary="获取管道终止节点", description="获取指定管道的终止节点ID")
|
||||
async def fastapi_get_pipe_node2(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取管道的终止节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
终止节点ID,如果不存在则返回None
|
||||
"""
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["node2"]
|
||||
|
||||
@router.get("/getpipelength/")
|
||||
async def fastapi_get_pipe_length(network: str, pipe: str) -> float | None:
|
||||
@router.get("/getpipelength/", summary="获取管道长度", description="获取指定管道的长度")
|
||||
async def fastapi_get_pipe_length(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取管道长度。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
管道长度(单位:米),如果不存在则返回None
|
||||
"""
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["length"]
|
||||
|
||||
@router.get("/getpipediameter/")
|
||||
async def fastapi_get_pipe_diameter(network: str, pipe: str) -> float | None:
|
||||
@router.get("/getpipediameter/", summary="获取管道管径", description="获取指定管道的管径")
|
||||
async def fastapi_get_pipe_diameter(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取管道管径。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
管道管径(单位:毫米),如果不存在则返回None
|
||||
"""
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["diameter"]
|
||||
|
||||
@router.get("/getpiperoughness/")
|
||||
async def fastapi_get_pipe_roughness(network: str, pipe: str) -> float | None:
|
||||
@router.get("/getpiperoughness/", summary="获取管道粗糙度", description="获取指定管道的粗糙度")
|
||||
async def fastapi_get_pipe_roughness(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取管道粗糙度。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
管道粗糙度值,如果不存在则返回None
|
||||
"""
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["roughness"]
|
||||
|
||||
@router.get("/getpipeminorloss/")
|
||||
async def fastapi_get_pipe_minor_loss(network: str, pipe: str) -> float | None:
|
||||
@router.get("/getpipeminorloss/", summary="获取管道局部阻力系数", description="获取指定管道的局部阻力系数")
|
||||
async def fastapi_get_pipe_minor_loss(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取管道局部阻力系数。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
局部阻力系数,如果不存在则返回None
|
||||
"""
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["minor_loss"]
|
||||
|
||||
@router.get("/getpipestatus/")
|
||||
async def fastapi_get_pipe_status(network: str, pipe: str) -> str | None:
|
||||
@router.get("/getpipestatus/", summary="获取管道状态", description="获取指定管道的状态(开启或关闭)")
|
||||
async def fastapi_get_pipe_status(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取管道状态。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
管道状态(开启/关闭),如果不存在则返回None
|
||||
"""
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["status"]
|
||||
|
||||
@router.post("/setpipenode1/", response_model=None)
|
||||
async def fastapi_set_pipe_node1(network: str, pipe: str, node1: str) -> ChangeSet:
|
||||
@router.post("/setpipenode1/", response_model=None, summary="设置管道起始节点", description="设置指定管道的起始节点")
|
||||
async def fastapi_set_pipe_node1(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
node1: str = Query(..., description="新的起始节点ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置管道起始节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
node1: 新的起始节点ID
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pipe, "node1": node1}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipenode2/", response_model=None)
|
||||
async def fastapi_set_pipe_node2(network: str, pipe: str, node2: str) -> ChangeSet:
|
||||
@router.post("/setpipenode2/", response_model=None, summary="设置管道终止节点", description="设置指定管道的终止节点")
|
||||
async def fastapi_set_pipe_node2(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
node2: str = Query(..., description="新的终止节点ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置管道终止节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
node2: 新的终止节点ID
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pipe, "node2": node2}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipelength/", response_model=None)
|
||||
async def fastapi_set_pipe_length(network: str, pipe: str, length: float) -> ChangeSet:
|
||||
@router.post("/setpipelength/", response_model=None, summary="设置管道长度", description="设置指定管道的长度")
|
||||
async def fastapi_set_pipe_length(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
length: float = Query(..., description="新的管道长度(单位:米)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置管道长度。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
length: 新的管道长度
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pipe, "length": length}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipediameter/", response_model=None)
|
||||
@router.post("/setpipediameter/", response_model=None, summary="设置管道管径", description="设置指定管道的管径")
|
||||
async def fastapi_set_pipe_diameter(
|
||||
network: str, pipe: str, diameter: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
diameter: float = Query(..., description="新的管道管径(单位:毫米)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置管道管径。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
diameter: 新的管道管径
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pipe, "diameter": diameter}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpiperoughness/", response_model=None)
|
||||
@router.post("/setpiperoughness/", response_model=None, summary="设置管道粗糙度", description="设置指定管道的粗糙度")
|
||||
async def fastapi_set_pipe_roughness(
|
||||
network: str, pipe: str, roughness: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
roughness: float = Query(..., description="新的管道粗糙度值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置管道粗糙度。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
roughness: 新的管道粗糙度
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pipe, "roughness": roughness}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipeminorloss/", response_model=None)
|
||||
@router.post("/setpipeminorloss/", response_model=None, summary="设置管道局部阻力系数", description="设置指定管道的局部阻力系数")
|
||||
async def fastapi_set_pipe_minor_loss(
|
||||
network: str, pipe: str, minor_loss: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
minor_loss: float = Query(..., description="新的局部阻力系数值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置管道局部阻力系数。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
minor_loss: 新的局部阻力系数
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pipe, "minor_loss": minor_loss}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipestatus/", response_model=None)
|
||||
async def fastapi_set_pipe_status(network: str, pipe: str, status: str) -> ChangeSet:
|
||||
@router.post("/setpipestatus/", response_model=None, summary="设置管道状态", description="设置指定管道的状态(开启或关闭)")
|
||||
async def fastapi_set_pipe_status(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
status: str = Query(..., description="新的管道状态(开启/关闭)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置管道状态。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
status: 新的管道状态
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pipe, "status": status}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpipeproperties/")
|
||||
async def fastapi_get_pipe_properties(network: str, pipe: str) -> dict[str, Any]:
|
||||
@router.get("/getpipeproperties/", summary="获取管道属性", description="获取指定管道的所有属性信息")
|
||||
async def fastapi_get_pipe_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取管道的所有属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
|
||||
Returns:
|
||||
包含管道所有属性的字典
|
||||
"""
|
||||
return get_pipe(network, pipe)
|
||||
|
||||
@router.get("/getallpipeproperties/")
|
||||
async def fastapi_get_all_pipe_properties(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getallpipeproperties/", summary="获取所有管道属性", description="获取网络中所有管道的属性信息列表")
|
||||
async def fastapi_get_all_pipe_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取网络中所有管道的属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含所有管道属性的字典列表
|
||||
"""
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_pipes(network)
|
||||
return results
|
||||
|
||||
@router.post("/setpipeproperties/", response_model=None)
|
||||
@router.post("/setpipeproperties/", response_model=None, summary="设置管道属性", description="批量设置指定管道的多个属性")
|
||||
async def fastapi_set_pipe_properties(
|
||||
network: str, pipe: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe: str = Query(..., description="管道ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
批量设置管道属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe: 管道ID
|
||||
req: 请求体,包含要设置的属性及其值
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": pipe} | props
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@@ -1,60 +1,203 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_pump,
|
||||
delete_pump,
|
||||
get_all_pumps,
|
||||
get_pump,
|
||||
get_pump_schema,
|
||||
set_pump,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpumpschema")
|
||||
async def fastapi_get_pump_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getpumpschema", summary="获取水泵模式", description="获取水泵对象的模式定义,包含所有可用字段及其类型")
|
||||
async def fastapi_get_pump_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取水泵数据模式定义。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含水泵模式信息的字典
|
||||
"""
|
||||
return get_pump_schema(network)
|
||||
|
||||
@router.post("/addpump/", response_model=None)
|
||||
@router.post("/addpump/", response_model=None, summary="添加水泵", description="向网络中添加新的水泵,需要提供水泵的基本参数如功率等")
|
||||
async def fastapi_add_pump(
|
||||
network: str, pump: str, node1: str, node2: str, power: float = 0.0
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="水泵标识符"),
|
||||
node1: str = Query(..., description="水泵起始节点ID"),
|
||||
node2: str = Query(..., description="水泵终止节点ID"),
|
||||
power: float = Query(0.0, description="水泵功率(单位:千瓦)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新水泵到网络。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
node1: 起始节点ID
|
||||
node2: 终止节点ID
|
||||
power: 水泵功率
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次操作的变更信息
|
||||
"""
|
||||
ps = {"id": pump, "node1": node1, "node2": node2, "power": power}
|
||||
return add_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletepump/", response_model=None)
|
||||
async def fastapi_delete_pump(network: str, pump: str) -> ChangeSet:
|
||||
@router.post("/deletepump/", response_model=None, summary="删除水泵", description="从网络中删除指定的水泵")
|
||||
async def fastapi_delete_pump(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="要删除的水泵ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除水泵。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次删除操作的变更信息
|
||||
"""
|
||||
ps = {"id": pump}
|
||||
return delete_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpumpnode1/")
|
||||
async def fastapi_get_pump_node1(network: str, pump: str) -> str | None:
|
||||
@router.get("/getpumpnode1/", summary="获取水泵起始节点", description="获取指定水泵的起始节点ID")
|
||||
async def fastapi_get_pump_node1(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="水泵ID")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取水泵的起始节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
|
||||
Returns:
|
||||
起始节点ID,如果不存在则返回None
|
||||
"""
|
||||
ps = get_pump(network, pump)
|
||||
return ps["node1"]
|
||||
|
||||
@router.get("/getpumpnode2/")
|
||||
async def fastapi_get_pump_node2(network: str, pump: str) -> str | None:
|
||||
@router.get("/getpumpnode2/", summary="获取水泵终止节点", description="获取指定水泵的终止节点ID")
|
||||
async def fastapi_get_pump_node2(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="水泵ID")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取水泵的终止节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
|
||||
Returns:
|
||||
终止节点ID,如果不存在则返回None
|
||||
"""
|
||||
ps = get_pump(network, pump)
|
||||
return ps["node2"]
|
||||
|
||||
@router.post("/setpumpnode1/", response_model=None)
|
||||
async def fastapi_set_pump_node1(network: str, pump: str, node1: str) -> ChangeSet:
|
||||
@router.post("/setpumpnode1/", response_model=None, summary="设置水泵起始节点", description="设置指定水泵的起始节点")
|
||||
async def fastapi_set_pump_node1(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="水泵ID"),
|
||||
node1: str = Query(..., description="新的起始节点ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水泵起始节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
node1: 新的起始节点ID
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pump, "node1": node1}
|
||||
return set_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpumpnode2/", response_model=None)
|
||||
async def fastapi_set_pump_node2(network: str, pump: str, node2: str) -> ChangeSet:
|
||||
@router.post("/setpumpnode2/", response_model=None, summary="设置水泵终止节点", description="设置指定水泵的终止节点")
|
||||
async def fastapi_set_pump_node2(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="水泵ID"),
|
||||
node2: str = Query(..., description="新的终止节点ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水泵终止节点。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
node2: 新的终止节点ID
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
ps = {"id": pump, "node2": node2}
|
||||
return set_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpumpproperties/")
|
||||
async def fastapi_get_pump_properties(network: str, pump: str) -> dict[str, Any]:
|
||||
@router.get("/getpumpproperties/", summary="获取水泵属性", description="获取指定水泵的所有属性信息")
|
||||
async def fastapi_get_pump_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="水泵ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取水泵的所有属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
|
||||
Returns:
|
||||
包含水泵所有属性的字典
|
||||
"""
|
||||
return get_pump(network, pump)
|
||||
|
||||
@router.get("/getallpumpproperties/")
|
||||
async def fastapi_get_all_pump_properties(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getallpumpproperties/", summary="获取所有水泵属性", description="获取网络中所有水泵的属性信息列表")
|
||||
async def fastapi_get_all_pump_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取网络中所有水泵的属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含所有水泵属性的字典列表
|
||||
"""
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_pumps(network)
|
||||
return results
|
||||
|
||||
@router.post("/setpumpproperties/", response_model=None)
|
||||
@router.post("/setpumpproperties/", response_model=None, summary="设置水泵属性", description="批量设置指定水泵的多个属性")
|
||||
async def fastapi_set_pump_properties(
|
||||
network: str, pump: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pump: str = Query(..., description="水泵ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
批量设置水泵属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pump: 水泵ID
|
||||
req: 请求体,包含要设置的属性及其值
|
||||
|
||||
Returns:
|
||||
ChangeSet对象,包含本次修改的变更信息
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": pump} | props
|
||||
return set_pump(network, ChangeSet(ps))
|
||||
|
||||
@@ -1,6 +1,42 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_district_metering_area,
|
||||
add_region,
|
||||
add_service_area,
|
||||
add_virtual_district,
|
||||
calculate_district_metering_area_for_network,
|
||||
calculate_district_metering_area_for_nodes,
|
||||
calculate_district_metering_area_for_region,
|
||||
calculate_service_area,
|
||||
calculate_virtual_district,
|
||||
delete_district_metering_area,
|
||||
delete_region,
|
||||
delete_service_area,
|
||||
delete_virtual_district,
|
||||
generate_district_metering_area,
|
||||
generate_service_area,
|
||||
generate_sub_district_metering_area,
|
||||
generate_virtual_district,
|
||||
get_all_district_metering_area_ids,
|
||||
get_all_district_metering_areas,
|
||||
get_all_service_areas,
|
||||
get_all_virtual_districts,
|
||||
get_district_metering_area,
|
||||
get_district_metering_area_schema,
|
||||
get_region,
|
||||
get_region_schema,
|
||||
get_service_area,
|
||||
get_service_area_schema,
|
||||
get_virtual_district,
|
||||
get_virtual_district_schema,
|
||||
set_district_metering_area,
|
||||
set_region,
|
||||
set_service_area,
|
||||
set_virtual_district,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -8,64 +44,95 @@ router = APIRouter()
|
||||
# region 32
|
||||
############################################################
|
||||
|
||||
@router.get("/calculateregion/")
|
||||
async def fastapi_calculate_region(network: str, time_index: int) -> dict[str, Any]:
|
||||
return calculate_region(network, time_index)
|
||||
|
||||
@router.get("/getregionschema/")
|
||||
async def fastapi_get_region_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getregionschema/",
|
||||
summary="获取区域属性架构",
|
||||
description="获取指定水网的区域属性架构定义"
|
||||
)
|
||||
async def fastapi_get_region_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""获取区域的属性架构。"""
|
||||
return get_region_schema(network)
|
||||
|
||||
@router.get("/getregion/")
|
||||
async def fastapi_get_region(network: str, id: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getregion/",
|
||||
summary="获取区域信息",
|
||||
description="获取指定ID的区域详细信息"
|
||||
)
|
||||
async def fastapi_get_region(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
id: str = Query(..., description="区域ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取区域的详细信息。"""
|
||||
return get_region(network, id)
|
||||
|
||||
@router.post("/setregion/", response_model=None)
|
||||
async def fastapi_set_region(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setregion/",
|
||||
response_model=None,
|
||||
summary="设置区域属性",
|
||||
description="修改指定区域的属性信息"
|
||||
)
|
||||
async def fastapi_set_region(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置区域属性。"""
|
||||
props = await req.json()
|
||||
return set_region(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addregion/", response_model=None)
|
||||
async def fastapi_add_region(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/addregion/",
|
||||
response_model=None,
|
||||
summary="添加新区域",
|
||||
description="向水网添加一个新的区域"
|
||||
)
|
||||
async def fastapi_add_region(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加新的区域。"""
|
||||
props = await req.json()
|
||||
return add_region(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deleteregion/", response_model=None)
|
||||
async def fastapi_delete_region(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/deleteregion/",
|
||||
response_model=None,
|
||||
summary="删除区域",
|
||||
description="删除指定的区域"
|
||||
)
|
||||
async def fastapi_delete_region(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""删除区域。"""
|
||||
props = await req.json()
|
||||
return delete_region(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallregions/")
|
||||
async def fastapi_get_all_regions(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_regions(network)
|
||||
|
||||
@router.post("/generateregion/", response_model=None)
|
||||
async def fastapi_generate_region(
|
||||
network: str, inflate_delta: float
|
||||
) -> ChangeSet:
|
||||
return generate_region(network, inflate_delta)
|
||||
|
||||
|
||||
############################################################
|
||||
# district_metering_area 33
|
||||
############################################################
|
||||
|
||||
@router.get("/calculatedistrictmeteringarea/")
|
||||
async def fastapi_calculate_district_metering_area(
|
||||
network: str, req: Request
|
||||
) -> list[list[str]]:
|
||||
props = await req.json()
|
||||
nodes = props["nodes"]
|
||||
part_count = props["part_count"]
|
||||
part_type = props["part_type"]
|
||||
return calculate_district_metering_area(
|
||||
network, nodes, part_count, part_type
|
||||
)
|
||||
|
||||
@router.get("/calculatedistrictmeteringareaforregion/")
|
||||
@router.get(
|
||||
"/calculatedistrictmeteringareaforregion/",
|
||||
summary="计算区域内DMA分区",
|
||||
description="为指定区域计算区域计量(DMA)分区方案"
|
||||
)
|
||||
async def fastapi_calculate_district_metering_area_for_region(
|
||||
network: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
计算区域内DMA分区。
|
||||
|
||||
请求体格式:
|
||||
{
|
||||
"region": 区域ID(str),
|
||||
"part_count": 分区数量(int),
|
||||
"part_type": 分区类型(int)
|
||||
}
|
||||
"""
|
||||
props = await req.json()
|
||||
region = props["region"]
|
||||
part_count = props["part_count"]
|
||||
@@ -74,32 +141,77 @@ async def fastapi_calculate_district_metering_area_for_region(
|
||||
network, region, part_count, part_type
|
||||
)
|
||||
|
||||
@router.get("/calculatedistrictmeteringareafornetwork/")
|
||||
@router.get(
|
||||
"/calculatedistrictmeteringareafornetwork/",
|
||||
summary="计算整网DMA分区",
|
||||
description="为整个水网计算区域计量(DMA)分区方案"
|
||||
)
|
||||
async def fastapi_calculate_district_metering_area_for_network(
|
||||
network: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
计算整网DMA分区。
|
||||
|
||||
请求体格式:
|
||||
{
|
||||
"part_count": 分区数量(int),
|
||||
"part_type": 分区类型(int)
|
||||
}
|
||||
"""
|
||||
props = await req.json()
|
||||
part_count = props["part_count"]
|
||||
part_type = props["part_type"]
|
||||
return calculate_district_metering_area_for_network(network, part_count, part_type)
|
||||
|
||||
@router.get("/getdistrictmeteringareaschema/")
|
||||
@router.get(
|
||||
"/getdistrictmeteringareaschema/",
|
||||
summary="获取DMA属性架构",
|
||||
description="获取指定水网的区域计量(DMA)属性架构定义"
|
||||
)
|
||||
async def fastapi_get_district_metering_area_schema(
|
||||
network: str,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""获取DMA的属性架构。"""
|
||||
return get_district_metering_area_schema(network)
|
||||
|
||||
@router.get("/getdistrictmeteringarea/")
|
||||
async def fastapi_get_district_metering_area(network: str, id: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getdistrictmeteringarea/",
|
||||
summary="获取DMA信息",
|
||||
description="获取指定ID的区域计量(DMA)详细信息"
|
||||
)
|
||||
async def fastapi_get_district_metering_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
id: str = Query(..., description="DMA ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取DMA的详细信息。"""
|
||||
return get_district_metering_area(network, id)
|
||||
|
||||
@router.post("/setdistrictmeteringarea/", response_model=None)
|
||||
async def fastapi_set_district_metering_area(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setdistrictmeteringarea/",
|
||||
response_model=None,
|
||||
summary="设置DMA属性",
|
||||
description="修改指定DMA的属性信息"
|
||||
)
|
||||
async def fastapi_set_district_metering_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置DMA属性。"""
|
||||
props = await req.json()
|
||||
return set_district_metering_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/adddistrictmeteringarea/", response_model=None)
|
||||
async def fastapi_add_district_metering_area(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/adddistrictmeteringarea/",
|
||||
response_model=None,
|
||||
summary="添加新DMA",
|
||||
description="向水网添加一个新的区域计量(DMA)"
|
||||
)
|
||||
async def fastapi_add_district_metering_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加新的DMA。"""
|
||||
props = await req.json()
|
||||
# boundary should be [(x,y), (x,y)]
|
||||
boundary = props.get("boundary", [])
|
||||
@@ -110,33 +222,73 @@ async def fastapi_add_district_metering_area(network: str, req: Request) -> Chan
|
||||
props["boundary"] = newBoundary
|
||||
return add_district_metering_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletedistrictmeteringarea/", response_model=None)
|
||||
@router.post(
|
||||
"/deletedistrictmeteringarea/",
|
||||
response_model=None,
|
||||
summary="删除DMA",
|
||||
description="删除指定的区域计量(DMA)"
|
||||
)
|
||||
async def fastapi_delete_district_metering_area(
|
||||
network: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""删除DMA。"""
|
||||
props = await req.json()
|
||||
return delete_district_metering_area(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getalldistrictmeteringareaids/")
|
||||
async def fastapi_get_all_district_metering_area_ids(network: str) -> list[str]:
|
||||
@router.get(
|
||||
"/getalldistrictmeteringareaids/",
|
||||
summary="获取所有DMA ID",
|
||||
description="获取指定水网中所有DMA的ID列表"
|
||||
)
|
||||
async def fastapi_get_all_district_metering_area_ids(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[str]:
|
||||
"""获取所有DMA的ID列表。"""
|
||||
return get_all_district_metering_area_ids(network)
|
||||
|
||||
@router.get("/getalldistrictmeteringareas/")
|
||||
async def getalldistrictmeteringareas(network: str) -> list[dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getalldistrictmeteringareas/",
|
||||
summary="获取所有DMA",
|
||||
description="获取指定水网中所有DMA的详细信息"
|
||||
)
|
||||
async def getalldistrictmeteringareas(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取所有DMA的详细信息列表。"""
|
||||
return get_all_district_metering_areas(network)
|
||||
|
||||
@router.post("/generatedistrictmeteringarea/", response_model=None)
|
||||
@router.post(
|
||||
"/generatedistrictmeteringarea/",
|
||||
response_model=None,
|
||||
summary="生成DMA分区",
|
||||
description="根据参数自动生成水网的DMA分区方案"
|
||||
)
|
||||
async def fastapi_generate_district_metering_area(
|
||||
network: str, part_count: int, part_type: int, inflate_delta: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
part_count: int = Query(..., description="分区数量", gt=0),
|
||||
part_type: int = Query(..., description="分区类型"),
|
||||
inflate_delta: float = Query(..., description="膨胀参数")
|
||||
) -> ChangeSet:
|
||||
"""生成DMA分区。"""
|
||||
return generate_district_metering_area(
|
||||
network, part_count, part_type, inflate_delta
|
||||
)
|
||||
|
||||
@router.post("/generatesubdistrictmeteringarea/", response_model=None)
|
||||
@router.post(
|
||||
"/generatesubdistrictmeteringarea/",
|
||||
response_model=None,
|
||||
summary="生成DMA子分区",
|
||||
description="为指定DMA生成子DMA分区"
|
||||
)
|
||||
async def fastapi_generate_sub_district_metering_area(
|
||||
network: str, dma: str, part_count: int, part_type: int, inflate_delta: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
dma: str = Query(..., description="DMA ID"),
|
||||
part_count: int = Query(..., description="分区数量", gt=0),
|
||||
part_type: int = Query(..., description="分区类型"),
|
||||
inflate_delta: float = Query(..., description="膨胀参数")
|
||||
) -> ChangeSet:
|
||||
"""生成DMA子分区。"""
|
||||
return generate_sub_district_metering_area(
|
||||
network, dma, part_count, part_type, inflate_delta
|
||||
)
|
||||
@@ -146,43 +298,104 @@ async def fastapi_generate_sub_district_metering_area(
|
||||
# service_area 34
|
||||
############################################################
|
||||
|
||||
@router.get("/calculateservicearea/")
|
||||
@router.get(
|
||||
"/calculateservicearea/",
|
||||
summary="计算服务区",
|
||||
description="计算指定水网的服务区分区,返回全部时间步结果"
|
||||
)
|
||||
async def fastapi_calculate_service_area(
|
||||
network: str, time_index: int
|
||||
) -> dict[str, Any]:
|
||||
return calculate_service_area(network, time_index)
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
) -> list[dict[str, list[str]]]:
|
||||
"""计算服务区分区,返回全部时间步结果。"""
|
||||
return calculate_service_area(network)
|
||||
|
||||
@router.get("/getserviceareaschema/")
|
||||
async def fastapi_get_service_area_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getserviceareaschema/",
|
||||
summary="获取服务区属性架构",
|
||||
description="获取指定水网的服务区属性架构定义"
|
||||
)
|
||||
async def fastapi_get_service_area_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""获取服务区的属性架构。"""
|
||||
return get_service_area_schema(network)
|
||||
|
||||
@router.get("/getservicearea/")
|
||||
async def fastapi_get_service_area(network: str, id: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getservicearea/",
|
||||
summary="获取服务区信息",
|
||||
description="获取指定ID的服务区详细信息"
|
||||
)
|
||||
async def fastapi_get_service_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
id: str = Query(..., description="服务区ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取服务区的详细信息。"""
|
||||
return get_service_area(network, id)
|
||||
|
||||
@router.post("/setservicearea/", response_model=None)
|
||||
async def fastapi_set_service_area(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setservicearea/",
|
||||
response_model=None,
|
||||
summary="设置服务区属性",
|
||||
description="修改指定服务区的属性信息"
|
||||
)
|
||||
async def fastapi_set_service_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置服务区属性。"""
|
||||
props = await req.json()
|
||||
return set_service_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addservicearea/", response_model=None)
|
||||
async def fastapi_add_service_area(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/addservicearea/",
|
||||
response_model=None,
|
||||
summary="添加新服务区",
|
||||
description="向水网添加一个新的服务区"
|
||||
)
|
||||
async def fastapi_add_service_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加新的服务区。"""
|
||||
props = await req.json()
|
||||
return add_service_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deleteservicearea/", response_model=None)
|
||||
async def fastapi_delete_service_area(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/deleteservicearea/",
|
||||
response_model=None,
|
||||
summary="删除服务区",
|
||||
description="删除指定的服务区"
|
||||
)
|
||||
async def fastapi_delete_service_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""删除服务区。"""
|
||||
props = await req.json()
|
||||
return delete_service_area(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallserviceareas/")
|
||||
async def fastapi_get_all_service_areas(network: str) -> list[dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getallserviceareas/",
|
||||
summary="获取所有服务区",
|
||||
description="获取指定水网中的所有服务区信息"
|
||||
)
|
||||
async def fastapi_get_all_service_areas(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取所有服务区的信息列表。"""
|
||||
return get_all_service_areas(network)
|
||||
|
||||
@router.post("/generateservicearea/", response_model=None)
|
||||
@router.post(
|
||||
"/generateservicearea/",
|
||||
response_model=None,
|
||||
summary="生成服务区分区",
|
||||
description="根据参数自动生成水网的服务区分区"
|
||||
)
|
||||
async def fastapi_generate_service_area(
|
||||
network: str, inflate_delta: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
inflate_delta: float = Query(..., description="膨胀参数")
|
||||
) -> ChangeSet:
|
||||
"""生成服务区分区。"""
|
||||
return generate_service_area(network, inflate_delta)
|
||||
|
||||
|
||||
@@ -190,52 +403,128 @@ async def fastapi_generate_service_area(
|
||||
# virtual_district 35
|
||||
############################################################
|
||||
|
||||
@router.get("/calculatevirtualdistrict/")
|
||||
@router.get(
|
||||
"/calculatevirtualdistrict/",
|
||||
summary="计算虚拟分区",
|
||||
description="根据指定的压力监测节点作为中心节点计算虚拟分区方案"
|
||||
)
|
||||
async def fastapi_calculate_virtual_district(
|
||||
network: str, centers: list[str]
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
centers: list[str] = Query(..., description="压力监测节点ID列表")
|
||||
) -> dict[str, list[Any]]:
|
||||
"""计算虚拟分区。"""
|
||||
return calculate_virtual_district(network, centers)
|
||||
|
||||
@router.get("/getvirtualdistrictschema/")
|
||||
@router.get(
|
||||
"/getvirtualdistrictschema/",
|
||||
summary="获取虚拟分区属性架构",
|
||||
description="获取指定水网的虚拟分区属性架构定义"
|
||||
)
|
||||
async def fastapi_get_virtual_district_schema(
|
||||
network: str,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""获取虚拟分区的属性架构。"""
|
||||
return get_virtual_district_schema(network)
|
||||
|
||||
@router.get("/getvirtualdistrict/")
|
||||
async def fastapi_get_virtual_district(network: str, id: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getvirtualdistrict/",
|
||||
summary="获取虚拟分区信息",
|
||||
description="获取指定ID的虚拟分区详细信息"
|
||||
)
|
||||
async def fastapi_get_virtual_district(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
id: str = Query(..., description="虚拟分区ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取虚拟分区的详细信息。"""
|
||||
return get_virtual_district(network, id)
|
||||
|
||||
@router.post("/setvirtualdistrict/", response_model=None)
|
||||
async def fastapi_set_virtual_district(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setvirtualdistrict/",
|
||||
response_model=None,
|
||||
summary="设置虚拟分区属性",
|
||||
description="修改指定虚拟分区的属性信息"
|
||||
)
|
||||
async def fastapi_set_virtual_district(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置虚拟分区属性。"""
|
||||
props = await req.json()
|
||||
return set_virtual_district(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addvirtualdistrict/", response_model=None)
|
||||
async def fastapi_add_virtual_district(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/addvirtualdistrict/",
|
||||
response_model=None,
|
||||
summary="添加新虚拟分区",
|
||||
description="向水网添加一个新的虚拟分区"
|
||||
)
|
||||
async def fastapi_add_virtual_district(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""添加新的虚拟分区。"""
|
||||
props = await req.json()
|
||||
return add_virtual_district(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletevirtualdistrict/", response_model=None)
|
||||
async def fastapi_delete_virtual_district(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/deletevirtualdistrict/",
|
||||
response_model=None,
|
||||
summary="删除虚拟分区",
|
||||
description="删除指定的虚拟分区"
|
||||
)
|
||||
async def fastapi_delete_virtual_district(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""删除虚拟分区。"""
|
||||
props = await req.json()
|
||||
return delete_virtual_district(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallvirtualdistrict/")
|
||||
async def fastapi_get_all_virtual_district(network: str) -> list[dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getallvirtualdistrict/",
|
||||
summary="获取所有虚拟分区",
|
||||
description="获取指定水网中的所有虚拟分区信息"
|
||||
)
|
||||
async def fastapi_get_all_virtual_district(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取所有虚拟分区的信息列表。"""
|
||||
return get_all_virtual_districts(network)
|
||||
|
||||
@router.post("/generatevirtualdistrict/", response_model=None)
|
||||
@router.post(
|
||||
"/generatevirtualdistrict/",
|
||||
response_model=None,
|
||||
summary="生成虚拟分区",
|
||||
description="根据参数自动生成虚拟分区方案"
|
||||
)
|
||||
async def fastapi_generate_virtual_district(
|
||||
network: str, inflate_delta: float, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
inflate_delta: float = Query(..., description="膨胀参数"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""生成虚拟分区。"""
|
||||
props = await req.json()
|
||||
return generate_virtual_district(network, props["centers"], inflate_delta)
|
||||
|
||||
@router.get("/calculatedistrictmeteringareafornodes/")
|
||||
@router.get(
|
||||
"/calculatedistrictmeteringareafornodes/",
|
||||
summary="计算节点DMA分区",
|
||||
description="为指定节点集计算区域计量(DMA)分区方案"
|
||||
)
|
||||
async def fastapi_calculate_district_metering_area_for_nodes(
|
||||
network: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
计算节点DMA分区。
|
||||
|
||||
请求体格式:
|
||||
{
|
||||
"nodes": 节点ID列表(list[str]),
|
||||
"part_count": 分区数量(int),
|
||||
"part_type": 分区类型(int)
|
||||
}
|
||||
"""
|
||||
props = await req.json()
|
||||
nodes = props["nodes"]
|
||||
part_count = props["part_count"]
|
||||
|
||||
@@ -1,105 +1,422 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_reservoir,
|
||||
delete_reservoir,
|
||||
get_all_reservoirs,
|
||||
get_reservoir,
|
||||
get_reservoir_schema,
|
||||
set_reservoir,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getreservoirschema")
|
||||
async def fast_get_reservoir_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getreservoirschema",
|
||||
summary="获取水库模式",
|
||||
description="获取指定供水网络中所有水库的模式/属性字段定义"
|
||||
)
|
||||
async def fast_get_reservoir_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取水库模式定义。
|
||||
|
||||
该端点返回指定网络中水库对象的模式定义,包括所有可用的属性字段。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
水库属性的模式定义字典
|
||||
"""
|
||||
return get_reservoir_schema(network)
|
||||
|
||||
@router.post("/addreservoir/", response_model=None)
|
||||
@router.post(
|
||||
"/addreservoir/",
|
||||
response_model=None,
|
||||
summary="添加水库",
|
||||
description="在指定供水网络中添加新的水库/水源节点"
|
||||
)
|
||||
async def fastapi_add_reservoir(
|
||||
network: str, reservoir: str, x: float, y: float, head: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符"),
|
||||
x: float = Query(..., description="水库的X坐标"),
|
||||
y: float = Query(..., description="水库的Y坐标"),
|
||||
head: float = Query(..., description="水库的水头/总水头(米)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新的水库/水源节点。
|
||||
|
||||
在指定的供水网络中创建一个新的水库,并设置其坐标和水头参数。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
x: 水库的X坐标位置
|
||||
y: 水库的Y坐标位置
|
||||
head: 水库的供水水头(以米为单位)
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": reservoir, "x": x, "y": y, "head": head}
|
||||
return add_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletereservoir/", response_model=None)
|
||||
async def fastapi_delete_reservoir(network: str, reservoir: str) -> ChangeSet:
|
||||
@router.post(
|
||||
"/deletereservoir/",
|
||||
response_model=None,
|
||||
summary="删除水库",
|
||||
description="从指定供水网络中删除指定的水库/水源节点"
|
||||
)
|
||||
async def fastapi_delete_reservoir(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="要删除的水库的唯一标识符")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除指定的水库节点。
|
||||
|
||||
从指定的供水网络中删除一个水库及其相关的所有连接关系。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": reservoir}
|
||||
return delete_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getreservoirhead/")
|
||||
async def fastapi_get_reservoir_head(network: str, reservoir: str) -> float | None:
|
||||
@router.get(
|
||||
"/getreservoirhead/",
|
||||
summary="获取水库水头",
|
||||
description="获取指定水库的供水水头/总水头值"
|
||||
)
|
||||
async def fastapi_get_reservoir_head(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取水库的水头参数。
|
||||
|
||||
返回指定水库的供水水头(总水头),单位为米。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
|
||||
Returns:
|
||||
水库的水头值(米),如果水库不存在则返回None
|
||||
"""
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["head"]
|
||||
|
||||
@router.get("/getreservoirpattern/")
|
||||
async def fastapi_get_reservoir_pattern(network: str, reservoir: str) -> str | None:
|
||||
@router.get(
|
||||
"/getreservoirpattern/",
|
||||
summary="获取水库模式",
|
||||
description="获取指定水库的运行模式/供水模式"
|
||||
)
|
||||
async def fastapi_get_reservoir_pattern(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取水库的运行模式。
|
||||
|
||||
返回指定水库的供水模式,如固定水头模式、时间序列模式等。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
|
||||
Returns:
|
||||
水库的运行模式字符串,如果水库不存在则返回None
|
||||
"""
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["pattern"]
|
||||
|
||||
@router.get("/getreservoirx/")
|
||||
@router.get(
|
||||
"/getreservoirx/",
|
||||
summary="获取水库X坐标",
|
||||
description="获取指定水库的X坐标位置"
|
||||
)
|
||||
async def fastapi_get_reservoir_x(
|
||||
network: str, reservoir: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符")
|
||||
) -> dict[str, float] | None:
|
||||
"""
|
||||
获取水库的X坐标。
|
||||
|
||||
返回指定水库的X轴坐标值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
|
||||
Returns:
|
||||
水库的X坐标值,如果水库不存在则返回None
|
||||
"""
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["x"]
|
||||
|
||||
@router.get("/getreservoiry/")
|
||||
@router.get(
|
||||
"/getreservoiry/",
|
||||
summary="获取水库Y坐标",
|
||||
description="获取指定水库的Y坐标位置"
|
||||
)
|
||||
async def fastapi_get_reservoir_y(
|
||||
network: str, reservoir: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符")
|
||||
) -> dict[str, float] | None:
|
||||
"""
|
||||
获取水库的Y坐标。
|
||||
|
||||
返回指定水库的Y轴坐标值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
|
||||
Returns:
|
||||
水库的Y坐标值,如果水库不存在则返回None
|
||||
"""
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["y"]
|
||||
|
||||
@router.get("/getreservoircoord/")
|
||||
@router.get(
|
||||
"/getreservoircoord/",
|
||||
summary="获取水库坐标",
|
||||
description="获取指定水库的平面坐标(X和Y坐标)"
|
||||
)
|
||||
async def fastapi_get_reservoir_coord(
|
||||
network: str, reservoir: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符")
|
||||
) -> dict[str, float] | None:
|
||||
"""
|
||||
获取水库的坐标。
|
||||
|
||||
返回指定水库的平面坐标,包含水库ID、X坐标和Y坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
|
||||
Returns:
|
||||
包含water库ID和X、Y坐标的字典,如果水库不存在则返回None
|
||||
"""
|
||||
ps = get_reservoir(network, reservoir)
|
||||
coord = {"id": reservoir, "x": ps["x"], "y": ps["y"]}
|
||||
return coord
|
||||
|
||||
@router.post("/setreservoirhead/", response_model=None)
|
||||
@router.post(
|
||||
"/setreservoirhead/",
|
||||
response_model=None,
|
||||
summary="设置水库水头",
|
||||
description="更新指定水库的供水水头/总水头值"
|
||||
)
|
||||
async def fastapi_set_reservoir_head(
|
||||
network: str, reservoir: str, head: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符"),
|
||||
head: float = Query(..., description="新的水头值(米)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水库的水头参数。
|
||||
|
||||
更新指定水库的供水水头(总水头)值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
head: 新的水头值(以米为单位)
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": reservoir, "head": head}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoirpattern/", response_model=None)
|
||||
@router.post(
|
||||
"/setreservoirpattern/",
|
||||
response_model=None,
|
||||
summary="设置水库模式",
|
||||
description="更新指定水库的运行模式/供水模式"
|
||||
)
|
||||
async def fastapi_set_reservoir_pattern(
|
||||
network: str, reservoir: str, pattern: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符"),
|
||||
pattern: str = Query(..., description="新的运行模式")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水库的运行模式。
|
||||
|
||||
更新指定水库的供水模式,如固定水头模式、时间序列模式等。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
pattern: 新的运行模式字符串
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": reservoir, "pattern": pattern}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoirx/", response_model=None)
|
||||
async def fastapi_set_reservoir_x(network: str, reservoir: str, x: float) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setreservoirx/",
|
||||
response_model=None,
|
||||
summary="设置水库X坐标",
|
||||
description="更新指定水库的X坐标位置"
|
||||
)
|
||||
async def fastapi_set_reservoir_x(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符"),
|
||||
x: float = Query(..., description="新的X坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水库的X坐标。
|
||||
|
||||
更新指定水库的X轴坐标值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
x: 新的X坐标值
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": reservoir, "x": x}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoiry/", response_model=None)
|
||||
async def fastapi_set_reservoir_y(network: str, reservoir: str, y: float) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setreservoiry/",
|
||||
response_model=None,
|
||||
summary="设置水库Y坐标",
|
||||
description="更新指定水库的Y坐标位置"
|
||||
)
|
||||
async def fastapi_set_reservoir_y(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符"),
|
||||
y: float = Query(..., description="新的Y坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水库的Y坐标。
|
||||
|
||||
更新指定水库的Y轴坐标值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
y: 新的Y坐标值
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": reservoir, "y": y}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoircoord/", response_model=None)
|
||||
@router.post(
|
||||
"/setreservoircoord/",
|
||||
response_model=None,
|
||||
summary="设置水库坐标",
|
||||
description="更新指定水库的平面坐标(X和Y坐标)"
|
||||
)
|
||||
async def fastapi_set_reservoir_coord(
|
||||
network: str, reservoir: str, x: float, y: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符"),
|
||||
x: float = Query(..., description="新的X坐标值"),
|
||||
y: float = Query(..., description="新的Y坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水库的坐标。
|
||||
|
||||
更新指定水库的平面坐标,包括X和Y坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
x: 新的X坐标值
|
||||
y: 新的Y坐标值
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": reservoir, "x": x, "y": y}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getreservoirproperties/")
|
||||
@router.get(
|
||||
"/getreservoirproperties/",
|
||||
summary="获取水库属性",
|
||||
description="获取指定水库的所有属性"
|
||||
)
|
||||
async def fastapi_get_reservoir_properties(
|
||||
network: str, reservoir: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取水库的所有属性。
|
||||
|
||||
返回指定水库的完整属性信息,包括ID、坐标、水头、模式等所有属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
|
||||
Returns:
|
||||
包含水库所有属性的字典
|
||||
"""
|
||||
return get_reservoir(network, reservoir)
|
||||
|
||||
@router.get("/getallreservoirproperties/")
|
||||
async def fastapi_get_all_reservoir_properties(network: str) -> list[dict[str, Any]]:
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
@router.get(
|
||||
"/getallreservoirproperties/",
|
||||
summary="获取所有水库属性",
|
||||
description="获取指定供水网络中所有水库的属性"
|
||||
)
|
||||
async def fastapi_get_all_reservoir_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取所有水库的属性。
|
||||
|
||||
返回指定供水网络中所有水库的完整属性信息列表。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含所有水库属性的字典列表
|
||||
"""
|
||||
results = get_all_reservoirs(network)
|
||||
return results
|
||||
|
||||
@router.post("/setreservoirproperties/", response_model=None)
|
||||
@router.post(
|
||||
"/setreservoirproperties/",
|
||||
response_model=None,
|
||||
summary="设置水库属性",
|
||||
description="批量更新指定水库的多个属性"
|
||||
)
|
||||
async def fastapi_set_reservoir_properties(
|
||||
network: str, reservoir: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
reservoir: str = Query(..., description="水库的唯一标识符"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水库的多个属性。
|
||||
|
||||
批量更新指定水库的属性。属性通过JSON请求体传递。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
reservoir: 水库的唯一标识符
|
||||
req: HTTP请求对象,包含JSON格式的属性数据
|
||||
|
||||
Returns:
|
||||
包含操作变更集的ChangeSet对象
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": reservoir} | props
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
get_tag,
|
||||
get_tag_schema,
|
||||
get_tags,
|
||||
set_tag,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -8,20 +15,52 @@ router = APIRouter()
|
||||
# tag 8.[TAGS]
|
||||
############################################################
|
||||
|
||||
@router.get("/gettagschema/")
|
||||
async def fastapi_get_tag_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/gettagschema/",
|
||||
summary="获取标签属性架构",
|
||||
description="获取指定水网的标签(Tag)属性架构定义"
|
||||
)
|
||||
async def fastapi_get_tag_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""获取标签的属性架构。"""
|
||||
return get_tag_schema(network)
|
||||
|
||||
@router.get("/gettag/")
|
||||
async def fastapi_get_tag(network: str, t_type: str, id: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/gettag/",
|
||||
summary="获取标签信息",
|
||||
description="获取指定类型和ID的标签信息"
|
||||
)
|
||||
async def fastapi_get_tag(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
t_type: str = Query(..., description="标签类型"),
|
||||
id: str = Query(..., description="元素ID")
|
||||
) -> dict[str, Any]:
|
||||
"""获取标签信息。"""
|
||||
return get_tag(network, t_type, id)
|
||||
|
||||
@router.get("/gettags/")
|
||||
async def fastapi_get_tags(network: str) -> list[dict[str, Any]]:
|
||||
@router.get(
|
||||
"/gettags/",
|
||||
summary="获取所有标签",
|
||||
description="获取指定水网中的所有标签信息"
|
||||
)
|
||||
async def fastapi_get_tags(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取水网中所有标签的列表。"""
|
||||
tags = get_tags(network)
|
||||
return tags
|
||||
|
||||
@router.post("/settag/", response_model=None)
|
||||
async def fastapi_set_tag(network: str, req: Request) -> ChangeSet:
|
||||
@router.post(
|
||||
"/settag/",
|
||||
response_model=None,
|
||||
summary="设置标签",
|
||||
description="为指定元素设置或修改标签信息"
|
||||
)
|
||||
async def fastapi_set_tag(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""设置标签信息。"""
|
||||
props = await req.json()
|
||||
return set_tag(network, ChangeSet(props))
|
||||
|
||||
@@ -1,26 +1,62 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
add_tank,
|
||||
delete_tank,
|
||||
get_all_tanks,
|
||||
get_tank,
|
||||
get_tank_schema,
|
||||
set_tank,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/gettankschema")
|
||||
async def fast_get_tank_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/gettankschema", summary="获取水箱模式", description="获取指定网络的水箱数据结构模式定义")
|
||||
async def fast_get_tank_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取水箱的数据结构模式。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含水箱属性的模式定义字典
|
||||
"""
|
||||
return get_tank_schema(network)
|
||||
|
||||
@router.post("/addtank/", response_model=None)
|
||||
@router.post("/addtank/", summary="新增水箱", description="向指定网络中新增一个水箱", response_model=None)
|
||||
async def fastapi_add_tank(
|
||||
network: str,
|
||||
tank: str,
|
||||
x: float,
|
||||
y: float,
|
||||
elevation: float,
|
||||
init_level: float = 0,
|
||||
min_level: float = 0,
|
||||
max_level: float = 0,
|
||||
diameter: float = 0,
|
||||
min_vol: float = 0,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
x: float = Query(..., description="X坐标"),
|
||||
y: float = Query(..., description="Y坐标"),
|
||||
elevation: float = Query(..., description="标高"),
|
||||
init_level: float = Query(0, description="初始水位"),
|
||||
min_level: float = Query(0, description="最小水位"),
|
||||
max_level: float = Query(0, description="最大水位"),
|
||||
diameter: float = Query(0, description="直径"),
|
||||
min_vol: float = Query(0, description="最小体积"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
创建新水箱。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
x: X坐标
|
||||
y: Y坐标
|
||||
elevation: 水箱标高
|
||||
init_level: 初始水位,默认为0
|
||||
min_level: 最小水位,默认为0
|
||||
max_level: 最大水位,默认为0
|
||||
diameter: 水箱直径,默认为0
|
||||
min_vol: 最小体积,默认为0
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {
|
||||
"id": tank,
|
||||
"x": x,
|
||||
@@ -34,155 +70,497 @@ async def fastapi_add_tank(
|
||||
}
|
||||
return add_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletetank/", response_model=None)
|
||||
async def fastapi_delete_tank(network: str, tank: str) -> ChangeSet:
|
||||
@router.post("/deletetank/", summary="删除水箱", description="删除指定网络中的水箱", response_model=None)
|
||||
async def fastapi_delete_tank(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除指定的水箱。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank}
|
||||
return delete_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/gettankelevation/")
|
||||
async def fastapi_get_tank_elevation(network: str, tank: str) -> float | None:
|
||||
@router.get("/gettankelevation/", summary="获取水箱标高", description="获取指定水箱的标高值")
|
||||
async def fastapi_get_tank_elevation(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取水箱的标高。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱标高值,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["elevation"]
|
||||
|
||||
@router.get("/gettankinitlevel/")
|
||||
async def fastapi_get_tank_init_level(network: str, tank: str) -> float | None:
|
||||
@router.get("/gettankinitlevel/", summary="获取水箱初始水位", description="获取指定水箱的初始水位值")
|
||||
async def fastapi_get_tank_init_level(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取水箱的初始水位。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱初始水位值,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["init_level"]
|
||||
|
||||
@router.get("/gettankminlevel/")
|
||||
async def fastapi_get_tank_min_level(network: str, tank: str) -> float | None:
|
||||
@router.get("/gettankminlevel/", summary="获取水箱最小水位", description="获取指定水箱的最小水位值")
|
||||
async def fastapi_get_tank_min_level(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取水箱的最小水位。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱最小水位值,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["min_level"]
|
||||
|
||||
@router.get("/gettankmaxlevel/")
|
||||
async def fastapi_get_tank_max_level(network: str, tank: str) -> float | None:
|
||||
@router.get("/gettankmaxlevel/", summary="获取水箱最大水位", description="获取指定水箱的最大水位值")
|
||||
async def fastapi_get_tank_max_level(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取水箱的最大水位。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱最大水位值,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["max_level"]
|
||||
|
||||
@router.get("/gettankdiameter/")
|
||||
async def fastapi_get_tank_diameter(network: str, tank: str) -> float | None:
|
||||
@router.get("/gettankdiameter/", summary="获取水箱直径", description="获取指定水箱的直径值")
|
||||
async def fastapi_get_tank_diameter(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取水箱的直径。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱直径值,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["diameter"]
|
||||
|
||||
@router.get("/gettankminvol/")
|
||||
async def fastapi_get_tank_min_vol(network: str, tank: str) -> float | None:
|
||||
@router.get("/gettankminvol/", summary="获取水箱最小体积", description="获取指定水箱的最小体积值")
|
||||
async def fastapi_get_tank_min_vol(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float | None:
|
||||
"""
|
||||
获取水箱的最小体积。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱最小体积值,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["min_vol"]
|
||||
|
||||
@router.get("/gettankvolcurve/")
|
||||
async def fastapi_get_tank_vol_curve(network: str, tank: str) -> str | None:
|
||||
@router.get("/gettankvolcurve/", summary="获取水箱容积曲线", description="获取指定水箱的容积曲线标识")
|
||||
async def fastapi_get_tank_vol_curve(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取水箱的容积曲线。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱容积曲线标识,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["vol_curve"]
|
||||
|
||||
@router.get("/gettankoverflow/")
|
||||
async def fastapi_get_tank_overflow(network: str, tank: str) -> str | None:
|
||||
@router.get("/gettankoverflow/", summary="获取水箱溢流口", description="获取指定水箱的溢流口配置")
|
||||
async def fastapi_get_tank_overflow(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> str | None:
|
||||
"""
|
||||
获取水箱的溢流口配置。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱溢流口配置,如果不存在则返回None
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["overflow"]
|
||||
|
||||
@router.get("/gettankx/")
|
||||
async def fastapi_get_tank_x(network: str, tank: str) -> float:
|
||||
@router.get("/gettankx/", summary="获取水箱X坐标", description="获取指定水箱的X坐标值")
|
||||
async def fastapi_get_tank_x(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float:
|
||||
"""
|
||||
获取水箱的X坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱X坐标值
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["x"]
|
||||
|
||||
@router.get("/gettanky/")
|
||||
async def fastapi_get_tank_y(network: str, tank: str) -> float:
|
||||
@router.get("/gettanky/", summary="获取水箱Y坐标", description="获取指定水箱的Y坐标值")
|
||||
async def fastapi_get_tank_y(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> float:
|
||||
"""
|
||||
获取水箱的Y坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
水箱Y坐标值
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
return ps["y"]
|
||||
|
||||
@router.get("/gettankcoord/")
|
||||
async def fastapi_get_tank_coord(network: str, tank: str) -> dict[str, float]:
|
||||
@router.get("/gettankcoord/", summary="获取水箱坐标", description="获取指定水箱的X和Y坐标")
|
||||
async def fastapi_get_tank_coord(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
获取水箱的坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
包含x和y坐标的字典
|
||||
"""
|
||||
ps = get_tank(network, tank)
|
||||
coord = {"x": ps["x"], "y": ps["y"]}
|
||||
return coord
|
||||
|
||||
@router.post("/settankelevation/", response_model=None)
|
||||
@router.post("/settankelevation/", summary="设置水箱标高", description="设置指定水箱的标高值", response_model=None)
|
||||
async def fastapi_set_tank_elevation(
|
||||
network: str, tank: str, elevation: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
elevation: float = Query(..., description="新的标高值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的标高。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
elevation: 新的标高值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "elevation": elevation}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankinitlevel/", response_model=None)
|
||||
@router.post("/settankinitlevel/", summary="设置水箱初始水位", description="设置指定水箱的初始水位值", response_model=None)
|
||||
async def fastapi_set_tank_init_level(
|
||||
network: str, tank: str, init_level: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
init_level: float = Query(..., description="新的初始水位值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的初始水位。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
init_level: 新的初始水位值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "init_level": init_level}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankminlevel/", response_model=None)
|
||||
@router.post("/settankminlevel/", summary="设置水箱最小水位", description="设置指定水箱的最小水位值", response_model=None)
|
||||
async def fastapi_set_tank_min_level(
|
||||
network: str, tank: str, min_level: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
min_level: float = Query(..., description="新的最小水位值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的最小水位。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
min_level: 新的最小水位值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "min_level": min_level}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankmaxlevel/", response_model=None)
|
||||
@router.post("/settankmaxlevel/", summary="设置水箱最大水位", description="设置指定水箱的最大水位值", response_model=None)
|
||||
async def fastapi_set_tank_max_level(
|
||||
network: str, tank: str, max_level: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
max_level: float = Query(..., description="新的最大水位值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的最大水位。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
max_level: 新的最大水位值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "max_level": max_level}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("settankdiameter//", response_model=None)
|
||||
@router.post("/settankdiameter/", summary="设置水箱直径", description="设置指定水箱的直径值", response_model=None)
|
||||
async def fastapi_set_tank_diameter(
|
||||
network: str, tank: str, diameter: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
diameter: float = Query(..., description="新的直径值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的直径。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
diameter: 新的直径值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "diameter": diameter}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankminvol/", response_model=None)
|
||||
@router.post("/settankminvol/", summary="设置水箱最小体积", description="设置指定水箱的最小体积值", response_model=None)
|
||||
async def fastapi_set_tank_min_vol(
|
||||
network: str, tank: str, min_vol: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
min_vol: float = Query(..., description="新的最小体积值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的最小体积。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
min_vol: 新的最小体积值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "min_vol": min_vol}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankvolcurve/", response_model=None)
|
||||
@router.post("/settankvolcurve/", summary="设置水箱容积曲线", description="设置指定水箱的容积曲线标识", response_model=None)
|
||||
async def fastapi_set_tank_vol_curve(
|
||||
network: str, tank: str, vol_curve: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
vol_curve: str = Query(..., description="新的容积曲线标识")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的容积曲线。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
vol_curve: 新的容积曲线标识
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "vol_curve": vol_curve}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankoverflow/", response_model=None)
|
||||
@router.post("/settankoverflow/", summary="设置水箱溢流口", description="设置指定水箱的溢流口配置", response_model=None)
|
||||
async def fastapi_set_tank_overflow(
|
||||
network: str, tank: str, overflow: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
overflow: str = Query(..., description="新的溢流口配置")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的溢流口配置。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
overflow: 新的溢流口配置
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "overflow": overflow}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankx/", response_model=None)
|
||||
async def fastapi_set_tank_x(network: str, tank: str, x: float) -> ChangeSet:
|
||||
@router.post("/settankx/", summary="设置水箱X坐标", description="设置指定水箱的X坐标值", response_model=None)
|
||||
async def fastapi_set_tank_x(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
x: float = Query(..., description="新的X坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的X坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
x: 新的X坐标值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "x": x}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settanky/", response_model=None)
|
||||
async def fastapi_set_tank_y(network: str, tank: str, y: float) -> ChangeSet:
|
||||
@router.post("/settanky/", summary="设置水箱Y坐标", description="设置指定水箱的Y坐标值", response_model=None)
|
||||
async def fastapi_set_tank_y(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
y: float = Query(..., description="新的Y坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的Y坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
y: 新的Y坐标值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "y": y}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankcoord/", response_model=None)
|
||||
@router.post("/settankcoord/", summary="设置水箱坐标", description="设置指定水箱的X和Y坐标", response_model=None)
|
||||
async def fastapi_set_tank_coord(
|
||||
network: str, tank: str, x: float, y: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
x: float = Query(..., description="新的X坐标值"),
|
||||
y: float = Query(..., description="新的Y坐标值")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置水箱的坐标。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
x: 新的X坐标值
|
||||
y: 新的Y坐标值
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
ps = {"id": tank, "x": x, "y": y}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/gettankproperties/")
|
||||
async def fastapi_get_tank_properties(network: str, tank: str) -> dict[str, Any]:
|
||||
@router.get("/gettankproperties/", summary="获取水箱属性", description="获取指定水箱的所有属性")
|
||||
async def fastapi_get_tank_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取水箱的所有属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
|
||||
Returns:
|
||||
包含水箱所有属性的字典
|
||||
"""
|
||||
return get_tank(network, tank)
|
||||
|
||||
@router.get("/getalltankproperties/")
|
||||
async def fastapi_get_all_tank_properties(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getalltankproperties/", summary="获取所有水箱属性", description="获取指定网络中所有水箱的属性")
|
||||
async def fastapi_get_all_tank_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取网络中所有水箱的属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含所有水箱属性的字典列表
|
||||
"""
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_tanks(network)
|
||||
return results
|
||||
|
||||
@router.post("/settankproperties/", response_model=None)
|
||||
@router.post("/settankproperties/", summary="设置水箱属性", description="批量设置指定水箱的多个属性", response_model=None)
|
||||
async def fastapi_set_tank_properties(
|
||||
network: str, tank: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
tank: str = Query(..., description="水箱ID"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
批量设置水箱的属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
tank: 水箱ID
|
||||
req: 包含水箱属性的请求体(JSON格式)
|
||||
|
||||
Returns:
|
||||
包含变更信息的ChangeSet对象
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": tank} | props
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@@ -1,24 +1,55 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query, Path, Body
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import (
|
||||
Any,
|
||||
ChangeSet,
|
||||
VALVES_TYPE_PRV,
|
||||
add_valve,
|
||||
delete_valve,
|
||||
get_all_valves,
|
||||
get_valve,
|
||||
get_valve_schema,
|
||||
set_valve,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getvalveschema")
|
||||
async def fastapi_get_valve_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getvalveschema",
|
||||
summary="获取阀门架构",
|
||||
description="获取指定水网中所有阀门的架构和字段定义",
|
||||
)
|
||||
async def fastapi_get_valve_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取阀门架构。
|
||||
|
||||
返回指定水网中所有阀门类型的完整架构定义,包括字段名称、类型和默认值。
|
||||
"""
|
||||
return get_valve_schema(network)
|
||||
|
||||
@router.post("/addvalve/", response_model=None)
|
||||
@router.post(
|
||||
"/addvalve/",
|
||||
response_model=None,
|
||||
summary="添加阀门",
|
||||
description="在指定的水网中添加新的阀门",
|
||||
)
|
||||
async def fastapi_add_valve(
|
||||
network: str,
|
||||
valve: str,
|
||||
node1: str,
|
||||
node2: str,
|
||||
diameter: float = 0,
|
||||
v_type: str = VALVES_TYPE_PRV,
|
||||
setting: float = 0,
|
||||
minor_loss: float = 0,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
node1: str = Query(..., description="起点节点ID"),
|
||||
node2: str = Query(..., description="终点节点ID"),
|
||||
diameter: float = Query(0, description="阀门直径(mm)"),
|
||||
v_type: str = Query(VALVES_TYPE_PRV, description="阀门类型"),
|
||||
setting: float = Query(0, description="阀门开度/设置值"),
|
||||
minor_loss: float = Query(0, description="损失系数"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新的阀门。
|
||||
|
||||
在指定的水网中创建一个新的阀门,设置其连接的两个节点、直径、类型、开度和损失系数。
|
||||
"""
|
||||
ps = {
|
||||
"id": valve,
|
||||
"node1": node1,
|
||||
@@ -31,85 +62,271 @@ async def fastapi_add_valve(
|
||||
|
||||
return add_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletevalve/", response_model=None)
|
||||
async def fastapi_delete_valve(network: str, valve: str) -> ChangeSet:
|
||||
@router.post(
|
||||
"/deletevalve/",
|
||||
response_model=None,
|
||||
summary="删除阀门",
|
||||
description="从指定的水网中删除指定的阀门",
|
||||
)
|
||||
async def fastapi_delete_valve(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除阀门。
|
||||
|
||||
从指定的水网中删除指定ID的阀门。
|
||||
"""
|
||||
ps = {"id": valve}
|
||||
return delete_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getvalvenode1/")
|
||||
async def fastapi_get_valve_node1(network: str, valve: str) -> str | None:
|
||||
@router.get(
|
||||
"/getvalvenode1/",
|
||||
summary="获取阀门起点节点",
|
||||
description="获取指定阀门连接的起点节点ID",
|
||||
)
|
||||
async def fastapi_get_valve_node1(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> str | None:
|
||||
"""
|
||||
获取阀门的起点节点。
|
||||
|
||||
返回指定阀门连接的起点(第一个)节点的ID。
|
||||
"""
|
||||
ps = get_valve(network, valve)
|
||||
return ps["node1"]
|
||||
|
||||
@router.get("/getvalvenode2/")
|
||||
async def fastapi_get_valve_node2(network: str, valve: str) -> str | None:
|
||||
@router.get(
|
||||
"/getvalvenode2/",
|
||||
summary="获取阀门终点节点",
|
||||
description="获取指定阀门连接的终点节点ID",
|
||||
)
|
||||
async def fastapi_get_valve_node2(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> str | None:
|
||||
"""
|
||||
获取阀门的终点节点。
|
||||
|
||||
返回指定阀门连接的终点(第二个)节点的ID。
|
||||
"""
|
||||
ps = get_valve(network, valve)
|
||||
return ps["node2"]
|
||||
|
||||
@router.get("/getvalvediameter/")
|
||||
async def fastapi_get_valve_diameter(network: str, valve: str) -> float | None:
|
||||
@router.get(
|
||||
"/getvalvediameter/",
|
||||
summary="获取阀门直径",
|
||||
description="获取指定阀门的直径",
|
||||
)
|
||||
async def fastapi_get_valve_diameter(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> float | None:
|
||||
"""
|
||||
获取阀门的直径。
|
||||
|
||||
返回指定阀门的直径值(单位:mm)。
|
||||
"""
|
||||
ps = get_valve(network, valve)
|
||||
return ps["diameter"]
|
||||
|
||||
@router.get("/getvalvetype/")
|
||||
async def fastapi_get_valve_type(network: str, valve: str) -> str | None:
|
||||
@router.get(
|
||||
"/getvalvetype/",
|
||||
summary="获取阀门类型",
|
||||
description="获取指定阀门的类型",
|
||||
)
|
||||
async def fastapi_get_valve_type(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> str | None:
|
||||
"""
|
||||
获取阀门的类型。
|
||||
|
||||
返回指定阀门的类型(例如:减压阀、调节阀等)。
|
||||
"""
|
||||
ps = get_valve(network, valve)
|
||||
return ps["type"]
|
||||
|
||||
@router.get("/getvalvesetting/")
|
||||
async def fastapi_get_valve_setting(network: str, valve: str) -> float | None:
|
||||
@router.get(
|
||||
"/getvalvesetting/",
|
||||
summary="获取阀门开度",
|
||||
description="获取指定阀门的开度/设置值",
|
||||
)
|
||||
async def fastapi_get_valve_setting(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> float | None:
|
||||
"""
|
||||
获取阀门的开度。
|
||||
|
||||
返回指定阀门的开度/设置值。
|
||||
"""
|
||||
ps = get_valve(network, valve)
|
||||
return ps["setting"]
|
||||
|
||||
@router.get("/getvalveminorloss/")
|
||||
async def fastapi_get_valve_minor_loss(network: str, valve: str) -> float | None:
|
||||
@router.get(
|
||||
"/getvalveminorloss/",
|
||||
summary="获取阀门损失系数",
|
||||
description="获取指定阀门的损失系数",
|
||||
)
|
||||
async def fastapi_get_valve_minor_loss(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> float | None:
|
||||
"""
|
||||
获取阀门的损失系数。
|
||||
|
||||
返回指定阀门的损失系数值,用于计算流体通过阀门的压力损失。
|
||||
"""
|
||||
ps = get_valve(network, valve)
|
||||
return ps["minor_loss"]
|
||||
|
||||
@router.post("/setvalvenode1/", response_model=None)
|
||||
async def fastapi_set_valve_node1(network: str, valve: str, node1: str) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setvalvenode1/",
|
||||
response_model=None,
|
||||
summary="设置阀门起点节点",
|
||||
description="设置指定阀门的起点节点",
|
||||
)
|
||||
async def fastapi_set_valve_node1(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
node1: str = Query(..., description="新的起点节点ID"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置阀门的起点节点。
|
||||
|
||||
更新指定阀门的起点节点连接。
|
||||
"""
|
||||
ps = {"id": valve, "node1": node1}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvenode2/", response_model=None)
|
||||
async def fastapi_set_valve_node2(network: str, valve: str, node2: str) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setvalvenode2/",
|
||||
response_model=None,
|
||||
summary="设置阀门终点节点",
|
||||
description="设置指定阀门的终点节点",
|
||||
)
|
||||
async def fastapi_set_valve_node2(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
node2: str = Query(..., description="新的终点节点ID"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置阀门的终点节点。
|
||||
|
||||
更新指定阀门的终点节点连接。
|
||||
"""
|
||||
ps = {"id": valve, "node2": node2}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvenodediameter/", response_model=None)
|
||||
@router.post(
|
||||
"/setvalvenodediameter/",
|
||||
response_model=None,
|
||||
summary="设置阀门直径",
|
||||
description="设置指定阀门的直径",
|
||||
)
|
||||
async def fastapi_set_valve_diameter(
|
||||
network: str, valve: str, diameter: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
diameter: float = Query(..., description="新的直径值(mm)"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置阀门的直径。
|
||||
|
||||
更新指定阀门的直径值。
|
||||
"""
|
||||
ps = {"id": valve, "diameter": diameter}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvetype/", response_model=None)
|
||||
async def fastapi_set_valve_type(network: str, valve: str, type: str) -> ChangeSet:
|
||||
@router.post(
|
||||
"/setvalvetype/",
|
||||
response_model=None,
|
||||
summary="设置阀门类型",
|
||||
description="设置指定阀门的类型",
|
||||
)
|
||||
async def fastapi_set_valve_type(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
type: str = Query(..., description="新的阀门类型"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置阀门的类型。
|
||||
|
||||
更新指定阀门的类型(例如:减压阀、调节阀等)。
|
||||
"""
|
||||
ps = {"id": valve, "type": type}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvesetting/", response_model=None)
|
||||
@router.post(
|
||||
"/setvalvesetting/",
|
||||
response_model=None,
|
||||
summary="设置阀门开度",
|
||||
description="设置指定阀门的开度/设置值",
|
||||
)
|
||||
async def fastapi_set_valve_setting(
|
||||
network: str, valve: str, setting: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
setting: float = Query(..., description="新的开度值"),
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
设置阀门的开度。
|
||||
|
||||
更新指定阀门的开度/设置值。
|
||||
"""
|
||||
ps = {"id": valve, "setting": setting}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getvalveproperties/")
|
||||
async def fastapi_get_valve_properties(network: str, valve: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getvalveproperties/",
|
||||
summary="获取阀门所有属性",
|
||||
description="获取指定阀门的所有属性",
|
||||
)
|
||||
async def fastapi_get_valve_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取阀门的所有属性。
|
||||
|
||||
返回指定阀门的完整属性集合,包括ID、节点、直径、类型、开度和损失系数。
|
||||
"""
|
||||
return get_valve(network, valve)
|
||||
|
||||
@router.get("/getallvalveproperties/")
|
||||
async def fastapi_get_all_valve_properties(network: str) -> list[dict[str, Any]]:
|
||||
@router.get(
|
||||
"/getallvalveproperties/",
|
||||
summary="获取所有阀门属性",
|
||||
description="获取指定水网中所有阀门的属性",
|
||||
)
|
||||
async def fastapi_get_all_valve_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取所有阀门的属性。
|
||||
|
||||
返回指定水网中所有阀门的完整属性列表。
|
||||
"""
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_valves(network)
|
||||
return results
|
||||
|
||||
@router.post("/setvalveproperties/", response_model=None)
|
||||
@router.post(
|
||||
"/setvalveproperties/",
|
||||
response_model=None,
|
||||
summary="批量设置阀门属性",
|
||||
description="批量设置指定阀门的多个属性",
|
||||
)
|
||||
async def fastapi_set_valve_properties(
|
||||
network: str, valve: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
valve: str = Query(..., description="阀门ID"),
|
||||
req: Request = None,
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
批量设置阀门的属性。
|
||||
|
||||
更新指定阀门的一个或多个属性,通过JSON请求体传递要更新的属性。
|
||||
"""
|
||||
props = await req.json()
|
||||
ps = {"id": valve} | props
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
+401
-40
@@ -1,12 +1,15 @@
|
||||
import json
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi import APIRouter, Request, HTTPException, Query, Path, Body, Depends
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
|
||||
from app.auth.project_dependencies import get_metadata_repository
|
||||
from app.domain.schemas.metadata import ProjectMetaResponse, GeoServerConfigResponse
|
||||
import app.services.project_info as project_info
|
||||
from app.native.api import ChangeSet
|
||||
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
|
||||
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
|
||||
from app.services.tjnetwork import (
|
||||
ChangeSet,
|
||||
list_project,
|
||||
have_project,
|
||||
create_project,
|
||||
@@ -39,30 +42,106 @@ inpDir = "data/" # Assuming data directory exists or is defined somewhere.
|
||||
router = APIRouter()
|
||||
lockedPrjs: Dict[str, str] = {}
|
||||
|
||||
@router.get("/listprojects/")
|
||||
@router.get("/project_info/", summary="获取项目信息", description="从数据库获取项目的详细信息,包括地图范围等。", response_model=ProjectMetaResponse)
|
||||
async def get_project_info_endpoint(
|
||||
network: str = Query(..., description="管网名称(或项目代码)"),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
"""
|
||||
获取项目信息
|
||||
|
||||
- **network**: 管网名称(或项目代码)
|
||||
"""
|
||||
project_detail = await metadata_repo.get_project_detail_by_code(network)
|
||||
if not project_detail:
|
||||
raise HTTPException(status_code=404, detail=f"Project {network} not found")
|
||||
|
||||
geoserver_payload = None
|
||||
if project_detail.geoserver:
|
||||
geoserver_payload = GeoServerConfigResponse(
|
||||
gs_base_url=project_detail.geoserver.gs_base_url,
|
||||
gs_admin_user=project_detail.geoserver.gs_admin_user,
|
||||
gs_datastore_name=project_detail.geoserver.gs_datastore_name,
|
||||
default_extent=project_detail.geoserver.default_extent,
|
||||
srid=project_detail.geoserver.srid,
|
||||
)
|
||||
|
||||
return ProjectMetaResponse(
|
||||
project_id=project_detail.project_id,
|
||||
name=project_detail.name,
|
||||
code=project_detail.code,
|
||||
description=project_detail.description,
|
||||
gs_workspace=project_detail.gs_workspace,
|
||||
map_extent=project_detail.map_extent,
|
||||
status=project_detail.status,
|
||||
project_role="viewer", # Default role for public access
|
||||
geoserver=geoserver_payload
|
||||
)
|
||||
|
||||
@router.get("/listprojects/", summary="获取项目列表", description="获取服务器上所有可用的供水管网项目名称列表。")
|
||||
async def list_projects_endpoint() -> list[str]:
|
||||
"""
|
||||
获取项目列表
|
||||
|
||||
返回所有已创建项目的名称列表。
|
||||
"""
|
||||
return list_project()
|
||||
|
||||
@router.get("/haveproject/")
|
||||
async def have_project_endpoint(network: str):
|
||||
@router.get("/haveproject/", summary="检查项目是否存在", description="检查指定名称的项目是否存在。")
|
||||
async def have_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
检查项目是否存在
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
return have_project(network)
|
||||
|
||||
@router.post("/createproject/")
|
||||
async def create_project_endpoint(network: str):
|
||||
@router.post("/createproject/", summary="创建新项目", description="创建一个新的供水管网项目。如果项目已存在,可能会覆盖或报错(取决于底层实现)。")
|
||||
async def create_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
创建新项目
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
create_project(network)
|
||||
return network
|
||||
|
||||
@router.post("/deleteproject/")
|
||||
async def delete_project_endpoint(network: str):
|
||||
@router.post("/deleteproject/", summary="删除项目", description="永久删除指定的供水管网项目。此操作不可恢复。")
|
||||
async def delete_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
删除项目
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
delete_project(network)
|
||||
return True
|
||||
|
||||
@router.get("/isprojectopen/")
|
||||
async def is_project_open_endpoint(network: str):
|
||||
@router.get("/isprojectopen/", summary="检查项目是否已打开", description="检查指定项目是否已被加载到内存中。")
|
||||
async def is_project_open_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
检查项目是否已打开
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
return is_project_open(network)
|
||||
|
||||
@router.post("/openproject/")
|
||||
async def open_project_endpoint(network: str):
|
||||
@router.post("/openproject/", summary="打开项目", description="将指定项目加载到内存中,并初始化数据库连接池。")
|
||||
async def open_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
打开项目
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
open_project(network)
|
||||
|
||||
# 尝试连接指定数据库
|
||||
@@ -88,18 +167,43 @@ async def open_project_endpoint(network: str):
|
||||
|
||||
return network
|
||||
|
||||
@router.post("/closeproject/")
|
||||
async def close_project_endpoint(network: str):
|
||||
@router.post("/closeproject/", summary="关闭项目", description="将指定项目从内存中卸载,释放资源。")
|
||||
async def close_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
关闭项目
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
close_project(network)
|
||||
return True
|
||||
|
||||
@router.post("/copyproject/")
|
||||
async def copy_project_endpoint(source: str, target: str):
|
||||
@router.post("/copyproject/", summary="复制项目", description="将现有项目复制为新项目。")
|
||||
async def copy_project_endpoint(
|
||||
source: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
target: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
复制项目
|
||||
|
||||
- **source**: 管网名称(或数据库名称)
|
||||
- **target**: 管网名称(或数据库名称)
|
||||
"""
|
||||
copy_project(source, target)
|
||||
return True
|
||||
|
||||
@router.post("/importinp/")
|
||||
async def import_inp_endpoint(network: str, req: Request):
|
||||
@router.post("/importinp/", summary="导入 INP 文件内容", description="将 INP 格式的文本内容导入到指定项目中。")
|
||||
async def import_inp_endpoint(
|
||||
req: Request,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
):
|
||||
"""
|
||||
导入 INP 文件内容
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **req**: 请求体,需包含 `{"inp": "..."}` 结构
|
||||
"""
|
||||
jo_root = await req.json()
|
||||
inp_text = jo_root["inp"]
|
||||
ps = {"inp": inp_text}
|
||||
@@ -107,8 +211,17 @@ async def import_inp_endpoint(network: str, req: Request):
|
||||
print(ret)
|
||||
return ret
|
||||
|
||||
@router.get("/exportinp/", response_model=None)
|
||||
async def export_inp_endpoint(network: str, version: str) -> ChangeSet:
|
||||
@router.get("/exportinp/", response_model=None, summary="导出项目为 ChangeSet", description="导出项目的变更集 (ChangeSet),包含顶点、SCADA 元素、DMA、SA、VD 等信息。")
|
||||
async def export_inp_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
version: str = Query(..., description="版本号 (通常用于增量更新)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
导出项目为 ChangeSet
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **version**: 版本号
|
||||
"""
|
||||
cs = export_inp(network, version)
|
||||
op = cs.operations[0]
|
||||
open_project(network)
|
||||
@@ -131,30 +244,75 @@ async def export_inp_endpoint(network: str, version: str) -> ChangeSet:
|
||||
|
||||
return cs
|
||||
|
||||
@router.post("/readinp/")
|
||||
async def read_inp_endpoint(network: str, inp: str) -> bool:
|
||||
@router.post("/readinp/", summary="读取 INP 文件到项目", description="从服务器文件系统中读取指定的 INP 文件并加载到项目中。")
|
||||
async def read_inp_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
inp: str = Query(..., description="INP 文件名 (不包含路径)")
|
||||
) -> bool:
|
||||
"""
|
||||
读取 INP 文件到项目
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **inp**: INP 文件名
|
||||
"""
|
||||
read_inp(network, inp)
|
||||
return True
|
||||
|
||||
@router.get("/dumpinp/")
|
||||
async def dump_inp_endpoint(network: str, inp: str) -> bool:
|
||||
@router.get("/dumpinp/", summary="导出项目到 INP 文件", description="将项目当前状态保存为 INP 文件到服务器文件系统。")
|
||||
async def dump_inp_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
inp: str = Query(..., description="目标文件名")
|
||||
) -> bool:
|
||||
"""
|
||||
导出项目到 INP 文件
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **inp**: 目标文件名
|
||||
"""
|
||||
dump_inp(network, inp)
|
||||
return True
|
||||
|
||||
@router.get("/isprojectlocked/")
|
||||
async def is_project_locked_endpoint(network: str, req: Request):
|
||||
@router.get("/isprojectlocked/", summary="检查项目是否被锁定", description="检查指定项目是否处于锁定状态。")
|
||||
async def is_project_locked_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
检查项目是否被锁定
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
return network in lockedPrjs.keys()
|
||||
|
||||
@router.get("/isprojectlockedbyme/")
|
||||
async def is_project_locked_by_me_endpoint(network: str, req: Request):
|
||||
@router.get("/isprojectlockedbyme/", summary="检查项目是否被当前用户锁定", description="检查指定项目是否被当前客户端 (IP) 锁定。")
|
||||
async def is_project_locked_by_me_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
检查项目是否被当前用户锁定
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
client_host = req.client.host
|
||||
return lockedPrjs.get(network) == client_host
|
||||
|
||||
# 0 successfully locked
|
||||
# 1 already locked by you
|
||||
# 2 locked by others
|
||||
@router.post("/lockproject/")
|
||||
async def lock_project_endpoint(network: str, req: Request):
|
||||
@router.post("/lockproject/", summary="锁定项目", description="锁定指定项目以防止并发修改。")
|
||||
async def lock_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
锁定项目
|
||||
|
||||
返回值:
|
||||
- **0**: 锁定成功
|
||||
- **1**: 已被当前用户锁定
|
||||
- **2**: 已被其他用户锁定
|
||||
"""
|
||||
client_host = req.client.host
|
||||
if not network in lockedPrjs.keys():
|
||||
lockedPrjs[network] = client_host
|
||||
@@ -165,8 +323,16 @@ async def lock_project_endpoint(network: str, req: Request):
|
||||
else:
|
||||
return 2
|
||||
|
||||
@router.post("/unlockproject/")
|
||||
def unlock_project_endpoint(network: str, req: Request):
|
||||
@router.post("/unlockproject/", summary="解锁项目", description="释放对项目的锁定。")
|
||||
def unlock_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
解锁项目
|
||||
|
||||
只有锁定者才能解锁。
|
||||
"""
|
||||
client_host = req.client.host
|
||||
if lockedPrjs.get(network) == client_host:
|
||||
print("delete key")
|
||||
@@ -176,8 +342,17 @@ def unlock_project_endpoint(network: str, req: Request):
|
||||
return False
|
||||
|
||||
# inp file operations
|
||||
@router.post("/uploadinp/", status_code=status.HTTP_200_OK)
|
||||
async def fastapi_upload_inp(afile: bytes, name: str):
|
||||
@router.post("/uploadinp/", status_code=status.HTTP_200_OK, summary="上传 INP 文件", description="上传 INP 文件到服务器数据目录。")
|
||||
async def fastapi_upload_inp(
|
||||
afile: bytes = Body(..., description="文件二进制内容"),
|
||||
name: str = Query(..., description="保存的文件名")
|
||||
):
|
||||
"""
|
||||
上传 INP 文件
|
||||
|
||||
- **afile**: 文件内容
|
||||
- **name**: 文件名
|
||||
"""
|
||||
if not os.path.exists(inpDir):
|
||||
os.makedirs(inpDir, exist_ok=True)
|
||||
|
||||
@@ -186,8 +361,16 @@ async def fastapi_upload_inp(afile: bytes, name: str):
|
||||
f.write(afile)
|
||||
return True
|
||||
|
||||
@router.get("/downloadinp/", status_code=status.HTTP_200_OK)
|
||||
async def fastapi_download_inp(name: str, response: Response):
|
||||
@router.get("/downloadinp/", status_code=status.HTTP_200_OK, summary="下载 INP 文件", description="从服务器数据目录下载指定的 INP 文件。")
|
||||
async def fastapi_download_inp(
|
||||
name: str = Query(..., description="文件名"),
|
||||
response: Response = None
|
||||
):
|
||||
"""
|
||||
下载 INP 文件
|
||||
|
||||
- **name**: 文件名
|
||||
"""
|
||||
filePath = inpDir + name
|
||||
if os.path.exists(filePath):
|
||||
return FileResponse(
|
||||
@@ -198,8 +381,186 @@ async def fastapi_download_inp(name: str, response: Response):
|
||||
return True
|
||||
|
||||
# DingZQ, 2024-12-28, convert v3 to v2
|
||||
@router.get("/convertv3tov2/", response_model=None)
|
||||
async def fastapi_convert_v3_to_v2(req: Request) -> ChangeSet:
|
||||
@router.get("/convertv3tov2/", response_model=None, summary="转换 INP V3 为 V2", description="将 EPANET 3.0 格式的 INP 内容转换为 2.x 格式。")
|
||||
async def fastapi_convert_v3_to_v2(
|
||||
req: Request
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
转换 INP V3 为 V2
|
||||
|
||||
- **req**: 请求体,需包含 `{"inp": "..."}` 结构
|
||||
"""
|
||||
network = "v3Tov2"
|
||||
jo_root = await req.json()
|
||||
inp = jo_root["inp"]
|
||||
cs = convert_inp_v3_to_v2(inp)
|
||||
op = cs.operations[0]
|
||||
open_project(network)
|
||||
op["vertex"] = json.dumps(get_all_vertices(network))
|
||||
op["scada"] = json.dumps(get_all_scada_elements(network))
|
||||
op["dma"] = json.dumps(get_all_district_metering_areas(network))
|
||||
op["sa"] = json.dumps(get_all_service_areas(network))
|
||||
op["vd"] = json.dumps(get_all_virtual_districts(network))
|
||||
op["legend"] = get_extension_data(network, "legend")
|
||||
|
||||
db = get_extension_data(network, "scada_db")
|
||||
print(db)
|
||||
scada_db = ""
|
||||
if db:
|
||||
scada_db = db
|
||||
print(scada_db)
|
||||
op["scada_db"] = scada_db
|
||||
|
||||
close_project(network)
|
||||
|
||||
return cs
|
||||
|
||||
@router.post("/readinp/", summary="读取 INP 文件到项目", description="从服务器文件系统中读取指定的 INP 文件并加载到项目中。")
|
||||
async def read_inp_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
inp: str = Query(..., description="INP 文件名 (不包含路径)")
|
||||
) -> bool:
|
||||
"""
|
||||
读取 INP 文件到项目
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **inp**: INP 文件名
|
||||
"""
|
||||
read_inp(network, inp)
|
||||
return True
|
||||
|
||||
@router.get("/dumpinp/", summary="导出项目到 INP 文件", description="将项目当前状态保存为 INP 文件到服务器文件系统。")
|
||||
async def dump_inp_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
inp: str = Query(..., description="目标文件名")
|
||||
) -> bool:
|
||||
"""
|
||||
导出项目到 INP 文件
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **inp**: 目标文件名
|
||||
"""
|
||||
dump_inp(network, inp)
|
||||
return True
|
||||
|
||||
@router.get("/isprojectlocked/", summary="检查项目是否被锁定", description="检查指定项目是否处于锁定状态。")
|
||||
async def is_project_locked_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
检查项目是否被锁定
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
return network in lockedPrjs.keys()
|
||||
|
||||
@router.get("/isprojectlockedbyme/", summary="检查项目是否被当前用户锁定", description="检查指定项目是否被当前客户端 (IP) 锁定。")
|
||||
async def is_project_locked_by_me_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
检查项目是否被当前用户锁定
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
"""
|
||||
client_host = req.client.host
|
||||
return lockedPrjs.get(network) == client_host
|
||||
|
||||
# 0 successfully locked
|
||||
# 1 already locked by you
|
||||
# 2 locked by others
|
||||
@router.post("/lockproject/", summary="锁定项目", description="锁定指定项目以防止并发修改。")
|
||||
async def lock_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
锁定项目
|
||||
|
||||
返回值:
|
||||
- **0**: 锁定成功
|
||||
- **1**: 已被当前用户锁定
|
||||
- **2**: 已被其他用户锁定
|
||||
"""
|
||||
client_host = req.client.host
|
||||
if not network in lockedPrjs.keys():
|
||||
lockedPrjs[network] = client_host
|
||||
return 0
|
||||
else:
|
||||
if lockedPrjs.get(network) == client_host:
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
@router.post("/unlockproject/", summary="解锁项目", description="释放对项目的锁定。")
|
||||
def unlock_project_endpoint(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
):
|
||||
"""
|
||||
解锁项目
|
||||
|
||||
只有锁定者才能解锁。
|
||||
"""
|
||||
client_host = req.client.host
|
||||
if lockedPrjs.get(network) == client_host:
|
||||
print("delete key")
|
||||
del lockedPrjs[network]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# inp file operations
|
||||
@router.post("/uploadinp/", status_code=status.HTTP_200_OK, summary="上传 INP 文件", description="上传 INP 文件到服务器数据目录。")
|
||||
async def fastapi_upload_inp(
|
||||
afile: bytes = Body(..., description="文件二进制内容"),
|
||||
name: str = Query(..., description="保存的文件名")
|
||||
):
|
||||
"""
|
||||
上传 INP 文件
|
||||
|
||||
- **afile**: 文件内容
|
||||
- **name**: 文件名
|
||||
"""
|
||||
if not os.path.exists(inpDir):
|
||||
os.makedirs(inpDir, exist_ok=True)
|
||||
|
||||
filePath = inpDir + str(name)
|
||||
with open(filePath, "wb") as f:
|
||||
f.write(afile)
|
||||
return True
|
||||
|
||||
@router.get("/downloadinp/", status_code=status.HTTP_200_OK, summary="下载 INP 文件", description="从服务器数据目录下载指定的 INP 文件。")
|
||||
async def fastapi_download_inp(
|
||||
name: str = Query(..., description="文件名"),
|
||||
response: Response = None
|
||||
):
|
||||
"""
|
||||
下载 INP 文件
|
||||
|
||||
- **name**: 文件名
|
||||
"""
|
||||
filePath = inpDir + name
|
||||
if os.path.exists(filePath):
|
||||
return FileResponse(
|
||||
filePath, media_type="application/octet-stream", filename="inp.inp"
|
||||
)
|
||||
else:
|
||||
response.status_code = status.HTTP_400_BAD_REQUEST
|
||||
return True
|
||||
|
||||
# DingZQ, 2024-12-28, convert v3 to v2
|
||||
@router.get("/convertv3tov2/", response_model=None, summary="转换 INP V3 为 V2", description="将 EPANET 3.0 格式的 INP 内容转换为 2.x 格式。")
|
||||
async def fastapi_convert_v3_to_v2(
|
||||
req: Request
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
转换 INP V3 为 V2
|
||||
|
||||
- **req**: 请求体,需包含 `{"inp": "..."}` 结构
|
||||
"""
|
||||
network = "v3Tov2"
|
||||
jo_root = await req.json()
|
||||
inp = jo_root["inp"]
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .scada_info import ScadaRepository
|
||||
from .scheme import SchemeRepository
|
||||
import app.native.wndb as wndb
|
||||
from app.infra.db.postgresql.scheme import SchemeRepository
|
||||
from app.auth.project_dependencies import get_project_pg_connection
|
||||
from app.services import project_info
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 动态项目 PostgreSQL 连接依赖
|
||||
async def get_database_connection(
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
"""获取数据库连接"""
|
||||
yield conn
|
||||
|
||||
|
||||
@router.get("/scada-info")
|
||||
@router.get("/scada-info", summary="获取SCADA信息", description="使用连接池查询所有SCADA信息")
|
||||
async def get_scada_info_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有SCADA信息
|
||||
获取所有SCADA信息
|
||||
|
||||
返回项目中所有的SCADA设备信息
|
||||
"""
|
||||
try:
|
||||
# 使用ScadaRepository查询SCADA信息
|
||||
scada_data = await ScadaRepository.get_scadas(conn)
|
||||
_ = conn
|
||||
network_name = project_info.name
|
||||
scada_data = wndb.get_all_scada_info(network_name) if network_name else []
|
||||
return {"success": True, "data": scada_data, "count": len(scada_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
@@ -32,30 +36,32 @@ async def get_scada_info_with_connection(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scheme-list")
|
||||
@router.get("/scheme-list", summary="获取方案列表", description="使用连接池查询所有方案信息")
|
||||
async def get_scheme_list_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有方案信息
|
||||
获取所有方案信息
|
||||
|
||||
返回项目中所有方案的详细信息
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询方案信息
|
||||
scheme_data = await SchemeRepository.get_schemes(conn)
|
||||
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/burst-locate-result")
|
||||
@router.get("/burst-locate-result", summary="获取爆管定位结果", description="使用连接池查询所有爆管定位结果")
|
||||
async def get_burst_locate_result_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有爆管定位结果
|
||||
获取所有爆管定位结果
|
||||
|
||||
返回项目中所有的爆管定位分析结果
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询爆管定位结果
|
||||
burst_data = await SchemeRepository.get_burst_locate_results(conn)
|
||||
return {"success": True, "data": burst_data, "count": len(burst_data)}
|
||||
except Exception as e:
|
||||
@@ -64,16 +70,18 @@ async def get_burst_locate_result_with_connection(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/burst-locate-result/{burst_incident}")
|
||||
@router.get("/burst-locate-result/{burst_incident}", summary="按事件查询爆管定位结果", description="根据爆管事件ID查询对应的爆管定位结果")
|
||||
async def get_burst_locate_result_by_incident(
|
||||
burst_incident: str,
|
||||
burst_incident: str = Path(..., description="爆管事件ID"),
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
根据 burst_incident 查询爆管定位结果
|
||||
根据爆管事件ID查询爆管定位结果
|
||||
|
||||
参数:
|
||||
burst_incident: 爆管事件的唯一标识符
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询爆管定位结果
|
||||
return await SchemeRepository.get_burst_locate_result_by_incident(
|
||||
conn, burst_incident
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, List, Dict
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Query, Path
|
||||
from app.services.tjnetwork import (
|
||||
get_pipe_risk_probability_now,
|
||||
get_pipe_risk_probability,
|
||||
@@ -10,35 +10,118 @@ from app.services.tjnetwork import (
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpiperiskprobabilitynow/")
|
||||
@router.get(
|
||||
"/getpiperiskprobabilitynow/",
|
||||
summary="获取管道当前风险概率",
|
||||
description="获取指定管道当前时刻的风险概率值"
|
||||
)
|
||||
async def fastapi_get_pipe_risk_probability_now(
|
||||
network: str, pipe_id: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe_id: str = Query(..., description="管道ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取管道当前风险概率。
|
||||
|
||||
查询指定管道在当前时刻的风险概率值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe_id: 管道ID
|
||||
|
||||
Returns:
|
||||
包含风险概率信息的字典
|
||||
"""
|
||||
return get_pipe_risk_probability_now(network, pipe_id)
|
||||
|
||||
|
||||
@router.get("/getpiperiskprobability/")
|
||||
@router.get(
|
||||
"/getpiperiskprobability/",
|
||||
summary="获取管道风险概率历史",
|
||||
description="获取指定管道的风险概率历史数据"
|
||||
)
|
||||
async def fastapi_get_pipe_risk_probability(
|
||||
network: str, pipe_id: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe_id: str = Query(..., description="管道ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取管道风险概率历史。
|
||||
|
||||
查询指定管道的历史风险概率数据。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe_id: 管道ID
|
||||
|
||||
Returns:
|
||||
包含风险概率历史的字典
|
||||
"""
|
||||
return get_pipe_risk_probability(network, pipe_id)
|
||||
|
||||
|
||||
@router.get("/getpipesriskprobability/")
|
||||
@router.get(
|
||||
"/getpipesriskprobability/",
|
||||
summary="批量获取多条管道风险概率",
|
||||
description="批量获取多条管道的风险概率值"
|
||||
)
|
||||
async def fastapi_get_pipes_risk_probability(
|
||||
network: str, pipe_ids: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
pipe_ids: str = Query(..., description="逗号分隔的管道ID列表")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
批量获取多条管道风险概率。
|
||||
|
||||
查询多条指定管道的风险概率值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
pipe_ids: 逗号分隔的管道ID列表(例如:pipe1,pipe2,pipe3)
|
||||
|
||||
Returns:
|
||||
包含多条管道风险概率的列表
|
||||
"""
|
||||
pipeids = pipe_ids.split(",")
|
||||
return get_pipes_risk_probability(network, pipeids)
|
||||
|
||||
|
||||
@router.get("/getnetworkpiperiskprobabilitynow/")
|
||||
@router.get(
|
||||
"/getnetworkpiperiskprobabilitynow/",
|
||||
summary="获取整个网络的管道风险概率",
|
||||
description="获取指定网络中所有管道的当前风险概率值"
|
||||
)
|
||||
async def fastapi_get_network_pipe_risk_probability_now(
|
||||
network: str,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取整个网络的管道风险概率。
|
||||
|
||||
查询指定网络中所有管道在当前时刻的风险概率值。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含网络内所有管道风险概率的列表
|
||||
"""
|
||||
return get_network_pipe_risk_probability_now(network)
|
||||
|
||||
|
||||
@router.get("/getpiperiskprobabilitygeometries/")
|
||||
async def fastapi_get_pipe_risk_probability_geometries(network: str) -> dict[str, Any]:
|
||||
@router.get(
|
||||
"/getpiperiskprobabilitygeometries/",
|
||||
summary="获取管道风险几何信息",
|
||||
description="获取指定网络中管道的风险相关几何数据"
|
||||
)
|
||||
async def fastapi_get_pipe_risk_probability_geometries(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取管道风险几何信息。
|
||||
|
||||
查询指定网络中管道的地理和风险相关的几何数据。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
包含几何信息和风险数据的字典
|
||||
"""
|
||||
return get_pipe_risk_probability_geometries(network)
|
||||
|
||||
+416
-58
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Request
|
||||
from app.native.api import ChangeSet
|
||||
from fastapi import APIRouter, Request, Query
|
||||
from app.services.tjnetwork import (
|
||||
ChangeSet,
|
||||
get_scada_info,
|
||||
get_all_scada_info,
|
||||
get_scada_device_schema,
|
||||
@@ -31,139 +31,497 @@ from app.services.tjnetwork import (
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getscadaproperties/")
|
||||
async def fast_get_scada_properties(network: str, scada: str) -> dict[str, Any]:
|
||||
@router.get("/getscadaproperties/", summary="获取SCADA属性", tags=["SCADA基础"])
|
||||
async def fast_get_scada_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scada: str = Query(..., description="SCADA设备ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取单个SCADA设备的属性信息
|
||||
|
||||
根据管网名称和SCADA设备ID获取该设备的完整属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
scada: SCADA设备ID
|
||||
|
||||
Returns:
|
||||
SCADA设备的属性字典
|
||||
"""
|
||||
return get_scada_info(network, scada)
|
||||
|
||||
@router.get("/getallscadaproperties/")
|
||||
async def fast_get_all_scada_properties(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getallscadaproperties/", summary="获取所有SCADA属性", tags=["SCADA基础"])
|
||||
async def fast_get_all_scada_properties(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定管网所有SCADA设备的属性信息
|
||||
|
||||
查询该管网下所有已配置的SCADA设备的属性列表。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA设备属性列表
|
||||
"""
|
||||
return get_all_scada_info(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_device 29
|
||||
# scada_device 设备管理
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadadeviceschema/")
|
||||
async def fastapi_get_scada_device_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getscadadeviceschema/", summary="获取SCADA设备架构", tags=["SCADA设备"])
|
||||
async def fastapi_get_scada_device_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取SCADA设备的数据架构
|
||||
|
||||
返回SCADA设备表的字段定义和类型信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA设备的字段架构信息
|
||||
"""
|
||||
return get_scada_device_schema(network)
|
||||
|
||||
@router.get("/getscadadevice/")
|
||||
async def fastapi_get_scada_device(network: str, id: str) -> dict[str, Any]:
|
||||
@router.get("/getscadadevice/", summary="获取SCADA设备", tags=["SCADA设备"])
|
||||
async def fastapi_get_scada_device(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
id: str = Query(..., description="SCADA设备ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取单个SCADA设备的信息
|
||||
|
||||
根据设备ID查询该设备的详细信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
id: SCADA设备ID
|
||||
|
||||
Returns:
|
||||
SCADA设备信息
|
||||
"""
|
||||
return get_scada_device(network, id)
|
||||
|
||||
@router.post("/setscadadevice/", response_model=None)
|
||||
async def fastapi_set_scada_device(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setscadadevice/", response_model=None, summary="更新SCADA设备", tags=["SCADA设备"])
|
||||
async def fastapi_set_scada_device(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
更新SCADA设备信息
|
||||
|
||||
修改指定SCADA设备的属性。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含要更新的设备属性
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_scada_device(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addscadadevice/", response_model=None)
|
||||
async def fastapi_add_scada_device(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addscadadevice/", response_model=None, summary="添加SCADA设备", tags=["SCADA设备"])
|
||||
async def fastapi_add_scada_device(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新的SCADA设备
|
||||
|
||||
在指定管网中添加一个新的SCADA设备。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含新设备的属性
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return add_scada_device(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletescadadevice/", response_model=None)
|
||||
async def fastapi_delete_scada_device(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/deletescadadevice/", response_model=None, summary="删除SCADA设备", tags=["SCADA设备"])
|
||||
async def fastapi_delete_scada_device(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除SCADA设备
|
||||
|
||||
从指定管网中删除一个SCADA设备。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含要删除的设备ID
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return delete_scada_device(network, ChangeSet(props))
|
||||
|
||||
@router.post("/cleanscadadevice/", response_model=None)
|
||||
async def fastapi_clean_scada_device(network: str) -> ChangeSet:
|
||||
@router.post("/cleanscadadevice/", response_model=None, summary="清空SCADA设备表", tags=["SCADA设备"])
|
||||
async def fastapi_clean_scada_device(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
清空SCADA设备表
|
||||
|
||||
删除指定管网中所有的SCADA设备。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
return clean_scada_device(network)
|
||||
|
||||
@router.get("/getallscadadeviceids/")
|
||||
async def fastapi_get_all_scada_device_ids(network: str) -> list[str]:
|
||||
@router.get("/getallscadadeviceids/", summary="获取所有SCADA设备ID", tags=["SCADA设备"])
|
||||
async def fastapi_get_all_scada_device_ids(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[str]:
|
||||
"""
|
||||
获取指定管网所有SCADA设备的ID列表
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA设备ID列表
|
||||
"""
|
||||
return get_all_scada_device_ids(network)
|
||||
|
||||
@router.get("/getallscadadevices/")
|
||||
async def fastapi_get_all_scada_devices(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getallscadadevices/", summary="获取所有SCADA设备", tags=["SCADA设备"])
|
||||
async def fastapi_get_all_scada_devices(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定管网所有SCADA设备的完整信息
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA设备信息列表
|
||||
"""
|
||||
return get_all_scada_devices(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_device_data 30
|
||||
# scada_device_data 设备数据管理
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadadevicedataschema/")
|
||||
@router.get("/getscadadevicedataschema/", summary="获取SCADA设备数据架构", tags=["SCADA设备数据"])
|
||||
async def fastapi_get_scada_device_data_schema(
|
||||
network: str,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取SCADA设备数据的表结构
|
||||
|
||||
返回SCADA设备数据表的字段定义和类型信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA设备数据的字段架构信息
|
||||
"""
|
||||
return get_scada_device_data_schema(network)
|
||||
|
||||
@router.get("/getscadadevicedata/")
|
||||
async def fastapi_get_scada_device_data(network: str, device_id: str) -> dict[str, Any]:
|
||||
@router.get("/getscadadevicedata/", summary="获取SCADA设备数据", tags=["SCADA设备数据"])
|
||||
async def fastapi_get_scada_device_data(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
device_id: str = Query(..., description="SCADA设备ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取单个SCADA设备的数据
|
||||
|
||||
查询指定设备的监测数据或配置数据。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
device_id: SCADA设备ID
|
||||
|
||||
Returns:
|
||||
SCADA设备数据
|
||||
"""
|
||||
return get_scada_device_data(network, device_id)
|
||||
|
||||
@router.post("/setscadadevicedata/", response_model=None)
|
||||
async def fastapi_set_scada_device_data(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setscadadevicedata/", response_model=None, summary="更新SCADA设备数据", tags=["SCADA设备数据"])
|
||||
async def fastapi_set_scada_device_data(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
更新SCADA设备数据
|
||||
|
||||
修改指定SCADA设备的数据。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含要更新的数据
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_scada_device_data(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addscadadevicedata/", response_model=None)
|
||||
async def fastapi_add_scada_device_data(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addscadadevicedata/", response_model=None, summary="添加SCADA设备数据", tags=["SCADA设备数据"])
|
||||
async def fastapi_add_scada_device_data(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新的SCADA设备数据
|
||||
|
||||
为指定SCADA设备添加新的数据记录。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含新数据的内容
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return add_scada_device_data(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletescadadevicedata/", response_model=None)
|
||||
async def fastapi_delete_scada_device_data(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/deletescadadevicedata/", response_model=None, summary="删除SCADA设备数据", tags=["SCADA设备数据"])
|
||||
async def fastapi_delete_scada_device_data(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除SCADA设备数据
|
||||
|
||||
删除指定SCADA设备的数据记录。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含要删除的数据ID
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return delete_scada_device_data(network, ChangeSet(props))
|
||||
|
||||
@router.post("/cleanscadadevicedata/", response_model=None)
|
||||
async def fastapi_clean_scada_device_data(network: str) -> ChangeSet:
|
||||
@router.post("/cleanscadadevicedata/", response_model=None, summary="清空SCADA设备数据表", tags=["SCADA设备数据"])
|
||||
async def fastapi_clean_scada_device_data(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
清空SCADA设备数据表
|
||||
|
||||
删除指定管网中所有SCADA设备的数据。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
return clean_scada_device_data(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_element 31
|
||||
# scada_element SCADA元素映射
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadaelementschema/")
|
||||
@router.get("/getscadaelementschema/", summary="获取SCADA元素架构", tags=["SCADA元素映射"])
|
||||
async def fastapi_get_scada_element_schema(
|
||||
network: str,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取SCADA元素映射的表结构
|
||||
|
||||
返回SCADA元素映射表的字段定义和类型信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA元素映射的字段架构信息
|
||||
"""
|
||||
return get_scada_element_schema(network)
|
||||
|
||||
@router.get("/getscadaelements/")
|
||||
async def fastapi_get_scada_elements(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getscadaelements/", summary="获取所有SCADA元素映射", tags=["SCADA元素映射"])
|
||||
async def fastapi_get_scada_elements(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定管网所有SCADA元素映射
|
||||
|
||||
查询所有SCADA设备与管网元素(节点/管道)的映射关系。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA元素映射列表
|
||||
"""
|
||||
return get_all_scada_elements(network)
|
||||
|
||||
@router.get("/getscadaelement/")
|
||||
async def fastapi_get_scada_element(network: str, id: str) -> dict[str, Any]:
|
||||
@router.get("/getscadaelement/", summary="获取单个SCADA元素映射", tags=["SCADA元素映射"])
|
||||
async def fastapi_get_scada_element(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
id: str = Query(..., description="SCADA元素映射ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取单个SCADA元素映射的信息
|
||||
|
||||
根据ID查询特定的SCADA设备与管网元素的映射关系。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
id: SCADA元素映射ID
|
||||
|
||||
Returns:
|
||||
SCADA元素映射信息
|
||||
"""
|
||||
return get_scada_element(network, id)
|
||||
|
||||
@router.post("/setscadaelement/", response_model=None)
|
||||
async def fastapi_set_scada_element(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/setscadaelement/", response_model=None, summary="更新SCADA元素映射", tags=["SCADA元素映射"])
|
||||
async def fastapi_set_scada_element(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
更新SCADA元素映射
|
||||
|
||||
修改SCADA设备与管网元素的映射关系。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含要更新的映射信息
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return set_scada_element(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addscadaelement/", response_model=None)
|
||||
async def fastapi_add_scada_element(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/addscadaelement/", response_model=None, summary="添加SCADA元素映射", tags=["SCADA元素映射"])
|
||||
async def fastapi_add_scada_element(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
添加新的SCADA元素映射
|
||||
|
||||
创建SCADA设备与管网元素的新映射关系。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含新映射的信息
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return add_scada_element(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletescadaelement/", response_model=None)
|
||||
async def fastapi_delete_scada_element(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/deletescadaelement/", response_model=None, summary="删除SCADA元素映射", tags=["SCADA元素映射"])
|
||||
async def fastapi_delete_scada_element(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
删除SCADA元素映射
|
||||
|
||||
移除SCADA设备与管网元素的映射关系。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
req: 请求体,包含要删除的映射ID
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
props = await req.json()
|
||||
return delete_scada_element(network, ChangeSet(props))
|
||||
|
||||
@router.post("/cleanscadaelement/", response_model=None)
|
||||
async def fastapi_clean_scada_element(network: str) -> ChangeSet:
|
||||
@router.post("/cleanscadaelement/", response_model=None, summary="清空SCADA元素映射表", tags=["SCADA元素映射"])
|
||||
async def fastapi_clean_scada_element(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
清空SCADA元素映射表
|
||||
|
||||
删除指定管网中所有的SCADA元素映射。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
变更集合信息
|
||||
"""
|
||||
return clean_scada_element(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_info 38
|
||||
# scada_info SCADA信息
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadainfoschema/")
|
||||
async def fastapi_get_scada_info_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
@router.get("/getscadainfoschema/", summary="获取SCADA信息架构", tags=["SCADA信息"])
|
||||
async def fastapi_get_scada_info_schema(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
获取SCADA信息表的结构
|
||||
|
||||
返回SCADA信息表的字段定义和类型信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA信息的字段架构信息
|
||||
"""
|
||||
return get_scada_info_schema(network)
|
||||
|
||||
@router.get("/getscadainfo/")
|
||||
async def fastapi_get_scada_info(network: str, id: str) -> dict[str, Any]:
|
||||
@router.get("/getscadainfo/", summary="获取SCADA信息", tags=["SCADA信息"])
|
||||
async def fastapi_get_scada_info(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
id: str = Query(..., description="SCADA信息ID")
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取单个SCADA信息
|
||||
|
||||
根据ID查询SCADA的详细配置信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
id: SCADA信息ID
|
||||
|
||||
Returns:
|
||||
SCADA信息详情
|
||||
"""
|
||||
return get_scada_info(network, id)
|
||||
|
||||
@router.get("/getallscadainfo/")
|
||||
async def fastapi_get_all_scada_info(network: str) -> list[dict[str, Any]]:
|
||||
@router.get("/getallscadainfo/", summary="获取所有SCADA信息", tags=["SCADA信息"])
|
||||
async def fastapi_get_all_scada_info(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取指定管网所有SCADA的信息
|
||||
|
||||
查询该管网下所有已配置的SCADA的完整信息。
|
||||
|
||||
Args:
|
||||
network: 管网名称(或数据库名称)
|
||||
|
||||
Returns:
|
||||
SCADA信息列表
|
||||
"""
|
||||
return get_all_scada_info(network)
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Query
|
||||
from typing import Any, List, Dict
|
||||
from app.services.tjnetwork import get_scheme_schema, get_scheme, get_all_schemes
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getschemeschema/")
|
||||
async def fastapi_get_scheme_schema(network: str) -> dict[str, dict[Any, Any]]:
|
||||
@router.get("/getschemeschema/", summary="获取方案模式", description="获取指定网络的方案模式定义")
|
||||
async def fastapi_get_scheme_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[Any, Any]]:
|
||||
"""
|
||||
获取方案模式定义
|
||||
|
||||
返回指定网络的方案模式结构定义
|
||||
"""
|
||||
return get_scheme_schema(network)
|
||||
|
||||
@router.get("/getscheme/")
|
||||
async def fastapi_get_scheme(network: str, schema_name: str) -> dict[Any, Any]:
|
||||
@router.get("/getscheme/", summary="获取单个方案", description="根据名称获取指定的方案信息")
|
||||
async def fastapi_get_scheme(network: str = Query(..., description="管网名称(或数据库名称)"), schema_name: str = Query(..., description="方案名称")) -> dict[Any, Any]:
|
||||
"""
|
||||
获取单个方案详情
|
||||
|
||||
返回指定网络中指定名称的方案详细信息
|
||||
"""
|
||||
return get_scheme(network, schema_name)
|
||||
|
||||
@router.get("/getallschemes/")
|
||||
async def fastapi_get_all_schemes(network: str) -> list[dict[Any, Any]]:
|
||||
@router.get("/getallschemes/", summary="获取所有方案", description="获取指定网络的所有方案信息")
|
||||
async def fastapi_get_all_schemes(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
|
||||
"""
|
||||
获取所有方案列表
|
||||
|
||||
返回指定网络中所有可用的方案
|
||||
"""
|
||||
return get_all_schemes(network)
|
||||
|
||||
+433
-289
@@ -4,20 +4,17 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import pandas as pd
|
||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Query
|
||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Query, Path, Body
|
||||
from fastapi.responses import PlainTextResponse
|
||||
import app.infra.db.influxdb.api as influxdb_api
|
||||
import app.services.simulation as simulation
|
||||
import app.services.globals as globals
|
||||
from app.infra.cache.redis_client import redis_client
|
||||
from app.services.tjnetwork import (
|
||||
run_project,
|
||||
run_project_return_dict,
|
||||
run_inp,
|
||||
dump_output,
|
||||
)
|
||||
from app.algorithms.simulations import (
|
||||
from app.algorithms.simulation.scenarios import (
|
||||
burst_analysis,
|
||||
valve_close_analysis,
|
||||
flushing_analysis,
|
||||
@@ -26,12 +23,11 @@ from app.algorithms.simulations import (
|
||||
# scheduling_analysis,
|
||||
pressure_regulation,
|
||||
)
|
||||
from app.algorithms.sensors import (
|
||||
from app.algorithms.sensor import (
|
||||
pressure_sensor_placement_sensitivity,
|
||||
pressure_sensor_placement_kmeans,
|
||||
)
|
||||
import app.algorithms.api_ex.flow_data_clean as flow_data_clean
|
||||
import app.algorithms.api_ex.pressure_data_clean as pressure_data_clean
|
||||
|
||||
from app.services.network_import import network_update
|
||||
from app.services.simulation_ops import (
|
||||
project_management,
|
||||
@@ -39,176 +35,180 @@ from app.services.simulation_ops import (
|
||||
daily_scheduling_simulation,
|
||||
)
|
||||
from app.services.valve_isolation import analyze_valve_isolation
|
||||
from pydantic import BaseModel
|
||||
from app.services.time_api import parse_aware_time, parse_utc_time
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RunSimulationManuallyByDate(BaseModel):
|
||||
name: str
|
||||
simulation_date: str
|
||||
start_time: str
|
||||
duration: int
|
||||
name: str = Field(..., description="管网名称(或数据库名称)")
|
||||
start_time: str = Field(..., description="开始时间 (ISO 8601 / RFC3339,必须显式带时区)")
|
||||
duration: int = Field(..., gt=0, description="持续时间 (分钟)")
|
||||
|
||||
@field_validator("start_time")
|
||||
@classmethod
|
||||
def validate_start_time_timezone(cls, value: str) -> str:
|
||||
parse_aware_time(value, field_name="start_time")
|
||||
return value
|
||||
|
||||
|
||||
class BurstAnalysis(BaseModel):
|
||||
name: str
|
||||
modify_pattern_start_time: str
|
||||
burst_ID: List[str] | str | None = None
|
||||
burst_size: List[float] | float | int | None = None
|
||||
modify_total_duration: int = 900
|
||||
modify_fixed_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_variable_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_valve_opening: Optional[dict[str, float]] = None
|
||||
scheme_name: Optional[str] = None
|
||||
name: str = Field(..., description="管网名称(或数据库名称)")
|
||||
modify_pattern_start_time: str = Field(..., description="模式修改开始时间 (ISO 8601)")
|
||||
burst_ID: List[str] | str | None = Field(None, description="爆管节点/管段ID列表")
|
||||
burst_size: List[float] | float | int | None = Field(None, description="爆管流量大小")
|
||||
modify_total_duration: int = Field(900, description="模拟总时长 (秒)")
|
||||
modify_fixed_pump_pattern: Optional[dict[str, list]] = Field(None, description="定速泵模式修改")
|
||||
modify_variable_pump_pattern: Optional[dict[str, list]] = Field(None, description="变速泵模式修改")
|
||||
modify_valve_opening: Optional[dict[str, float]] = Field(None, description="阀门开度修改")
|
||||
scheme_name: Optional[str] = Field(None, description="方案名称")
|
||||
|
||||
|
||||
class SchedulingAnalysis(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
tank_id: str
|
||||
water_plant_output_id: str
|
||||
time_delta: Optional[int] = 300
|
||||
network: str = Field(..., description="管网名称(或数据库名称)")
|
||||
start_time: str = Field(..., description="开始时间")
|
||||
pump_control: dict = Field(..., description="泵控制策略")
|
||||
tank_id: str = Field(..., description="水箱ID")
|
||||
water_plant_output_id: str = Field(..., description="水厂出水ID")
|
||||
time_delta: Optional[int] = Field(300, description="时间步长 (秒)")
|
||||
|
||||
|
||||
class PressureRegulation(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
tank_init_level: Optional[dict] = None
|
||||
duration: Optional[int] = 900
|
||||
scheme_name: Optional[str] = None
|
||||
network: str = Field(..., description="管网名称(或数据库名称)")
|
||||
start_time: str = Field(..., description="开始时间")
|
||||
pump_control: dict = Field(..., description="泵控制策略")
|
||||
tank_init_level: Optional[dict] = Field(None, description="水箱初始水位")
|
||||
duration: Optional[int] = Field(900, description="持续时间 (秒)")
|
||||
scheme_name: Optional[str] = Field(None, description="方案名称")
|
||||
|
||||
|
||||
class ProjectManagement(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
tank_init_level: Optional[dict] = None
|
||||
region_demand: Optional[dict] = None
|
||||
network: str = Field(..., description="管网名称(或数据库名称)")
|
||||
start_time: str = Field(..., description="开始时间")
|
||||
pump_control: dict = Field(..., description="泵控制策略")
|
||||
tank_init_level: Optional[dict] = Field(None, description="水箱初始水位")
|
||||
region_demand: Optional[dict] = Field(None, description="区域需水量控制")
|
||||
|
||||
|
||||
class DailySchedulingAnalysis(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
reservoir_id: str
|
||||
tank_id: str
|
||||
water_plant_output_id: str
|
||||
time_delta: Optional[int] = 300
|
||||
network: str = Field(..., description="管网名称(或数据库名称)")
|
||||
start_time: str = Field(..., description="开始时间")
|
||||
pump_control: dict = Field(..., description="泵控制策略")
|
||||
reservoir_id: str = Field(..., description="水库ID")
|
||||
tank_id: str = Field(..., description="水箱ID")
|
||||
water_plant_output_id: str = Field(..., description="水厂出水ID")
|
||||
time_delta: Optional[int] = Field(300, description="时间步长 (秒)")
|
||||
|
||||
|
||||
class PumpFailureState(BaseModel):
|
||||
time: str
|
||||
pump_status: dict
|
||||
time: str = Field(..., description="故障发生时间")
|
||||
pump_status: dict = Field(..., description="泵状态字典")
|
||||
|
||||
|
||||
class PressureSensorPlacement(BaseModel):
|
||||
name: str
|
||||
scheme_name: str
|
||||
sensor_number: int
|
||||
min_diameter: int = 0
|
||||
username: str
|
||||
name: str = Field(..., description="管网名称(或数据库名称)")
|
||||
scheme_name: str = Field(..., description="方案名称")
|
||||
sensor_number: int = Field(..., description="传感器数量")
|
||||
min_diameter: int = Field(0, description="最小管径限制")
|
||||
username: str = Field(..., description="用户名")
|
||||
|
||||
|
||||
def run_simulation_manually_by_date(
|
||||
network_name: str, base_date: datetime, start_time: str, duration: int
|
||||
network_name: str, start_time: datetime, duration: int
|
||||
) -> None:
|
||||
time_parts = list(map(int, start_time.split(":")))
|
||||
if len(time_parts) == 2:
|
||||
start_hour, start_minute = time_parts
|
||||
start_second = 0
|
||||
elif len(time_parts) == 3:
|
||||
start_hour, start_minute, start_second = time_parts
|
||||
else:
|
||||
raise ValueError("Invalid start_time format. Use HH:MM or HH:MM:SS")
|
||||
|
||||
start_datetime = base_date.replace(
|
||||
hour=start_hour, minute=start_minute, second=start_second
|
||||
)
|
||||
end_datetime = start_datetime + timedelta(minutes=duration)
|
||||
current_time = start_datetime
|
||||
end_datetime = start_time + timedelta(minutes=duration)
|
||||
current_time = start_time
|
||||
while current_time < end_datetime:
|
||||
iso_time = current_time.strftime("%Y-%m-%dT%H:%M:%S") + "+08:00"
|
||||
simulation.run_simulation(
|
||||
name=network_name,
|
||||
simulation_type="realtime",
|
||||
modify_pattern_start_time=iso_time,
|
||||
modify_pattern_start_time=current_time.isoformat(timespec="seconds"),
|
||||
)
|
||||
current_time += timedelta(minutes=15)
|
||||
|
||||
|
||||
# 必须用这个PlainTextResponse,不然每个key都有引号
|
||||
@router.get("/runproject/", response_class=PlainTextResponse)
|
||||
async def run_project_endpoint(network: str) -> str:
|
||||
lock_key = "exclusive_api_lock"
|
||||
timeout = 120 # 锁自动过期时间(秒)
|
||||
|
||||
# 尝试获取锁(NX=True: 不存在时设置,EX=timeout: 过期时间)
|
||||
acquired = redis_client.set(lock_key, "locked", nx=True, ex=timeout)
|
||||
|
||||
if not acquired:
|
||||
raise HTTPException(status_code=409, detail="is in simulation")
|
||||
else:
|
||||
try:
|
||||
return run_project(network)
|
||||
finally:
|
||||
# 手动释放锁(可选,依赖过期时间自动释放更安全)
|
||||
redis_client.delete(lock_key)
|
||||
@router.get("/runproject/", response_class=PlainTextResponse, summary="运行项目模拟", description="基于指定的管网项目运行标准水力模拟,返回纯文本格式的模拟报告。")
|
||||
async def run_project_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> str:
|
||||
"""
|
||||
运行项目模拟
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
|
||||
运行指定管网项目的标准水力模拟并返回文本报告。
|
||||
"""
|
||||
return run_project(network)
|
||||
|
||||
|
||||
# DingZQ, 2025-02-04, 返回dict[str, Any]
|
||||
# output 和 report
|
||||
# output 是 json
|
||||
# report 是 text
|
||||
@router.get("/runprojectreturndict/")
|
||||
async def run_project_return_dict_endpoint(network: str) -> dict[str, Any]:
|
||||
lock_key = "exclusive_api_lock"
|
||||
timeout = 120 # 锁自动过期时间(秒)
|
||||
|
||||
# 尝试获取锁(NX=True: 不存在时设置,EX=timeout: 过期时间)
|
||||
acquired = redis_client.set(lock_key, "locked", nx=True, ex=timeout)
|
||||
|
||||
if not acquired:
|
||||
raise HTTPException(status_code=409, detail="is in simulation")
|
||||
else:
|
||||
try:
|
||||
return run_project_return_dict(network)
|
||||
finally:
|
||||
# 手动释放锁(可选,依赖过期时间自动释放更安全)
|
||||
redis_client.delete(lock_key)
|
||||
@router.get("/runprojectreturndict/", summary="运行项目模拟(返回字典)", description="基于指定的管网项目运行标准水力模拟,返回JSON格式的字典,包含输出数据和报告文本。")
|
||||
async def run_project_return_dict_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, Any]:
|
||||
"""
|
||||
运行项目模拟(返回字典)
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
|
||||
返回字典包含:
|
||||
- output: JSON格式的模拟输出数据
|
||||
- report: 文本格式的模拟报告
|
||||
|
||||
运行指定管网项目的标准水力模拟并返回字典结果。
|
||||
"""
|
||||
return run_project_return_dict(network)
|
||||
|
||||
|
||||
# put in inp folder, name without extension
|
||||
@router.get("/runinp/")
|
||||
async def run_inp_endpoint(network: str) -> str:
|
||||
@router.get("/runinp/", summary="运行INP文件", description="运行指定INP文件格式的管网模型进行水力模拟。INP文件应该放在inp文件夹中,参数为文件名不含扩展名。")
|
||||
async def run_inp_endpoint(network: str = Query(..., description="inp文件名(不含扩展名)")) -> str:
|
||||
"""
|
||||
运行INP文件
|
||||
|
||||
- **network**: inp文件名(不含扩展名)
|
||||
|
||||
从inp文件夹中读取指定的INP文件并运行模拟。
|
||||
"""
|
||||
return run_inp(network)
|
||||
|
||||
|
||||
# path is absolute path
|
||||
@router.get("/dumpoutput/")
|
||||
async def dump_output_endpoint(output: str) -> str:
|
||||
@router.get("/dumpoutput/", summary="导出模拟输出", description="导出指定路径的模拟输出文件内容。参数应为绝对路径。")
|
||||
async def dump_output_endpoint(output: str = Query(..., description="模拟输出文件的绝对路径")) -> str:
|
||||
"""
|
||||
导出模拟输出
|
||||
|
||||
- **output**: 模拟输出文件的绝对路径
|
||||
|
||||
读取并返回指定路径的模拟输出内容。
|
||||
"""
|
||||
return dump_output(output)
|
||||
|
||||
|
||||
# Analysis Endpoints
|
||||
@router.get("/burstanalysis/")
|
||||
async def burst_analysis_endpoint(
|
||||
network: str, pipe_id: str, start_time: str, end_time: str, burst_flow: float
|
||||
):
|
||||
return burst_analysis(network, pipe_id, start_time, end_time, burst_flow)
|
||||
|
||||
|
||||
@router.get("/burst_analysis/")
|
||||
@router.get("/burst_analysis/", summary="爆管分析(高级)", description="高级版本的爆管分析,支持在指定时间点修改泵控制模式和阀门开度,以分析这些改变对爆管影响的作用。支持固定泵和变速泵的独立控制。")
|
||||
async def fastapi_burst_analysis(
|
||||
network: str = Query(...),
|
||||
modify_pattern_start_time: str = Query(...),
|
||||
burst_ID: list[str] = Query(...),
|
||||
burst_size: list[float] = Query(...),
|
||||
modify_total_duration: int = Query(...),
|
||||
scheme_name: str = Query(...),
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
modify_pattern_start_time: str = Query(..., description="模式修改开始时间(ISO 8601格式)"),
|
||||
burst_ID: list[str] = Query(..., description="爆管节点/管段ID列表"),
|
||||
burst_size: list[float] = Query(..., description="对应各爆管点的爆管流量大小列表(L/s)"),
|
||||
modify_total_duration: int = Query(..., description="模拟总时长(秒)"),
|
||||
scheme_name: str = Query(..., description="分析方案名称"),
|
||||
) -> str:
|
||||
"""
|
||||
爆管分析(高级版本)
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **modify_pattern_start_time**: 模式修改开始时间
|
||||
- **burst_ID**: 爆管节点/管段ID列表
|
||||
- **burst_size**: 爆管流量大小列表(与burst_ID对应)
|
||||
- **modify_total_duration**: 模拟总时长(秒)
|
||||
- **scheme_name**: 分析方案名称
|
||||
|
||||
支持在指定时间修改泵控制模式和阀门开度。
|
||||
"""
|
||||
burst_analysis(
|
||||
name=network,
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
@@ -220,74 +220,100 @@ async def fastapi_burst_analysis(
|
||||
return "success"
|
||||
|
||||
|
||||
@router.get("/valvecloseanalysis/")
|
||||
async def valve_close_analysis_endpoint(
|
||||
network: str, valve_id: str, start_time: str, end_time: str
|
||||
):
|
||||
return valve_close_analysis(network, valve_id, start_time, end_time)
|
||||
|
||||
|
||||
@router.get("/valve_close_analysis/", response_class=PlainTextResponse)
|
||||
@router.get("/valve_close_analysis/", response_class=PlainTextResponse, summary="阀门关闭分析(高级)", description="高级版本的阀门关闭分析,支持同时关闭多个阀门,并在指定持续时间内进行模拟。返回纯文本格式的分析结果。")
|
||||
async def fastapi_valve_close_analysis(
|
||||
network: str,
|
||||
start_time: str,
|
||||
valves: List[str] = Query(...),
|
||||
duration: int | None = None,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
start_time: str = Query(..., description="阀门关闭开始时间(ISO 8601格式)"),
|
||||
valves: List[str] = Query(..., description="要关闭的阀门ID列表"),
|
||||
duration: int | None = Query(None, description="模拟持续时间(秒),默认900秒"),
|
||||
scheme_name: str = Query(..., description="阀门关闭方案名称"),
|
||||
) -> str:
|
||||
"""
|
||||
阀门关闭分析(高级版本)
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 阀门关闭开始时间
|
||||
- **valves**: 要关闭的阀门ID列表
|
||||
- **duration**: 模拟持续时间(秒,可选,默认900)
|
||||
- **scheme_name**: 阀门关闭方案名称
|
||||
|
||||
支持同时关闭多个阀门进行分析。
|
||||
"""
|
||||
result = valve_close_analysis(
|
||||
name=network,
|
||||
modify_pattern_start_time=start_time,
|
||||
modify_total_duration=duration or 900,
|
||||
modify_valve_opening={valve_id: 0.0 for valve_id in valves},
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
return result or "success"
|
||||
|
||||
|
||||
@router.get("/valve_isolation_analysis/")
|
||||
@router.get("/valve_isolation_analysis/", summary="阀门隔离分析", description="分析当发生突发事件时,通过关闭指定阀门进行隔离,确定哪些阀门必须关闭、哪些可选关闭,以及隔离的可行性。")
|
||||
async def valve_isolation_endpoint(
|
||||
network: str,
|
||||
accident_element: List[str] = Query(...),
|
||||
disabled_valves: List[str] = Query(None),
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
accident_element: List[str] = Query(..., description="发生事故的管段/节点ID列表"),
|
||||
disabled_valves: List[str] = Query(None, description="已故障的阀门ID列表(可选)"),
|
||||
):
|
||||
result = {
|
||||
"accident_element": "P461309",
|
||||
"accident_elements": ["P461309"],
|
||||
"affected_nodes": [
|
||||
"J316629_A",
|
||||
"J317037_B",
|
||||
"J317060_B",
|
||||
"J408189_B",
|
||||
"J499996",
|
||||
"J524940",
|
||||
"J535933",
|
||||
"J58841",
|
||||
],
|
||||
"isolatable": True,
|
||||
"must_close_valves": ["210521658", "V12974", "V12986", "V12993"],
|
||||
"optional_valves": [],
|
||||
}
|
||||
"""
|
||||
阀门隔离分析
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **accident_element**: 发生事故的管段/节点ID列表
|
||||
- **disabled_valves**: 已故障的阀门ID列表(可选)
|
||||
|
||||
返回隔离方案,包括:
|
||||
- must_close_valves: 必须关闭的阀门列表
|
||||
- optional_valves: 可选关闭的阀门列表
|
||||
- affected_nodes: 受影响的节点列表
|
||||
- isolatable: 是否可以有效隔离
|
||||
"""
|
||||
# result = {
|
||||
# "accident_element": "P461309",
|
||||
# "accident_elements": ["P461309"],
|
||||
# "affected_nodes": [
|
||||
# "J316629_A",
|
||||
# "J317037_B",
|
||||
# "J317060_B",
|
||||
# "J408189_B",
|
||||
# "J499996",
|
||||
# "J524940",
|
||||
# "J535933",
|
||||
# "J58841",
|
||||
# ],
|
||||
# "isolatable": True,
|
||||
# "must_close_valves": ["210521658", "V12974", "V12986", "V12993"],
|
||||
# "optional_valves": [],
|
||||
# }
|
||||
result = analyze_valve_isolation(network, accident_element, disabled_valves)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/flushinganalysis/")
|
||||
async def flushing_analysis_endpoint(
|
||||
network: str, pipe_id: str, start_time: str, duration: float, flow: float
|
||||
):
|
||||
return flushing_analysis(network, pipe_id, start_time, duration, flow)
|
||||
|
||||
|
||||
@router.get("/flushing_analysis/", response_class=PlainTextResponse)
|
||||
@router.get("/flushing_analysis/", response_class=PlainTextResponse, summary="冲洗分析(高级)", description="高级版本的冲洗分析,支持同时开启多个阀门进行冲洗,指定排污节点,并设置固定的冲洗流量。返回纯文本格式的分析结果。")
|
||||
async def fastapi_flushing_analysis(
|
||||
network: str,
|
||||
start_time: str,
|
||||
valves: List[str] = Query(...),
|
||||
valves_k: List[float] = Query(...),
|
||||
drainage_node_ID: str = Query(...),
|
||||
flush_flow: float = 0,
|
||||
duration: int | None = None,
|
||||
scheme_name: str | None = None,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
start_time: str = Query(..., description="冲洗开始时间(ISO 8601格式)"),
|
||||
valves: List[str] = Query(..., description="要开启的阀门ID列表"),
|
||||
valves_k: List[float] = Query(..., description="对应各阀门的开度列表(0-1)"),
|
||||
drainage_node_ID: str = Query(..., description="排污节点ID"),
|
||||
flush_flow: float = Query(0, description="冲洗流量(L/s),0表示自动计算"),
|
||||
duration: int | None = Query(None, description="模拟持续时间(秒),默认900秒"),
|
||||
scheme_name: str = Query(..., description="冲洗方案名称"),
|
||||
) -> str:
|
||||
"""
|
||||
冲洗分析(高级版本)
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 冲洗开始时间
|
||||
- **valves**: 要开启的阀门ID列表
|
||||
- **valves_k**: 各阀门的开度列表(0-1,与valves对应)
|
||||
- **drainage_node_ID**: 排污节点ID
|
||||
- **flush_flow**: 冲洗流量(L/s)
|
||||
- **duration**: 模拟持续时间(秒,可选,默认900)
|
||||
- **scheme_name**: 冲洗方案名称
|
||||
|
||||
支持多阀联合冲洗操作。
|
||||
"""
|
||||
valve_opening = {
|
||||
valve_id: float(valves_k[idx]) for idx, valve_id in enumerate(valves)
|
||||
}
|
||||
@@ -303,16 +329,29 @@ async def fastapi_flushing_analysis(
|
||||
return result or "success"
|
||||
|
||||
|
||||
@router.get("/contaminant_simulation/", response_class=PlainTextResponse)
|
||||
@router.get("/contaminant_simulation/", response_class=PlainTextResponse, summary="污染物模拟", description="对管网中的污染物扩散进行模拟,评估污染源对管网的影响范围和浓度分布。支持指定污染源位置、污染浓度和扩散模式。")
|
||||
async def fastapi_contaminant_simulation(
|
||||
network: str,
|
||||
start_time: str,
|
||||
source: str,
|
||||
concentration: float,
|
||||
duration: int,
|
||||
scheme_name: str | None = None,
|
||||
pattern: str | None = None,
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
start_time: str = Query(..., description="污染开始时间(ISO 8601格式)"),
|
||||
source: str = Query(..., description="污染源节点ID"),
|
||||
concentration: float = Query(..., description="污染浓度(mg/L)"),
|
||||
duration: int = Query(..., description="模拟持续时间(秒)"),
|
||||
scheme_name: str = Query(..., description="模拟方案名称"),
|
||||
pattern: str | None = Query(None, description="污染源模式ID(可选)"),
|
||||
) -> str:
|
||||
"""
|
||||
污染物模拟
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 污染开始时间
|
||||
- **source**: 污染源节点ID
|
||||
- **concentration**: 污染浓度(mg/L)
|
||||
- **duration**: 模拟持续时间(秒)
|
||||
- **scheme_name**: 模拟方案名称
|
||||
- **pattern**: 污染源模式ID(可选)
|
||||
|
||||
用于评估管网中污染物的传播和影响范围。
|
||||
"""
|
||||
result = contaminant_simulation(
|
||||
name=network,
|
||||
modify_pattern_start_time=start_time,
|
||||
@@ -325,15 +364,21 @@ async def fastapi_contaminant_simulation(
|
||||
return result or "success"
|
||||
|
||||
|
||||
@router.get("/ageanalysis/")
|
||||
async def age_analysis_endpoint(network: str):
|
||||
return age_analysis(network)
|
||||
|
||||
|
||||
@router.get("/age_analysis/", response_class=PlainTextResponse)
|
||||
@router.get("/age_analysis/", response_class=PlainTextResponse, summary="水龄分析(高级)", description="高级版本的水龄分析,在指定时间点进行分析,支持自定义模拟持续时间。返回纯文本格式的分析结果。")
|
||||
async def fastapi_age_analysis(
|
||||
network: str, start_time: str, end_time: str, duration: int
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
start_time: str = Query(..., description="分析开始时间(ISO 8601格式)"),
|
||||
duration: int = Query(..., description="模拟持续时间(秒)"),
|
||||
) -> str:
|
||||
"""
|
||||
水龄分析(高级版本)
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 分析开始时间
|
||||
- **duration**: 模拟持续时间(秒)
|
||||
|
||||
分析指定时间段内管网中各节点的水体停留时间。
|
||||
"""
|
||||
result = age_analysis(network, start_time, duration)
|
||||
return result or "success"
|
||||
|
||||
@@ -343,15 +388,39 @@ async def fastapi_age_analysis(
|
||||
# return scheduling_analysis(network)
|
||||
|
||||
|
||||
@router.get("/pressureregulation/")
|
||||
@router.get("/pressureregulation/", summary="压力调节(基础)", description="对管网的压力进行调节分析,通过控制泵的运行来维持目标节点的目标压力。此为基础版本。")
|
||||
async def pressure_regulation_endpoint(
|
||||
network: str, target_node: str, target_pressure: float
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
target_node: str = Query(..., description="目标节点ID"),
|
||||
target_pressure: float = Query(..., description="目标压力值(kPa)"),
|
||||
):
|
||||
"""
|
||||
压力调节(基础版本)
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **target_node**: 目标节点ID
|
||||
- **target_pressure**: 目标压力值(kPa)
|
||||
|
||||
通过泵控制维持目标节点的压力。
|
||||
"""
|
||||
return pressure_regulation(network, target_node, target_pressure)
|
||||
|
||||
|
||||
@router.post("/pressure_regulation/")
|
||||
async def fastapi_pressure_regulation(data: PressureRegulation) -> str:
|
||||
@router.post("/pressure_regulation/", summary="压力调节(高级)", description="高级版本的压力调节分析,通过JSON请求体提供详细的控制参数,包括固定泵和变速泵的独立控制、水箱初始水位等。")
|
||||
async def fastapi_pressure_regulation(data: PressureRegulation = Body(..., description="压力调节控制参数")) -> str:
|
||||
"""
|
||||
压力调节(高级版本)
|
||||
|
||||
请求体参数:
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 控制开始时间
|
||||
- **pump_control**: 泵控制策略字典
|
||||
- **tank_init_level**: 水箱初始水位字典(可选)
|
||||
- **duration**: 模拟持续时间(秒,可选,默认900)
|
||||
- **scheme_name**: 控制方案名称(可选)
|
||||
|
||||
支持固定泵和变速泵的独立控制。
|
||||
"""
|
||||
item = data.dict()
|
||||
simulation.query_corresponding_element_id_and_query_id(item["network"])
|
||||
fixed_pumps = set(globals.fixed_pumps_id.keys())
|
||||
@@ -375,13 +444,20 @@ async def fastapi_pressure_regulation(data: PressureRegulation) -> str:
|
||||
return "success"
|
||||
|
||||
|
||||
@router.get("/projectmanagement/")
|
||||
async def project_management_endpoint(network: str):
|
||||
return project_management(network)
|
||||
|
||||
|
||||
@router.post("/project_management/")
|
||||
async def fastapi_project_management(data: ProjectManagement) -> str:
|
||||
@router.post("/project_management/", summary="项目管理(高级)", description="高级版本的项目管理,通过JSON请求体提供详细的控制参数,包括泵控制策略、水箱初始水位和区域需水量控制。")
|
||||
async def fastapi_project_management(data: ProjectManagement = Body(..., description="项目管理控制参数")) -> str:
|
||||
"""
|
||||
项目管理(高级版本)
|
||||
|
||||
请求体参数:
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 管理开始时间
|
||||
- **pump_control**: 泵控制策略字典
|
||||
- **tank_init_level**: 水箱初始水位字典(可选)
|
||||
- **region_demand**: 区域需水量控制字典(可选)
|
||||
|
||||
支持多维度的项目管理。
|
||||
"""
|
||||
item = data.dict()
|
||||
return project_management(
|
||||
prj_name=item["network"],
|
||||
@@ -397,8 +473,21 @@ async def fastapi_project_management(data: ProjectManagement) -> str:
|
||||
# return daily_scheduling_analysis(network)
|
||||
|
||||
|
||||
@router.post("/scheduling_analysis/")
|
||||
async def fastapi_scheduling_analysis(data: SchedulingAnalysis) -> str:
|
||||
@router.post("/scheduling_analysis/", summary="排程分析", description="对管网的供水排程进行分析,优化泵的运行时间和出水流量,平衡水厂出水、水箱进出水,满足用户需求。")
|
||||
async def fastapi_scheduling_analysis(data: SchedulingAnalysis = Body(..., description="排程分析参数")) -> str:
|
||||
"""
|
||||
排程分析
|
||||
|
||||
请求体参数:
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 分析开始时间
|
||||
- **pump_control**: 泵控制策略字典
|
||||
- **tank_id**: 水箱ID
|
||||
- **water_plant_output_id**: 水厂出水ID
|
||||
- **time_delta**: 时间步长(秒,可选,默认300)
|
||||
|
||||
用于优化供水排程。
|
||||
"""
|
||||
item = data.dict()
|
||||
return scheduling_simulation(
|
||||
item["network"],
|
||||
@@ -410,8 +499,22 @@ async def fastapi_scheduling_analysis(data: SchedulingAnalysis) -> str:
|
||||
)
|
||||
|
||||
|
||||
@router.post("/daily_scheduling_analysis/")
|
||||
async def fastapi_daily_scheduling_analysis(data: DailySchedulingAnalysis) -> str:
|
||||
@router.post("/daily_scheduling_analysis/", summary="日排程分析", description="对管网的每日供水排程进行分析,优化水库、水厂、水箱和用户需求的协调,制定合理的每日排程方案。")
|
||||
async def fastapi_daily_scheduling_analysis(data: DailySchedulingAnalysis = Body(..., description="日排程分析参数")) -> str:
|
||||
"""
|
||||
日排程分析
|
||||
|
||||
请求体参数:
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **start_time**: 分析开始时间
|
||||
- **pump_control**: 泵控制策略字典
|
||||
- **reservoir_id**: 水库ID
|
||||
- **tank_id**: 水箱ID
|
||||
- **water_plant_output_id**: 水厂出水ID
|
||||
- **time_delta**: 时间步长(秒,可选,默认300)
|
||||
|
||||
用于制定每日供水排程方案。
|
||||
"""
|
||||
item = data.dict()
|
||||
return daily_scheduling_simulation(
|
||||
item["network"],
|
||||
@@ -423,8 +526,15 @@ async def fastapi_daily_scheduling_analysis(data: DailySchedulingAnalysis) -> st
|
||||
)
|
||||
|
||||
|
||||
@router.post("/network_project/")
|
||||
async def fastapi_network_project(file: UploadFile = File()) -> str:
|
||||
@router.post("/network_project/", summary="导入网络项目", description="通过上传INP格式的管网文件导入新的网络项目。系统将自动处理文件并执行模拟。")
|
||||
async def fastapi_network_project(file: UploadFile = File(..., description="INP格式的管网文件")) -> str:
|
||||
"""
|
||||
导入网络项目
|
||||
|
||||
- **file**: 上传的INP格式管网文件
|
||||
|
||||
系统将上传的文件保存到inp文件夹并执行模拟。
|
||||
"""
|
||||
temp_file_dir = "./inp/"
|
||||
if not os.path.exists(temp_file_dir):
|
||||
os.mkdir(temp_file_dir)
|
||||
@@ -435,13 +545,15 @@ async def fastapi_network_project(file: UploadFile = File()) -> str:
|
||||
return run_inp(temp_file_name)
|
||||
|
||||
|
||||
@router.get("/networkupdate/")
|
||||
async def network_update_endpoint(network: str):
|
||||
return network_update(network)
|
||||
|
||||
|
||||
@router.post("/network_update/")
|
||||
async def fastapi_network_update(file: UploadFile = File()) -> str:
|
||||
@router.post("/network_update/", summary="管网更新(高级)", description="通过上传更新文件对管网进行高级的更新操作。系统将处理更新文件并应用到数据库。")
|
||||
async def fastapi_network_update(file: UploadFile = File(..., description="包含管网更新信息的文件")) -> str:
|
||||
"""
|
||||
管网更新(高级版本)
|
||||
|
||||
- **file**: 包含管网更新信息的文件
|
||||
|
||||
系统将处理上传的文件并应用管网更新。
|
||||
"""
|
||||
default_folder = "./"
|
||||
temp_file_name = f'network_update_{datetime.now().strftime("%Y%m%d")}'
|
||||
temp_file_path = os.path.join(default_folder, temp_file_name)
|
||||
@@ -459,8 +571,17 @@ async def fastapi_network_update(file: UploadFile = File()) -> str:
|
||||
# return pump_failure(network, pump_id, time)
|
||||
|
||||
|
||||
@router.post("/pump_failure/")
|
||||
async def fastapi_pump_failure(data: PumpFailureState) -> str:
|
||||
@router.post("/pump_failure/", summary="泵故障管理", description="记录和管理泵的故障状态,包括故障发生时间和受影响的泵列表。系统将记录故障日志并更新泵状态。")
|
||||
async def fastapi_pump_failure(data: PumpFailureState = Body(..., description="泵故障状态信息")) -> str:
|
||||
"""
|
||||
泵故障管理
|
||||
|
||||
请求体参数:
|
||||
- **time**: 故障发生时间
|
||||
- **pump_status**: 泵状态字典,包含第一阶段和第二阶段泵的故障状态
|
||||
|
||||
系统将验证泵信息的有效性并更新故障状态文件。
|
||||
"""
|
||||
item = data.dict()
|
||||
with open("./pump_failure_message.txt", "a", encoding="utf-8-sig") as f1:
|
||||
f1.write("[{}] {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), item))
|
||||
@@ -494,19 +615,46 @@ async def fastapi_pump_failure(data: PumpFailureState) -> str:
|
||||
return json.dumps("SUCCESS")
|
||||
|
||||
|
||||
@router.get("/pressuresensorplacementsensitivity/")
|
||||
@router.get("/pressuresensorplacementsensitivity/", summary="压力传感器放置-灵敏度分析(基础)", description="基于灵敏度分析方法,为指定管网项目确定最优的压力传感器放置位置。此为基础版本。")
|
||||
async def pressure_sensor_placement_sensitivity_endpoint(
|
||||
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
|
||||
name: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scheme_name: str = Query(..., description="放置方案名称"),
|
||||
sensor_number: int = Query(..., description="传感器数量"),
|
||||
min_diameter: int = Query(..., description="最小管径限制(毫米)"),
|
||||
username: str = Query(..., description="用户名"),
|
||||
):
|
||||
"""
|
||||
压力传感器放置-灵敏度分析(基础版本)
|
||||
|
||||
- **name**: 管网名称(或数据库名称)
|
||||
- **scheme_name**: 放置方案名称
|
||||
- **sensor_number**: 传感器数量
|
||||
- **min_diameter**: 最小管径限制(毫米)
|
||||
- **username**: 用户名
|
||||
|
||||
基于灵敏度分析方法确定传感器放置位置。
|
||||
"""
|
||||
return pressure_sensor_placement_sensitivity(
|
||||
name, scheme_name, sensor_number, min_diameter, username
|
||||
)
|
||||
|
||||
|
||||
@router.post("/pressure_sensor_placement_sensitivity/")
|
||||
@router.post("/pressure_sensor_placement_sensitivity/", summary="压力传感器放置-灵敏度分析(高级)", description="高级版本的压力传感器放置分析,通过JSON请求体提供详细参数。基于灵敏度分析方法确定最优放置位置。")
|
||||
async def fastapi_pressure_sensor_placement_sensitivity(
|
||||
data: PressureSensorPlacement,
|
||||
data: PressureSensorPlacement = Body(..., description="传感器放置分析参数"),
|
||||
) -> None:
|
||||
"""
|
||||
压力传感器放置-灵敏度分析(高级版本)
|
||||
|
||||
请求体参数:
|
||||
- **name**: 管网名称(或数据库名称)
|
||||
- **scheme_name**: 放置方案名称
|
||||
- **sensor_number**: 传感器数量
|
||||
- **min_diameter**: 最小管径限制(毫米)
|
||||
- **username**: 用户名
|
||||
|
||||
基于灵敏度分析方法确定压力传感器的最优放置位置。
|
||||
"""
|
||||
item = data.dict()
|
||||
pressure_sensor_placement_sensitivity(
|
||||
name=item["name"],
|
||||
@@ -517,19 +665,46 @@ async def fastapi_pressure_sensor_placement_sensitivity(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/pressuresensorplacementkmeans/")
|
||||
@router.get("/pressuresensorplacementkmeans/", summary="压力传感器放置-KMeans聚类分析(基础)", description="基于KMeans聚类算法,为指定管网项目确定压力传感器的最优放置位置。此为基础版本。")
|
||||
async def pressure_sensor_placement_kmeans_endpoint(
|
||||
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
|
||||
name: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scheme_name: str = Query(..., description="放置方案名称"),
|
||||
sensor_number: int = Query(..., description="传感器数量"),
|
||||
min_diameter: int = Query(..., description="最小管径限制(毫米)"),
|
||||
username: str = Query(..., description="用户名"),
|
||||
):
|
||||
"""
|
||||
压力传感器放置-KMeans聚类分析(基础版本)
|
||||
|
||||
- **name**: 管网名称(或数据库名称)
|
||||
- **scheme_name**: 放置方案名称
|
||||
- **sensor_number**: 传感器数量
|
||||
- **min_diameter**: 最小管径限制(毫米)
|
||||
- **username**: 用户名
|
||||
|
||||
基于KMeans聚类算法确定传感器放置位置。
|
||||
"""
|
||||
return pressure_sensor_placement_kmeans(
|
||||
name, scheme_name, sensor_number, min_diameter, username
|
||||
)
|
||||
|
||||
|
||||
@router.post("/pressure_sensor_placement_kmeans/")
|
||||
@router.post("/pressure_sensor_placement_kmeans/", summary="压力传感器放置-KMeans聚类分析(高级)", description="高级版本的压力传感器放置分析,通过JSON请求体提供详细参数。基于KMeans聚类算法确定最优放置位置。")
|
||||
async def fastapi_pressure_sensor_placement_kmeans(
|
||||
data: PressureSensorPlacement,
|
||||
data: PressureSensorPlacement = Body(..., description="传感器放置分析参数"),
|
||||
) -> None:
|
||||
"""
|
||||
压力传感器放置-KMeans聚类分析(高级版本)
|
||||
|
||||
请求体参数:
|
||||
- **name**: 管网名称(或数据库名称)
|
||||
- **scheme_name**: 放置方案名称
|
||||
- **sensor_number**: 传感器数量
|
||||
- **min_diameter**: 最小管径限制(毫米)
|
||||
- **username**: 用户名
|
||||
|
||||
基于KMeans聚类算法确定压力传感器的最优放置位置。
|
||||
"""
|
||||
item = data.dict()
|
||||
pressure_sensor_placement_kmeans(
|
||||
name=item["name"],
|
||||
@@ -540,16 +715,31 @@ async def fastapi_pressure_sensor_placement_kmeans(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sensorplacementscheme/create")
|
||||
@router.post("/sensorplacementscheme/create", summary="传感器放置方案创建", description="创建新的传感器放置方案,支持灵敏度分析和KMeans聚类两种方法。根据指定的方法自动计算最优的传感器放置位置。")
|
||||
async def fastapi_pressure_sensor_placement(
|
||||
network: str = Query(...),
|
||||
scheme_name: str = Query(...),
|
||||
sensor_type: str = Query(...),
|
||||
method: str = Query(...),
|
||||
sensor_count: int = Query(...),
|
||||
min_diameter: int = Query(0),
|
||||
user_name: str = Query(...),
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
scheme_name: str = Query(..., description="放置方案名称"),
|
||||
sensor_type: str = Query(..., description="传感器类型"),
|
||||
method: str = Query(..., description="放置方法('sensitivity'或'kmeans')"),
|
||||
sensor_count: int = Query(..., description="传感器数量"),
|
||||
min_diameter: int = Query(0, description="最小管径限制(毫米),默认0"),
|
||||
user_name: str = Query(..., description="用户名"),
|
||||
) -> str:
|
||||
"""
|
||||
传感器放置方案创建
|
||||
|
||||
- **network**: 管网名称(或数据库名称)
|
||||
- **scheme_name**: 放置方案名称
|
||||
- **sensor_type**: 传感器类型
|
||||
- **method**: 放置方法('sensitivity'或'kmeans')
|
||||
- **sensor_count**: 传感器数量
|
||||
- **min_diameter**: 最小管径限制(毫米,默认0)
|
||||
- **user_name**: 用户名
|
||||
|
||||
支持两种放置方法:
|
||||
- sensitivity: 基于灵敏度分析
|
||||
- kmeans: 基于KMeans聚类
|
||||
"""
|
||||
if method not in ["sensitivity", "kmeans"]:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid method. Must be 'sensitivity' or 'kmeans'"
|
||||
@@ -573,68 +763,22 @@ async def fastapi_pressure_sensor_placement(
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/scadadevicedatacleaning/")
|
||||
async def fastapi_scada_device_data_cleaning(
|
||||
network: str = Query(...),
|
||||
ids_list: List[str] = Query(...),
|
||||
start_time: str = Query(...),
|
||||
end_time: str = Query(...),
|
||||
user_name: str = Query(...),
|
||||
) -> str:
|
||||
item = {
|
||||
"network": network,
|
||||
"ids": ids_list,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"user_name": user_name,
|
||||
}
|
||||
query_ids_list = item["ids"][0].split(",")
|
||||
scada_data = influxdb_api.query_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids_list,
|
||||
start_time=item["start_time"],
|
||||
end_time=item["end_time"],
|
||||
)
|
||||
scada_device_info = influxdb_api.query_pg_scada_info(item["network"])
|
||||
scada_device_info_dict = {info["id"]: info for info in scada_device_info}
|
||||
type_groups: dict[str, list[str]] = {}
|
||||
for device_id in query_ids_list:
|
||||
device_info = scada_device_info_dict.get(device_id, {})
|
||||
device_type = device_info.get("type", "unknown")
|
||||
type_groups.setdefault(device_type, []).append(device_id)
|
||||
for device_type, device_ids in type_groups.items():
|
||||
if device_type not in ["pressure", "pipe_flow"]:
|
||||
continue
|
||||
type_scada_data = {
|
||||
device_id: scada_data[device_id]
|
||||
for device_id in device_ids
|
||||
if device_id in scada_data
|
||||
}
|
||||
if not type_scada_data:
|
||||
continue
|
||||
time_list = [record["time"] for record in next(iter(type_scada_data.values()))]
|
||||
df = pd.DataFrame({"time": time_list})
|
||||
for device_id in device_ids:
|
||||
if device_id in type_scada_data:
|
||||
values = [record["value"] for record in type_scada_data[device_id]]
|
||||
df[device_id] = values
|
||||
if device_type == "pressure":
|
||||
cleaned_value_df = pressure_data_clean.clean_pressure_data_df_km(df)
|
||||
elif device_type == "pipe_flow":
|
||||
cleaned_value_df = flow_data_clean.clean_flow_data_df_kf(df)
|
||||
cleaned_value_df = pd.DataFrame(cleaned_value_df)
|
||||
cleaned_df = pd.concat([df["time"], cleaned_value_df], axis=1)
|
||||
influxdb_api.import_multicolumn_data_from_dict(
|
||||
data_dict=cleaned_df.to_dict("list"),
|
||||
raw=False,
|
||||
)
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/runsimulationmanuallybydate/")
|
||||
@router.post("/runsimulationmanuallybydate/", summary="手动运行日期指定模拟", description="根据指定的开始时间和持续时间,手动运行水力模拟。开始时间必须是显式带时区的 ISO 8601 / RFC3339 时间。")
|
||||
async def fastapi_run_simulation_manually_by_date(
|
||||
data: RunSimulationManuallyByDate,
|
||||
data: RunSimulationManuallyByDate = Body(..., description="模拟运行参数"),
|
||||
) -> dict[str, str]:
|
||||
item = data.dict()
|
||||
"""
|
||||
手动运行日期指定模拟
|
||||
|
||||
请求体参数:
|
||||
- **name**: 管网名称(或数据库名称)
|
||||
- **start_time**: 开始时间(ISO 8601 / RFC3339,必须显式带时区)
|
||||
- **duration**: 模拟持续时间(分钟)
|
||||
|
||||
系统将从指定时间开始,按15分钟间隔多次运行模拟。
|
||||
每次模拟间隔15分钟,直至达到指定的总持续时间。
|
||||
"""
|
||||
item = data.model_dump()
|
||||
try:
|
||||
simulation.query_corresponding_element_id_and_query_id(item["name"])
|
||||
simulation.query_corresponding_pattern_id_and_query_id(item["name"])
|
||||
@@ -661,10 +805,10 @@ async def fastapi_run_simulation_manually_by_date(
|
||||
globals.source_outflow_region_id,
|
||||
globals.realtime_region_pipe_flow_and_demand_id,
|
||||
)
|
||||
base_date = datetime.strptime(item["simulation_date"], "%Y-%m-%d")
|
||||
start_time = parse_utc_time(item["start_time"], field_name="start_time")
|
||||
run_simulation_manually_by_date(
|
||||
item["name"], base_date, item["start_time"], item["duration"]
|
||||
item["name"], start_time, item["duration"]
|
||||
)
|
||||
return {"status": "success"}
|
||||
except Exception as exc:
|
||||
return {"status": "error", "message": str(exc)}
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from app.native.api import ChangeSet
|
||||
from fastapi import APIRouter, Request, Query
|
||||
from app.services.tjnetwork import (
|
||||
ChangeSet,
|
||||
get_current_operation,
|
||||
execute_undo,
|
||||
execute_redo,
|
||||
@@ -22,90 +22,183 @@ from app.services.tjnetwork import (
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getcurrentoperationid/")
|
||||
async def get_current_operation_id_endpoint(network: str) -> int:
|
||||
@router.get("/getcurrentoperationid/", summary="获取当前操作ID", description="获取网络当前的操作ID")
|
||||
async def get_current_operation_id_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> int:
|
||||
"""
|
||||
获取当前操作ID
|
||||
|
||||
返回网络当前正在执行的操作ID
|
||||
"""
|
||||
return get_current_operation(network)
|
||||
|
||||
@router.post("/undo/")
|
||||
async def undo_endpoint(network: str):
|
||||
@router.post("/undo/", summary="撤销操作", description="撤销网络上最后的一个操作")
|
||||
async def undo_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")):
|
||||
"""
|
||||
撤销操作
|
||||
|
||||
撤销网络上最近执行的一个操作
|
||||
"""
|
||||
return execute_undo(network)
|
||||
|
||||
@router.post("/redo/")
|
||||
async def redo_endpoint(network: str):
|
||||
@router.post("/redo/", summary="重做操作", description="重做网络上被撤销的操作")
|
||||
async def redo_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")):
|
||||
"""
|
||||
重做操作
|
||||
|
||||
重做网络上被撤销的操作
|
||||
"""
|
||||
return execute_redo(network)
|
||||
|
||||
@router.get("/getsnapshots/")
|
||||
async def list_snapshot_endpoint(network: str) -> list[tuple[int, str]]:
|
||||
@router.get("/getsnapshots/", summary="获取快照列表", description="获取网络中的所有快照")
|
||||
async def list_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[tuple[int, str]]:
|
||||
"""
|
||||
获取快照列表
|
||||
|
||||
返回网络中所有可用的快照及其信息
|
||||
"""
|
||||
return list_snapshot(network)
|
||||
|
||||
@router.get("/havesnapshot/")
|
||||
async def have_snapshot_endpoint(network: str, tag: str) -> bool:
|
||||
@router.get("/havesnapshot/", summary="检查快照是否存在", description="检查指定标签的快照是否存在")
|
||||
async def have_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> bool:
|
||||
"""
|
||||
检查快照是否存在
|
||||
|
||||
返回指定标签的快照是否存在
|
||||
"""
|
||||
return have_snapshot(network, tag)
|
||||
|
||||
@router.get("/havesnapshotforoperation/")
|
||||
async def have_snapshot_for_operation_endpoint(network: str, operation: int) -> bool:
|
||||
@router.get("/havesnapshotforoperation/", summary="检查操作快照是否存在", description="检查指定操作ID的快照是否存在")
|
||||
async def have_snapshot_for_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), operation: int = Query(..., description="操作ID")) -> bool:
|
||||
"""
|
||||
检查操作快照是否存在
|
||||
|
||||
返回指定操作ID的快照是否存在
|
||||
"""
|
||||
return have_snapshot_for_operation(network, operation)
|
||||
|
||||
@router.get("/havesnapshotforcurrentoperation/")
|
||||
async def have_snapshot_for_current_operation_endpoint(network: str) -> bool:
|
||||
@router.get("/havesnapshotforcurrentoperation/", summary="检查当前操作快照是否存在", description="检查当前操作的快照是否存在")
|
||||
async def have_snapshot_for_current_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> bool:
|
||||
"""
|
||||
检查当前操作快照是否存在
|
||||
|
||||
返回当前操作的快照是否存在
|
||||
"""
|
||||
return have_snapshot_for_current_operation(network)
|
||||
|
||||
@router.post("/takesnapshotforoperation/")
|
||||
@router.post("/takesnapshotforoperation/", summary="为操作创建快照", description="为指定的操作创建快照")
|
||||
async def take_snapshot_for_operation_endpoint(
|
||||
network: str, operation: int, tag: str
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
operation: int = Query(..., description="操作ID"),
|
||||
tag: str = Query(..., description="快照标签")
|
||||
) -> None:
|
||||
"""
|
||||
为操作创建快照
|
||||
|
||||
为指定操作创建一个带标签的快照
|
||||
"""
|
||||
return take_snapshot_for_operation(network, operation, tag)
|
||||
|
||||
@router.post("/takesnapshotforcurrentoperation")
|
||||
async def take_snapshot_for_current_operation_endpoint(network: str, tag: str) -> None:
|
||||
@router.post("/takesnapshotforcurrentoperation", summary="为当前操作创建快照", description="为当前操作创建快照")
|
||||
async def take_snapshot_for_current_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> None:
|
||||
"""
|
||||
为当前操作创建快照
|
||||
|
||||
为网络当前操作创建一个快照
|
||||
"""
|
||||
return take_snapshot_for_current_operation(network, tag)
|
||||
|
||||
# 兼容旧拼写: takenapshotforcurrentoperation
|
||||
@router.post("/takenapshotforcurrentoperation")
|
||||
async def take_snapshot_for_current_operation_legacy_endpoint(
|
||||
network: str, tag: str
|
||||
) -> None:
|
||||
@router.post("/takenapshotforcurrentoperation", summary="为当前操作创建快照(兼容模式)", description="为当前操作创建快照(兼容旧的API路径)")
|
||||
async def take_snapshot_for_current_operation_legacy_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> None:
|
||||
"""
|
||||
为当前操作创建快照(兼容模式)
|
||||
|
||||
兼容旧的API路径,为网络当前操作创建一个快照
|
||||
"""
|
||||
return take_snapshot_for_current_operation(network, tag)
|
||||
|
||||
@router.post("/takesnapshot/")
|
||||
async def take_snapshot_endpoint(network: str, tag: str) -> None:
|
||||
@router.post("/takesnapshot/", summary="创建快照", description="为网络创建一个快照")
|
||||
async def take_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签")) -> None:
|
||||
"""
|
||||
创建快照
|
||||
|
||||
为网络创建一个带标签的快照
|
||||
"""
|
||||
return take_snapshot(network, tag)
|
||||
|
||||
@router.post("/picksnapshot/", response_model=None)
|
||||
async def pick_snapshot_endpoint(network: str, tag: str, discard: bool = False) -> ChangeSet:
|
||||
@router.post("/picksnapshot/", summary="选择快照", description="选择并恢复到指定的快照", response_model=None)
|
||||
async def pick_snapshot_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), tag: str = Query(..., description="快照标签"), discard: bool = Query(False, description="是否丢弃当前更改")) -> ChangeSet:
|
||||
"""
|
||||
选择快照
|
||||
|
||||
选择并恢复到指定的快照
|
||||
"""
|
||||
return pick_snapshot(network, tag, discard)
|
||||
|
||||
@router.post("/pickoperation/", response_model=None)
|
||||
@router.post("/pickoperation/", summary="选择操作", description="选择并恢复到指定的操作", response_model=None)
|
||||
async def pick_operation_endpoint(
|
||||
network: str, operation: int, discard: bool = False
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
operation: int = Query(..., description="操作ID"),
|
||||
discard: bool = Query(False, description="是否丢弃当前更改")
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
选择操作
|
||||
|
||||
选择并恢复到指定的操作
|
||||
"""
|
||||
return pick_operation(network, operation, discard)
|
||||
|
||||
@router.get("/syncwithserver/", response_model=None)
|
||||
async def sync_with_server_endpoint(network: str, operation: int) -> ChangeSet:
|
||||
@router.get("/syncwithserver/", summary="与服务器同步", description="将网络与服务器同步到指定操作", response_model=None)
|
||||
async def sync_with_server_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), operation: int = Query(..., description="目标操作ID")) -> ChangeSet:
|
||||
"""
|
||||
与服务器同步
|
||||
|
||||
将网络与服务器同步到指定的操作
|
||||
"""
|
||||
return sync_with_server(network, operation)
|
||||
|
||||
@router.post("/batch/", response_model=None)
|
||||
async def execute_batch_commands_endpoint(network: str, req: Request) -> ChangeSet:
|
||||
@router.post("/batch/", summary="执行批量命令", description="执行多个网络操作命令", response_model=None)
|
||||
async def execute_batch_commands_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), req: Request = None) -> ChangeSet:
|
||||
"""
|
||||
执行批量命令
|
||||
|
||||
在网络上执行多个操作命令
|
||||
"""
|
||||
jo_root = await req.json()
|
||||
cs: ChangeSet = ChangeSet()
|
||||
cs.operations = jo_root["operations"]
|
||||
rcs = execute_batch_commands(network, cs)
|
||||
return rcs
|
||||
|
||||
@router.post("/compressedbatch/", response_model=None)
|
||||
@router.post("/compressedbatch/", summary="执行压缩批量命令", description="执行压缩的批量命令", response_model=None)
|
||||
async def execute_compressed_batch_commands_endpoint(
|
||||
network: str, req: Request
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> ChangeSet:
|
||||
"""
|
||||
执行压缩批量命令
|
||||
|
||||
执行压缩格式的批量命令
|
||||
"""
|
||||
jo_root = await req.json()
|
||||
cs: ChangeSet = ChangeSet()
|
||||
cs.operations = jo_root["operations"]
|
||||
return execute_batch_command(network, cs)
|
||||
|
||||
@router.get("/getrestoreoperation/")
|
||||
async def get_restore_operation_endpoint(network: str) -> int:
|
||||
@router.get("/getrestoreoperation/", summary="获取恢复操作ID", description="获取网络的恢复操作ID")
|
||||
async def get_restore_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)")) -> int:
|
||||
"""
|
||||
获取恢复操作ID
|
||||
|
||||
返回网络的恢复操作ID
|
||||
"""
|
||||
return get_restore_operation(network)
|
||||
|
||||
@router.post("/setrestoreoperation/")
|
||||
async def set_restore_operation_endpoint(network: str, operation: int) -> None:
|
||||
@router.post("/setrestoreoperation/", summary="设置恢复操作ID", description="设置网络的恢复操作ID")
|
||||
async def set_restore_operation_endpoint(network: str = Query(..., description="管网名称(或数据库名称)"), operation: int = Query(..., description="操作ID")) -> None:
|
||||
"""
|
||||
设置恢复操作ID
|
||||
|
||||
设置网络的恢复操作ID
|
||||
"""
|
||||
return set_restore_operation(network, operation)
|
||||
|
||||
@@ -0,0 +1,263 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from app.infra.db.timescaledb.composite_queries import CompositeQueries
|
||||
from .dependencies import get_timescale_connection, get_postgres_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/composite/scada-simulation", summary="获取SCADA关联的模拟数据")
|
||||
async def get_scada_associated_simulation_data(
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
device_ids: str = Query(..., description="SCADA设备ID列表,逗号分隔"),
|
||||
scheme_type: str = Query(None, description="方案类型,若为空则查询实时数据"),
|
||||
scheme_name: str = Query(None, description="方案名称,若为空则查询实时数据"),
|
||||
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
|
||||
):
|
||||
"""
|
||||
获取SCADA关联的link/node模拟值
|
||||
|
||||
根据传入的SCADA device_ids,找到关联的link/node,
|
||||
并根据对应的type,查询对应的模拟数据。支持查询实时或方案数据。
|
||||
|
||||
Args:
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
device_ids: SCADA设备ID列表,用逗号分隔
|
||||
scheme_type: 方案类型,若为空则查询实时数据
|
||||
scheme_name: 方案名称,若为空则查询实时数据
|
||||
timescale_conn: TimescaleDB连接
|
||||
postgres_conn: PostgreSQL连接
|
||||
|
||||
Returns:
|
||||
SCADA关联的模拟数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误,未找到数据时返回404错误
|
||||
"""
|
||||
try:
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()]
|
||||
if device_ids
|
||||
else []
|
||||
)
|
||||
|
||||
if scheme_type and scheme_name:
|
||||
result = await CompositeQueries.get_scada_associated_scheme_simulation_data(
|
||||
timescale_conn,
|
||||
postgres_conn,
|
||||
device_ids_list,
|
||||
start_time,
|
||||
end_time,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
)
|
||||
else:
|
||||
result = (
|
||||
await CompositeQueries.get_scada_associated_realtime_simulation_data(
|
||||
timescale_conn,
|
||||
postgres_conn,
|
||||
device_ids_list,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="No simulation data found")
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/composite/element-simulation", summary="获取管网元素的模拟数据")
|
||||
async def get_feature_simulation_data(
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
feature_infos: str = Query(
|
||||
..., description="特征信息,格式: id1:type1,id2:type2,type为pipe(管道)或junction(节点)"
|
||||
),
|
||||
scheme_type: str = Query(None, description="方案类型,若为空则查询实时数据"),
|
||||
scheme_name: str = Query(None, description="方案名称,若为空则查询实时数据"),
|
||||
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
获取link/node模拟值
|
||||
|
||||
根据传入的featureInfos,找到关联的link/node,
|
||||
并根据对应的type,查询对应的模拟数据。支持查询实时或方案数据。
|
||||
|
||||
Args:
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
feature_infos: 格式为 "element_id1:type1,element_id2:type2"
|
||||
例如: "P1:pipe,J1:junction"
|
||||
scheme_type: 方案类型,若为空则查询实时数据
|
||||
scheme_name: 方案名称,若为空则查询实时数据
|
||||
timescale_conn: TimescaleDB连接
|
||||
|
||||
Returns:
|
||||
管网元素的模拟数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 当feature_infos为空返回400错误,未找到数据返回404错误,其他错误返回400错误
|
||||
"""
|
||||
try:
|
||||
feature_infos_list = []
|
||||
if feature_infos:
|
||||
for item in feature_infos.split(","):
|
||||
item = item.strip()
|
||||
if ":" in item:
|
||||
element_id, element_type = item.split(":", 1)
|
||||
feature_infos_list.append(
|
||||
(element_id.strip(), element_type.strip())
|
||||
)
|
||||
|
||||
if not feature_infos_list:
|
||||
raise HTTPException(status_code=400, detail="feature_infos cannot be empty")
|
||||
|
||||
if scheme_type and scheme_name:
|
||||
result = await CompositeQueries.get_scheme_simulation_data(
|
||||
timescale_conn,
|
||||
feature_infos_list,
|
||||
start_time,
|
||||
end_time,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
)
|
||||
else:
|
||||
result = await CompositeQueries.get_realtime_simulation_data(
|
||||
timescale_conn,
|
||||
feature_infos_list,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="No simulation data found")
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/composite/element-scada", summary="获取管网元素关联的SCADA监测数据")
|
||||
async def get_element_associated_scada_data(
|
||||
element_id: str = Query(..., description="管网元素ID(管道或节点)"),
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
use_cleaned: bool = Query(False, description="是否使用清洗后的数据"),
|
||||
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
|
||||
):
|
||||
"""
|
||||
获取link/node关联的SCADA监测值
|
||||
|
||||
根据传入的link/node id,匹配SCADA信息,
|
||||
如果存在关联的SCADA device_id,获取实际的监测数据。
|
||||
|
||||
Args:
|
||||
element_id: 管网元素ID
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
use_cleaned: 是否使用清洗后的数据,默认为False使用原始数据
|
||||
timescale_conn: TimescaleDB连接
|
||||
postgres_conn: PostgreSQL连接
|
||||
|
||||
Returns:
|
||||
管网元素关联的SCADA监测数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误,未找到关联数据返回404错误
|
||||
"""
|
||||
try:
|
||||
result = await CompositeQueries.get_element_associated_scada_data(
|
||||
timescale_conn, postgres_conn, element_id, start_time, end_time, use_cleaned
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No associated SCADA data found"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/composite/clean-scada", summary="清洗SCADA监测数据")
|
||||
async def clean_scada_data(
|
||||
device_ids: str = Query(..., description="设备ID列表或 'all' 表示清洗所有设备"),
|
||||
start_time: datetime = Query(..., description="清洗数据的开始时间"),
|
||||
end_time: datetime = Query(..., description="清洗数据的结束时间"),
|
||||
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
|
||||
):
|
||||
"""
|
||||
清洗SCADA监测数据
|
||||
|
||||
根据device_ids查询monitored_value,清洗后更新cleaned_value。
|
||||
支持清洗指定设备或所有设备的数据。
|
||||
|
||||
Args:
|
||||
device_ids: 设备ID列表,用逗号分隔,或 'all' 表示清洗所有设备
|
||||
start_time: 清洗数据的开始时间
|
||||
end_time: 清洗数据的结束时间
|
||||
timescale_conn: TimescaleDB连接
|
||||
postgres_conn: PostgreSQL连接
|
||||
|
||||
Returns:
|
||||
清洗结果信息
|
||||
|
||||
Raises:
|
||||
HTTPException: 当清洗过程出现错误时返回400错误
|
||||
"""
|
||||
try:
|
||||
if device_ids == "all":
|
||||
device_ids_list = []
|
||||
else:
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()]
|
||||
if device_ids
|
||||
else []
|
||||
)
|
||||
return await CompositeQueries.clean_scada_data(
|
||||
timescale_conn, postgres_conn, device_ids_list, start_time, end_time
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/composite/pipeline-health-prediction", summary="预测管道健康状况")
|
||||
async def predict_pipeline_health(
|
||||
query_time: datetime = Query(..., description="查询时间"),
|
||||
network_name: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
timescale_conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
预测管道健康状况
|
||||
|
||||
根据管网名称和当前时间,查询管道信息和实时数据,
|
||||
使用随机生存森林模型预测管道的生存概率。
|
||||
|
||||
Args:
|
||||
query_time: 查询时间
|
||||
network_name: 管网名称(或数据库名称)
|
||||
timescale_conn: TimescaleDB连接
|
||||
|
||||
Returns:
|
||||
预测结果列表,每个元素包含 link_id 和对应的生存函数
|
||||
|
||||
Raises:
|
||||
HTTPException: 当模型文件不存在返回404错误,其他错误返回400或500错误
|
||||
"""
|
||||
try:
|
||||
return await CompositeQueries.predict_pipeline_health(
|
||||
timescale_conn, network_name, query_time
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
|
||||
@@ -0,0 +1,19 @@
|
||||
from fastapi import Depends
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from app.auth.project_dependencies import (
|
||||
get_project_pg_connection,
|
||||
get_project_timescale_connection,
|
||||
)
|
||||
|
||||
|
||||
async def get_timescale_connection(
|
||||
conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_postgres_connection(
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
yield conn
|
||||
@@ -0,0 +1,289 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from app.infra.db.timescaledb.repositories.realtime import RealtimeRepository
|
||||
from .dependencies import get_timescale_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TIME_WITH_TZ_DESC = "ISO 8601 / RFC 3339 时间,必须显式带时区;可直接传 UTC+8,服务端会先转换为 UTC 再处理。"
|
||||
TIME_RANGE_START_DESC = f"时间范围开始时间。{TIME_WITH_TZ_DESC}"
|
||||
TIME_RANGE_END_DESC = f"时间范围结束时间。{TIME_WITH_TZ_DESC}"
|
||||
|
||||
|
||||
@router.post("/realtime/links/batch", status_code=201, summary="批量插入实时管道数据")
|
||||
async def insert_realtime_links(
|
||||
data: List[dict] = Body(..., description="管道数据列表,每项包含管道ID、时间戳等信息"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection)
|
||||
):
|
||||
"""
|
||||
批量插入实时管道数据
|
||||
|
||||
将管道的实时监测数据批量插入时间序列数据库。
|
||||
|
||||
Args:
|
||||
data: 管道数据列表
|
||||
|
||||
Returns:
|
||||
插入成功的记录数
|
||||
"""
|
||||
await RealtimeRepository.insert_links_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/realtime/links",
|
||||
summary="查询实时管道数据",
|
||||
description="按时间范围查询实时管道数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
|
||||
)
|
||||
async def get_realtime_links(
|
||||
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
|
||||
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
查询指定时间范围内的实时管道数据
|
||||
|
||||
根据时间范围查询所有实时管道的监测值。传入时间必须显式包含时区,
|
||||
可以直接使用 UTC+8,服务端会先统一转换为 UTC 再参与数据库查询。
|
||||
|
||||
Args:
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
|
||||
Returns:
|
||||
实时管道数据列表
|
||||
"""
|
||||
return await RealtimeRepository.get_links_by_time_range(conn, start_time, end_time)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/realtime/links",
|
||||
summary="删除实时管道数据",
|
||||
description="按时间范围删除实时管道数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端按请求中的绝对时间删除对应 UTC 数据。",
|
||||
)
|
||||
async def delete_realtime_links(
|
||||
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
|
||||
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
删除指定时间范围内的实时管道数据
|
||||
|
||||
删除在指定时间范围内的所有实时管道监测数据。
|
||||
|
||||
Args:
|
||||
start_time: 删除开始时间
|
||||
end_time: 删除结束时间
|
||||
|
||||
Returns:
|
||||
删除结果信息
|
||||
"""
|
||||
await RealtimeRepository.delete_links_by_time_range(conn, start_time, end_time)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.patch("/realtime/links/{link_id}/field", summary="更新实时管道字段")
|
||||
async def update_realtime_link_field(
|
||||
link_id: str = Path(..., description="管道ID"),
|
||||
time: datetime = Query(..., description=f"要更新记录的时间戳。{TIME_WITH_TZ_DESC}"),
|
||||
field: str = Query(..., description="要更新的字段名称"),
|
||||
value: float = Query(..., description="更新的字段值"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
更新指定管道的字段值
|
||||
|
||||
更新实时管道在特定时间的某个字段数据。
|
||||
|
||||
Args:
|
||||
link_id: 管道ID
|
||||
time: 数据时间戳
|
||||
field: 字段名称
|
||||
value: 字段新值
|
||||
|
||||
Returns:
|
||||
更新结果信息
|
||||
|
||||
Raises:
|
||||
HTTPException: 当字段不存在或更新失败时返回400错误
|
||||
"""
|
||||
try:
|
||||
await RealtimeRepository.update_link_field(conn, time, link_id, field, value)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/realtime/nodes/batch", status_code=201, summary="批量插入实时节点数据")
|
||||
async def insert_realtime_nodes(
|
||||
data: List[dict] = Body(..., description="节点数据列表,每项包含节点ID、时间戳等信息"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection)
|
||||
):
|
||||
"""
|
||||
批量插入实时节点数据
|
||||
|
||||
将节点的实时监测数据批量插入时间序列数据库。
|
||||
|
||||
Args:
|
||||
data: 节点数据列表
|
||||
|
||||
Returns:
|
||||
插入成功的记录数
|
||||
"""
|
||||
await RealtimeRepository.insert_nodes_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/realtime/nodes",
|
||||
summary="查询实时节点数据",
|
||||
description="按时间范围查询实时节点数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
|
||||
)
|
||||
async def get_realtime_nodes(
|
||||
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
|
||||
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
查询指定时间范围内的实时节点数据
|
||||
|
||||
根据时间范围查询所有实时节点的监测值。传入时间必须显式包含时区,
|
||||
可以直接使用 UTC+8,服务端会先统一转换为 UTC 再参与数据库查询。
|
||||
|
||||
Args:
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
|
||||
Returns:
|
||||
实时节点数据列表
|
||||
"""
|
||||
return await RealtimeRepository.get_nodes_by_time_range(conn, start_time, end_time)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/realtime/nodes",
|
||||
summary="删除实时节点数据",
|
||||
description="按时间范围删除实时节点数据。start_time 和 end_time 必须显式带时区;允许传 UTC+8,服务端按请求中的绝对时间删除对应 UTC 数据。",
|
||||
)
|
||||
async def delete_realtime_nodes(
|
||||
start_time: datetime = Query(..., description=TIME_RANGE_START_DESC),
|
||||
end_time: datetime = Query(..., description=TIME_RANGE_END_DESC),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
删除指定时间范围内的实时节点数据
|
||||
|
||||
删除在指定时间范围内的所有实时节点监测数据。
|
||||
|
||||
Args:
|
||||
start_time: 删除开始时间
|
||||
end_time: 删除结束时间
|
||||
|
||||
Returns:
|
||||
删除结果信息
|
||||
"""
|
||||
await RealtimeRepository.delete_nodes_by_time_range(conn, start_time, end_time)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/realtime/simulation/store", status_code=201, summary="存储实时模拟结果")
|
||||
async def store_realtime_simulation_result(
|
||||
node_result_list: List[dict] = Body(..., description="节点模拟结果列表"),
|
||||
link_result_list: List[dict] = Body(..., description="管道模拟结果列表"),
|
||||
result_start_time: str = Query(..., description=f"模拟结果开始时间。{TIME_WITH_TZ_DESC}"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
存储实时模拟结果到时间序列数据库
|
||||
|
||||
将节点和管道的实时模拟计算结果批量存储到TimescaleDB数据库。
|
||||
|
||||
Args:
|
||||
node_result_list: 节点模拟结果列表
|
||||
link_result_list: 管道模拟结果列表
|
||||
result_start_time: 模拟结果对应的起始时间
|
||||
|
||||
Returns:
|
||||
存储结果信息
|
||||
"""
|
||||
await RealtimeRepository.store_realtime_simulation_result(
|
||||
conn, node_result_list, link_result_list, result_start_time
|
||||
)
|
||||
return {"message": "Simulation results stored successfully"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/realtime/query/by-time-property",
|
||||
summary="按时间和属性查询实时数据",
|
||||
description="查询指定时间点的实时属性值。query_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
|
||||
)
|
||||
async def query_realtime_records_by_time_property(
|
||||
query_time: str = Query(..., description=f"查询时间。{TIME_WITH_TZ_DESC}"),
|
||||
type: str = Query(..., description="数据类型,pipe(管道)或 junction(节点)"),
|
||||
property: str = Query(..., description="要查询的属性名称"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
按指定时间和属性查询所有实时监测数据
|
||||
|
||||
查询在特定时间点,所有指定类型元素的特定属性值。
|
||||
|
||||
Args:
|
||||
query_time: 查询时间
|
||||
type: 元素类型(pipe或junction)
|
||||
property: 属性名称
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误
|
||||
"""
|
||||
try:
|
||||
results = await RealtimeRepository.query_all_record_by_time_property(
|
||||
conn, query_time, type, property
|
||||
)
|
||||
return {"results": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/realtime/query/by-id-time",
|
||||
summary="按ID和时间查询实时模拟数据",
|
||||
description="查询指定元素在某一时间点的实时模拟结果。query_time 必须显式带时区;允许传 UTC+8,服务端会先归一化为 UTC 再执行查询。",
|
||||
)
|
||||
async def query_realtime_simulation_by_id_time(
|
||||
id: str = Query(..., description="元素ID(管道ID或节点ID)"),
|
||||
type: str = Query(..., description="元素类型,pipe(管道)或 junction(节点)"),
|
||||
query_time: str = Query(..., description=f"查询时间。{TIME_WITH_TZ_DESC}"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
按指定ID和时间查询实时模拟结果
|
||||
|
||||
查询特定元素在某一时间点的实时模拟数据。
|
||||
|
||||
Args:
|
||||
id: 元素ID
|
||||
type: 元素类型(pipe或junction)
|
||||
query_time: 查询时间
|
||||
|
||||
Returns:
|
||||
模拟结果数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误
|
||||
"""
|
||||
try:
|
||||
results = await RealtimeRepository.query_simulation_result_by_id_time(
|
||||
conn, id, type, query_time
|
||||
)
|
||||
return {"result": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -0,0 +1,159 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from app.infra.db.timescaledb.repositories.scada import ScadaRepository
|
||||
from .dependencies import get_timescale_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/scada/batch", status_code=201, summary="批量插入SCADA监测数据")
|
||||
async def insert_scada_data(
|
||||
data: List[dict] = Body(..., description="SCADA设备监测数据列表"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
批量插入SCADA监测数据
|
||||
|
||||
将多个设备的实时监测数据批量插入时间序列数据库。
|
||||
|
||||
Args:
|
||||
data: SCADA设备监测数据列表,每项包含device_id、时间戳和监测值等信息
|
||||
|
||||
Returns:
|
||||
插入成功的记录数
|
||||
"""
|
||||
await ScadaRepository.insert_scada_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/scada/by-ids-time-range", summary="按设备ID和时间范围查询SCADA数据")
|
||||
async def get_scada_by_ids_time_range(
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
device_ids: str = Query(
|
||||
..., description="设备ID列表,逗号分隔,如 'device1,device2,device3'"
|
||||
),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
按设备ID和时间范围查询SCADA监测数据
|
||||
|
||||
查询多个设备在指定时间范围内的所有监测数据。
|
||||
|
||||
Args:
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
device_ids: 设备ID列表,用逗号分隔
|
||||
|
||||
Returns:
|
||||
SCADA监测数据列表
|
||||
"""
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()] if device_ids else []
|
||||
)
|
||||
return await ScadaRepository.get_scada_by_ids_time_range(
|
||||
conn, device_ids_list, start_time, end_time
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scada/by-ids-field-time-range", summary="按设备ID、字段和时间范围查询SCADA数据"
|
||||
)
|
||||
async def get_scada_field_by_ids_time_range(
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
field: str = Query(..., description="要查询的字段名称"),
|
||||
device_ids: str = Query(
|
||||
..., description="设备ID列表,逗号分隔,如 'device1,device2,device3'"
|
||||
),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
按设备ID、字段和时间范围查询特定SCADA数据
|
||||
|
||||
查询多个设备在指定时间范围内的特定字段监测数据。
|
||||
|
||||
Args:
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
field: 字段名称
|
||||
device_ids: 设备ID列表,用逗号分隔
|
||||
|
||||
Returns:
|
||||
SCADA字段数据列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当字段不存在或查询参数无效时返回400错误
|
||||
"""
|
||||
try:
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()]
|
||||
if device_ids
|
||||
else []
|
||||
)
|
||||
return await ScadaRepository.get_scada_field_by_id_time_range(
|
||||
conn, device_ids_list, start_time, end_time, field
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/scada/{device_id}/field", summary="更新SCADA设备字段")
|
||||
async def update_scada_field(
|
||||
device_id: str = Path(..., description="设备ID"),
|
||||
time: datetime = Query(..., description="更新数据的时间戳"),
|
||||
field: str = Query(..., description="要更新的字段名称"),
|
||||
value: float = Query(..., description="更新的字段值"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
更新指定设备的字段值
|
||||
|
||||
更新SCADA设备在特定时间的某个字段监测数据。
|
||||
|
||||
Args:
|
||||
device_id: 设备ID
|
||||
time: 数据时间戳
|
||||
field: 字段名称
|
||||
value: 字段新值
|
||||
|
||||
Returns:
|
||||
更新结果信息
|
||||
|
||||
Raises:
|
||||
HTTPException: 当字段不存在或更新失败时返回400错误
|
||||
"""
|
||||
try:
|
||||
await ScadaRepository.update_scada_field(conn, time, device_id, field, value)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/scada/by-id-time-range", summary="按设备ID和时间范围删除SCADA数据")
|
||||
async def delete_scada_data(
|
||||
device_id: str = Query(..., description="设备ID"),
|
||||
start_time: datetime = Query(..., description="删除开始时间"),
|
||||
end_time: datetime = Query(..., description="删除结束时间"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
删除指定设备和时间范围内的SCADA数据
|
||||
|
||||
删除在指定时间范围内的特定设备监测数据。
|
||||
|
||||
Args:
|
||||
device_id: 设备ID
|
||||
start_time: 删除开始时间
|
||||
end_time: 删除结束时间
|
||||
|
||||
Returns:
|
||||
删除结果信息
|
||||
"""
|
||||
await ScadaRepository.delete_scada_by_id_time_range(
|
||||
conn, device_id, start_time, end_time
|
||||
)
|
||||
return {"message": "Deleted successfully"}
|
||||
@@ -0,0 +1,391 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Body
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from app.infra.db.timescaledb.repositories.scheme import SchemeRepository
|
||||
from .dependencies import get_timescale_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/scheme/links/batch", status_code=201, summary="批量插入方案管道数据")
|
||||
async def insert_scheme_links(
|
||||
data: List[dict] = Body(..., description="方案管道数据列表"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
批量插入方案管道数据
|
||||
|
||||
将特定方案的管道模拟数据批量插入时间序列数据库。
|
||||
|
||||
Args:
|
||||
data: 方案管道数据列表
|
||||
|
||||
Returns:
|
||||
插入成功的记录数
|
||||
"""
|
||||
await SchemeRepository.insert_links_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/scheme/links", summary="查询方案管道数据")
|
||||
async def get_scheme_links(
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
查询指定方案和时间范围内的管道数据
|
||||
|
||||
根据方案和时间范围查询管道的模拟值。
|
||||
|
||||
Args:
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
|
||||
Returns:
|
||||
方案管道数据列表
|
||||
"""
|
||||
return await SchemeRepository.get_links_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scheme/links/{link_id}/field", summary="查询方案管道字段数据")
|
||||
async def get_scheme_link_field(
|
||||
link_id: str = Path(..., description="管道ID"),
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
field: str = Query(..., description="要查询的字段名称"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
查询指定方案管道的特定字段数据
|
||||
|
||||
查询特定方案中指定管道在时间范围内的特定字段值。
|
||||
|
||||
Args:
|
||||
link_id: 管道ID
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
field: 字段名称
|
||||
|
||||
Returns:
|
||||
字段数据列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误
|
||||
"""
|
||||
try:
|
||||
return await SchemeRepository.get_link_field_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, link_id, field
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/scheme/links/{link_id}/field", summary="更新方案管道字段")
|
||||
async def update_scheme_link_field(
|
||||
link_id: str = Path(..., description="管道ID"),
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
time: datetime = Query(..., description="更新数据的时间戳"),
|
||||
field: str = Query(..., description="要更新的字段名称"),
|
||||
value: float = Query(..., description="更新的字段值"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
更新指定方案管道的字段值
|
||||
|
||||
更新特定方案中指定管道在某个时间的字段数据。
|
||||
|
||||
Args:
|
||||
link_id: 管道ID
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
time: 数据时间戳
|
||||
field: 字段名称
|
||||
value: 字段新值
|
||||
|
||||
Returns:
|
||||
更新结果信息
|
||||
|
||||
Raises:
|
||||
HTTPException: 当字段不存在或更新失败时返回400错误
|
||||
"""
|
||||
try:
|
||||
await SchemeRepository.update_link_field(
|
||||
conn, time, scheme_type, scheme_name, link_id, field, value
|
||||
)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/scheme/links", summary="删除方案管道数据")
|
||||
async def delete_scheme_links(
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
start_time: datetime = Query(..., description="删除开始时间"),
|
||||
end_time: datetime = Query(..., description="删除结束时间"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
删除指定方案和时间范围内的管道数据
|
||||
|
||||
删除在指定方案和时间范围内的所有管道模拟数据。
|
||||
|
||||
Args:
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
start_time: 删除开始时间
|
||||
end_time: 删除结束时间
|
||||
|
||||
Returns:
|
||||
删除结果信息
|
||||
"""
|
||||
await SchemeRepository.delete_links_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time
|
||||
)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/scheme/nodes/batch", status_code=201, summary="批量插入方案节点数据")
|
||||
async def insert_scheme_nodes(
|
||||
data: List[dict] = Body(..., description="方案节点数据列表"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
批量插入方案节点数据
|
||||
|
||||
将特定方案的节点模拟数据批量插入时间序列数据库。
|
||||
|
||||
Args:
|
||||
data: 方案节点数据列表
|
||||
|
||||
Returns:
|
||||
插入成功的记录数
|
||||
"""
|
||||
await SchemeRepository.insert_nodes_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/scheme/nodes/{node_id}/field", summary="查询方案节点字段数据")
|
||||
async def get_scheme_node_field(
|
||||
node_id: str = Path(..., description="节点ID"),
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
start_time: datetime = Query(..., description="查询开始时间"),
|
||||
end_time: datetime = Query(..., description="查询结束时间"),
|
||||
field: str = Query(..., description="要查询的字段名称"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
查询指定方案节点的特定字段数据
|
||||
|
||||
查询特定方案中指定节点在时间范围内的特定字段值。
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
start_time: 查询开始时间
|
||||
end_time: 查询结束时间
|
||||
field: 字段名称
|
||||
|
||||
Returns:
|
||||
字段数据列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误
|
||||
"""
|
||||
try:
|
||||
return await SchemeRepository.get_node_field_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, node_id, field
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/scheme/nodes/{node_id}/field", summary="更新方案节点字段")
|
||||
async def update_scheme_node_field(
|
||||
node_id: str = Path(..., description="节点ID"),
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
time: datetime = Query(..., description="更新数据的时间戳"),
|
||||
field: str = Query(..., description="要更新的字段名称"),
|
||||
value: float = Query(..., description="更新的字段值"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
更新指定方案节点的字段值
|
||||
|
||||
更新特定方案中指定节点在某个时间的字段数据。
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
time: 数据时间戳
|
||||
field: 字段名称
|
||||
value: 字段新值
|
||||
|
||||
Returns:
|
||||
更新结果信息
|
||||
|
||||
Raises:
|
||||
HTTPException: 当字段不存在或更新失败时返回400错误
|
||||
"""
|
||||
try:
|
||||
await SchemeRepository.update_node_field(
|
||||
conn, time, scheme_type, scheme_name, node_id, field, value
|
||||
)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/scheme/nodes", summary="删除方案节点数据")
|
||||
async def delete_scheme_nodes(
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
start_time: datetime = Query(..., description="删除开始时间"),
|
||||
end_time: datetime = Query(..., description="删除结束时间"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
删除指定方案和时间范围内的节点数据
|
||||
|
||||
删除在指定方案和时间范围内的所有节点模拟数据。
|
||||
|
||||
Args:
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
start_time: 删除开始时间
|
||||
end_time: 删除结束时间
|
||||
|
||||
Returns:
|
||||
删除结果信息
|
||||
"""
|
||||
await SchemeRepository.delete_nodes_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time
|
||||
)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/scheme/simulation/store", status_code=201, summary="存储方案模拟结果")
|
||||
async def store_scheme_simulation_result(
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
node_result_list: List[dict] = Body(..., description="节点模拟结果列表"),
|
||||
link_result_list: List[dict] = Body(..., description="管道模拟结果列表"),
|
||||
result_start_time: str = Query(..., description="模拟结果开始时间"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
存储方案模拟结果到时间序列数据库
|
||||
|
||||
将特定方案的节点和管道模拟计算结果批量存储到TimescaleDB数据库。
|
||||
|
||||
Args:
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
node_result_list: 节点模拟结果列表
|
||||
link_result_list: 管道模拟结果列表
|
||||
result_start_time: 模拟结果对应的起始时间
|
||||
|
||||
Returns:
|
||||
存储结果信息
|
||||
"""
|
||||
await SchemeRepository.store_scheme_simulation_result(
|
||||
conn,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
node_result_list,
|
||||
link_result_list,
|
||||
result_start_time,
|
||||
)
|
||||
return {"message": "Scheme simulation results stored successfully"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/scheme/query/by-scheme-time-property", summary="按方案、时间和属性查询数据"
|
||||
)
|
||||
async def query_scheme_records_by_scheme_time_property(
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
query_time: str = Query(..., description="查询时间"),
|
||||
type: str = Query(..., description="元素类型,pipe(管道)或 junction(节点)"),
|
||||
property: str = Query(..., description="要查询的属性名称"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
按指定方案、时间和属性查询所有方案数据
|
||||
|
||||
查询在特定方案和时间点,所有指定类型元素的特定属性值。
|
||||
|
||||
Args:
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
query_time: 查询时间
|
||||
type: 元素类型(pipe或junction)
|
||||
property: 属性名称
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误
|
||||
"""
|
||||
try:
|
||||
results = await SchemeRepository.query_all_record_by_scheme_time_property(
|
||||
conn, scheme_type, scheme_name, query_time, type, property
|
||||
)
|
||||
return {"results": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/scheme/query/by-id-time", summary="按ID和时间查询方案模拟数据")
|
||||
async def query_scheme_simulation_by_id_time(
|
||||
scheme_type: str = Query(..., description="方案类型"),
|
||||
scheme_name: str = Query(..., description="方案名称"),
|
||||
id: str = Query(..., description="元素ID(管道ID或节点ID)"),
|
||||
type: str = Query(..., description="元素类型,pipe(管道)或 junction(节点)"),
|
||||
query_time: str = Query(..., description="查询时间"),
|
||||
conn: AsyncConnection = Depends(get_timescale_connection),
|
||||
):
|
||||
"""
|
||||
按指定ID和时间查询方案模拟结果
|
||||
|
||||
查询特定方案中的元素在某一时间点的模拟数据。
|
||||
|
||||
Args:
|
||||
scheme_type: 方案类型
|
||||
scheme_name: 方案名称
|
||||
id: 元素ID
|
||||
type: 元素类型(pipe或junction)
|
||||
query_time: 查询时间
|
||||
|
||||
Returns:
|
||||
模拟结果数据
|
||||
|
||||
Raises:
|
||||
HTTPException: 当查询参数无效时返回400错误
|
||||
"""
|
||||
try:
|
||||
result = await SchemeRepository.query_scheme_simulation_result_by_id_time(
|
||||
conn, scheme_type, scheme_name, id, type, query_time
|
||||
)
|
||||
return {"result": result}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -3,178 +3,213 @@
|
||||
|
||||
演示权限控制的使用
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
|
||||
from app.domain.schemas.user import UserResponse, UserUpdate, UserCreate
|
||||
from app.domain.models.role import UserRole
|
||||
from app.domain.schemas.user import UserInDB
|
||||
from app.infra.repositories.user_repository import UserRepository
|
||||
from app.infra.db.metadb.repositories.user_repository import UserRepository
|
||||
from app.auth.dependencies import get_user_repository, get_current_active_user
|
||||
from app.auth.permissions import get_current_admin, require_role, check_resource_owner
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[UserResponse])
|
||||
|
||||
@router.get(
|
||||
"/",
|
||||
summary="列出所有用户",
|
||||
description="获取用户列表(仅管理员)",
|
||||
response_model=List[UserResponse],
|
||||
)
|
||||
async def list_users(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
skip: int = Query(0, ge=0, description="跳过的用户数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大用户数"),
|
||||
current_user: UserInDB = Depends(require_role(UserRole.ADMIN)),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> List[UserResponse]:
|
||||
"""
|
||||
获取用户列表(仅管理员)
|
||||
获取用户列表
|
||||
|
||||
获取系统中所有的用户信息(需要管理员权限)
|
||||
"""
|
||||
users = await user_repo.get_all_users(skip=skip, limit=limit)
|
||||
return [UserResponse.model_validate(user) for user in users]
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
|
||||
@router.get(
|
||||
"/{user_id}",
|
||||
summary="获取用户详情",
|
||||
description="获取指定用户的详细信息",
|
||||
response_model=UserResponse,
|
||||
)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
user_id: int = Path(..., gt=0, description="用户ID"),
|
||||
current_user: UserInDB = Depends(get_current_active_user),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> UserResponse:
|
||||
"""
|
||||
获取用户详情
|
||||
|
||||
|
||||
管理员可查看所有用户,普通用户只能查看自己
|
||||
"""
|
||||
# 检查权限
|
||||
if not check_resource_owner(user_id, current_user):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to view this user"
|
||||
detail="You don't have permission to view this user",
|
||||
)
|
||||
|
||||
|
||||
user = await user_repo.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
|
||||
@router.put(
|
||||
"/{user_id}",
|
||||
summary="更新用户信息",
|
||||
description="更新指定用户的信息",
|
||||
response_model=UserResponse,
|
||||
)
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_update: UserUpdate,
|
||||
user_id: int = Path(..., gt=0, description="用户ID"),
|
||||
user_update: UserUpdate = None,
|
||||
current_user: UserInDB = Depends(get_current_active_user),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> UserResponse:
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
|
||||
管理员可更新所有用户,普通用户只能更新自己(且不能修改角色)
|
||||
"""
|
||||
# 检查用户是否存在
|
||||
target_user = await user_repo.get_user_by_id(user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
# 权限检查
|
||||
is_owner = current_user.id == user_id
|
||||
is_admin = UserRole(current_user.role).has_permission(UserRole.ADMIN)
|
||||
|
||||
|
||||
if not is_owner and not is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to update this user"
|
||||
detail="You don't have permission to update this user",
|
||||
)
|
||||
|
||||
|
||||
# 非管理员不能修改角色和激活状态
|
||||
if not is_admin:
|
||||
if user_update.role is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admins can change user roles"
|
||||
detail="Only admins can change user roles",
|
||||
)
|
||||
if user_update.is_active is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admins can change user active status"
|
||||
detail="Only admins can change user active status",
|
||||
)
|
||||
|
||||
|
||||
# 更新用户
|
||||
updated_user = await user_repo.update_user(user_id, user_update)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user"
|
||||
detail="Failed to update user",
|
||||
)
|
||||
|
||||
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
|
||||
@router.delete("/{user_id}", summary="删除用户", description="删除指定用户(仅管理员)")
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
user_id: int = Path(..., gt=0, description="用户ID"),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> dict:
|
||||
"""
|
||||
删除用户(仅管理员)
|
||||
删除用户
|
||||
|
||||
删除指定用户(需要管理员权限,不能删除自己)
|
||||
"""
|
||||
# 不能删除自己
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You cannot delete your own account"
|
||||
detail="You cannot delete your own account",
|
||||
)
|
||||
|
||||
|
||||
success = await user_repo.delete_user(user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
return {"message": "User deleted successfully"}
|
||||
|
||||
@router.post("/{user_id}/activate")
|
||||
|
||||
@router.post(
|
||||
"/{user_id}/activate",
|
||||
summary="激活用户",
|
||||
description="激活指定用户账户(仅管理员)",
|
||||
response_model=UserResponse,
|
||||
)
|
||||
async def activate_user(
|
||||
user_id: int,
|
||||
user_id: int = Path(..., gt=0, description="用户ID"),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> UserResponse:
|
||||
"""
|
||||
激活用户(仅管理员)
|
||||
激活用户
|
||||
|
||||
激活指定用户的账户(需要管理员权限)
|
||||
"""
|
||||
user_update = UserUpdate(is_active=True)
|
||||
updated_user = await user_repo.update_user(user_id, user_update)
|
||||
|
||||
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
@router.post("/{user_id}/deactivate")
|
||||
|
||||
@router.post(
|
||||
"/{user_id}/deactivate",
|
||||
summary="停用用户",
|
||||
description="停用指定用户账户(仅管理员)",
|
||||
response_model=UserResponse,
|
||||
)
|
||||
async def deactivate_user(
|
||||
user_id: int,
|
||||
user_id: int = Path(..., gt=0, description="用户ID"),
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> UserResponse:
|
||||
"""
|
||||
停用用户(仅管理员)
|
||||
停用用户
|
||||
|
||||
停用指定用户的账户(需要管理员权限,不能停用自己)
|
||||
"""
|
||||
# 不能停用自己
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You cannot deactivate your own account"
|
||||
detail="You cannot deactivate your own account",
|
||||
)
|
||||
|
||||
|
||||
user_update = UserUpdate(is_active=False)
|
||||
updated_user = await user_repo.update_user(user_id, user_update)
|
||||
|
||||
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Query
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import Any, get_all_users, get_user, get_user_schema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -8,14 +8,29 @@ router = APIRouter()
|
||||
# user 39
|
||||
###########################################################
|
||||
|
||||
@router.get("/getuserschema/")
|
||||
async def fastapi_get_user_schema(network: str) -> dict[str, dict[Any, Any]]:
|
||||
@router.get("/getuserschema/", summary="获取用户模式", description="获取指定网络的用户模式定义")
|
||||
async def fastapi_get_user_schema(network: str = Query(..., description="管网名称(或数据库名称)")) -> dict[str, dict[Any, Any]]:
|
||||
"""
|
||||
获取用户模式定义
|
||||
|
||||
返回指定网络的用户模式结构定义
|
||||
"""
|
||||
return get_user_schema(network)
|
||||
|
||||
@router.get("/getuser/")
|
||||
async def fastapi_get_user(network: str, user_name: str) -> dict[Any, Any]:
|
||||
@router.get("/getuser/", summary="获取单个用户", description="获取指定网络中的单个用户信息")
|
||||
async def fastapi_get_user(network: str = Query(..., description="管网名称(或数据库名称)"), user_name: str = Query(..., description="用户名")) -> dict[Any, Any]:
|
||||
"""
|
||||
获取用户信息
|
||||
|
||||
返回指定网络中指定用户名的详细信息
|
||||
"""
|
||||
return get_user(network, user_name)
|
||||
|
||||
@router.get("/getallusers/")
|
||||
async def fastapi_get_all_users(network: str) -> list[dict[Any, Any]]:
|
||||
@router.get("/getallusers/", summary="获取所有用户", description="获取指定网络的所有用户列表")
|
||||
async def fastapi_get_all_users(network: str = Query(..., description="管网名称(或数据库名称)")) -> list[dict[Any, Any]]:
|
||||
"""
|
||||
获取所有用户列表
|
||||
|
||||
返回指定网络中所有用户的信息
|
||||
"""
|
||||
return get_all_users(network)
|
||||
+31
-9
@@ -6,12 +6,15 @@ from app.api.v1.endpoints import (
|
||||
scada,
|
||||
extension,
|
||||
snapshots,
|
||||
data_query,
|
||||
# data_query,
|
||||
users,
|
||||
schemes,
|
||||
misc,
|
||||
risk,
|
||||
cache,
|
||||
leakage,
|
||||
burst_detection,
|
||||
burst_location,
|
||||
user_management, # 新增:用户管理
|
||||
audit, # 新增:审计日志
|
||||
meta,
|
||||
@@ -38,14 +41,21 @@ from app.api.v1.endpoints.components import (
|
||||
visuals,
|
||||
)
|
||||
|
||||
from app.infra.db.postgresql import router as postgresql_router
|
||||
from app.infra.db.timescaledb import router as timescaledb_router
|
||||
from app.api.v1.endpoints import project_data
|
||||
from app.api.v1.endpoints.timeseries import (
|
||||
realtime as ts_realtime,
|
||||
scheme as ts_scheme,
|
||||
scada as ts_scada,
|
||||
composite as ts_composite,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# Core Services
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
||||
api_router.include_router(
|
||||
user_management.router, prefix="/users", tags=["User Management"]
|
||||
) # 新增
|
||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
||||
api_router.include_router(meta.router, tags=["Metadata"])
|
||||
api_router.include_router(project.router, tags=["Project"])
|
||||
@@ -75,18 +85,30 @@ api_router.include_router(visuals.router, tags=["Visuals"])
|
||||
|
||||
# Simulation & Data
|
||||
api_router.include_router(simulation.router, tags=["Simulation Control"])
|
||||
api_router.include_router(data_query.router, tags=["Data Query & InfluxDB"])
|
||||
api_router.include_router(scada.router, tags=["SCADA"])
|
||||
# api_router.include_router(data_query.router, tags=["Data Query & InfluxDB"])
|
||||
api_router.include_router(scada.router)
|
||||
api_router.include_router(snapshots.router, tags=["Snapshots"])
|
||||
api_router.include_router(users.router, tags=["Users"])
|
||||
api_router.include_router(schemes.router, tags=["Schemes"])
|
||||
api_router.include_router(misc.router, tags=["Misc"])
|
||||
api_router.include_router(risk.router, tags=["Risk"])
|
||||
api_router.include_router(cache.router, tags=["Cache"])
|
||||
api_router.include_router(leakage.router, prefix="/leakage", tags=["Leakage"])
|
||||
api_router.include_router(
|
||||
burst_detection.router, prefix="/burst-detection", tags=["Burst Detection"]
|
||||
)
|
||||
api_router.include_router(
|
||||
burst_location.router, prefix="/burst-location", tags=["Burst Location"]
|
||||
)
|
||||
|
||||
# Database Routers
|
||||
api_router.include_router(timescaledb_router, tags=["TimescaleDB"])
|
||||
api_router.include_router(postgresql_router, tags=["PostgreSQL"])
|
||||
# TimescaleDB Data Access
|
||||
api_router.include_router(ts_realtime.router, tags=["TimescaleDB - Realtime"])
|
||||
api_router.include_router(ts_scheme.router, tags=["TimescaleDB - Scheme"])
|
||||
api_router.include_router(ts_scada.router, tags=["TimescaleDB - SCADA"])
|
||||
api_router.include_router(ts_composite.router, tags=["TimescaleDB - Composite"])
|
||||
|
||||
# Project Data (PostgreSQL)
|
||||
api_router.include_router(project_data.router, tags=["Project Data"])
|
||||
|
||||
# Extension
|
||||
api_router.include_router(extension.router, tags=["Extension"])
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
This module is reserved for future implementation of audit logging and compliance features.
|
||||
|
||||
Current implementation of audit logging can be found in `app.core.audit`.
|
||||
Future expansion may include:
|
||||
- Comprehensive audit trails for all system actions
|
||||
- Compliance reporting (e.g., for industrial control systems)
|
||||
- Anomaly detection in user behavior
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt, JWTError
|
||||
from app.core.config import settings
|
||||
from app.domain.schemas.user import UserInDB, TokenPayload
|
||||
from app.infra.repositories.user_repository import UserRepository
|
||||
from app.infra.db.metadb.repositories.user_repository import UserRepository
|
||||
from app.infra.db.postgresql.database import Database
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
|
||||
|
||||
@@ -61,3 +61,43 @@ async def get_current_keycloak_sub(
|
||||
detail="Invalid subject claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
|
||||
async def get_current_keycloak_username(
|
||||
token: str | None = Depends(oauth2_optional),
|
||||
) -> str:
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if settings.KEYCLOAK_PUBLIC_KEY:
|
||||
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
algorithms = [settings.KEYCLOAK_ALGORITHM]
|
||||
else:
|
||||
key = settings.SECRET_KEY
|
||||
algorithms = [settings.ALGORITHM]
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=algorithms,
|
||||
audience=settings.KEYCLOAK_AUDIENCE or None,
|
||||
)
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
username = payload.get("preferred_username") or payload.get("username")
|
||||
if not username:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing username claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return str(username)
|
||||
|
||||
@@ -8,8 +8,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
from app.infra.db.metadb.database import get_metadata_session
|
||||
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.dynamic_manager import project_connection_manager
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
from app.infra.db.metadb.database import get_metadata_session
|
||||
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
DB_ROLE_BIZ_DATA = "biz_data"
|
||||
DB_ROLE_IOT_DATA = "iot_data"
|
||||
|
||||
+2
-2
@@ -66,8 +66,8 @@ async def log_audit_event(
|
||||
response_status: 响应状态码
|
||||
session: 元数据库会话(可选)
|
||||
"""
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
from app.infra.db.metadb.database import SessionLocal
|
||||
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
|
||||
|
||||
if request_data:
|
||||
request_data = sanitize_sensitive_data(request_data)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import quote_plus
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "TJWater Server"
|
||||
ENVIRONMENT: str = "production"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
NETWORK_NAME: str = "default_network"
|
||||
|
||||
# JWT 配置
|
||||
SECRET_KEY: str = (
|
||||
"your-secret-key-here-change-in-production-use-openssl-rand-hex-32"
|
||||
@@ -80,3 +84,71 @@ class Settings(BaseSettings):
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_pgconn_string(
|
||||
db_name: Optional[str] = None,
|
||||
db_host: Optional[str] = None,
|
||||
db_port: Optional[str] = None,
|
||||
db_user: Optional[str] = None,
|
||||
db_password: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Return PostgreSQL connection string in psycopg conninfo format."""
|
||||
resolved_db_name = db_name or settings.DB_NAME
|
||||
resolved_db_host = db_host or settings.DB_HOST
|
||||
resolved_db_port = db_port or settings.DB_PORT
|
||||
resolved_db_user = db_user or settings.DB_USER
|
||||
resolved_db_password = db_password or settings.DB_PASSWORD
|
||||
return (
|
||||
f"dbname={resolved_db_name} host={resolved_db_host} port={resolved_db_port} "
|
||||
f"user={resolved_db_user} password={resolved_db_password}"
|
||||
)
|
||||
|
||||
|
||||
def get_pg_config() -> dict:
|
||||
"""Return PostgreSQL configuration except password."""
|
||||
return {
|
||||
"name": settings.DB_NAME,
|
||||
"host": settings.DB_HOST,
|
||||
"port": settings.DB_PORT,
|
||||
"user": settings.DB_USER,
|
||||
}
|
||||
|
||||
|
||||
def get_pg_password() -> str:
|
||||
"""Return PostgreSQL password (use with care)."""
|
||||
return settings.DB_PASSWORD
|
||||
|
||||
|
||||
def get_timescaledb_pgconn_string(
|
||||
db_name: Optional[str] = None,
|
||||
db_host: Optional[str] = None,
|
||||
db_port: Optional[str] = None,
|
||||
db_user: Optional[str] = None,
|
||||
db_password: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Return TimescaleDB connection string in psycopg conninfo format."""
|
||||
resolved_db_name = db_name or settings.TIMESCALEDB_DB_NAME
|
||||
resolved_db_host = db_host or settings.TIMESCALEDB_DB_HOST
|
||||
resolved_db_port = db_port or settings.TIMESCALEDB_DB_PORT
|
||||
resolved_db_user = db_user or settings.TIMESCALEDB_DB_USER
|
||||
resolved_db_password = db_password or settings.TIMESCALEDB_DB_PASSWORD
|
||||
return (
|
||||
f"dbname={resolved_db_name} host={resolved_db_host} port={resolved_db_port} "
|
||||
f"user={resolved_db_user} password={resolved_db_password}"
|
||||
)
|
||||
|
||||
|
||||
def get_timescaledb_pg_config() -> dict:
|
||||
"""Return TimescaleDB configuration except password."""
|
||||
return {
|
||||
"name": settings.TIMESCALEDB_DB_NAME,
|
||||
"host": settings.TIMESCALEDB_DB_HOST,
|
||||
"port": settings.TIMESCALEDB_DB_PORT,
|
||||
"user": settings.TIMESCALEDB_DB_USER,
|
||||
}
|
||||
|
||||
|
||||
def get_timescaledb_pg_password() -> str:
|
||||
"""Return TimescaleDB password (use with care)."""
|
||||
return settings.TIMESCALEDB_DB_PASSWORD
|
||||
|
||||
+10
-6
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
from jose import jwt
|
||||
@@ -8,6 +8,10 @@ from app.core.config import settings
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
@@ -22,9 +26,9 @@ def create_access_token(
|
||||
JWT token 字符串
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now() + expires_delta
|
||||
expire = _utc_now() + expires_delta
|
||||
else:
|
||||
expire = datetime.now() + timedelta(
|
||||
expire = _utc_now() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
@@ -32,7 +36,7 @@ def create_access_token(
|
||||
"exp": expire,
|
||||
"sub": str(subject),
|
||||
"type": "access",
|
||||
"iat": datetime.now(),
|
||||
"iat": _utc_now(),
|
||||
}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
@@ -50,13 +54,13 @@ def create_refresh_token(subject: Union[str, Any]) -> str:
|
||||
Returns:
|
||||
JWT refresh token 字符串
|
||||
"""
|
||||
expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = _utc_now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"exp": expire,
|
||||
"sub": str(subject),
|
||||
"type": "refresh",
|
||||
"iat": datetime.now(),
|
||||
"iat": _utc_now(),
|
||||
}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
This module is reserved for future implementation of advanced cryptographic operations.
|
||||
|
||||
Current basic encryption (Fernet) and password hashing are implemented in `app.core.encryption` and `app.core.security`.
|
||||
Future expansion may include:
|
||||
- Asymmetric encryption (RSA/ECC) for secure communication
|
||||
- Key management and rotation services
|
||||
- Integration with Hardware Security Modules (HSM)
|
||||
- Digital signatures for data integrity verification
|
||||
"""
|
||||
|
||||
@@ -5,10 +5,10 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeoServerConfigResponse(BaseModel):
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_base_url: Optional[str] = None
|
||||
gs_admin_user: Optional[str] = None
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
default_extent: Optional[dict] = None
|
||||
srid: int
|
||||
|
||||
|
||||
@@ -16,18 +16,19 @@ class ProjectMetaResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
description: Optional[str] = None
|
||||
gs_workspace: str
|
||||
map_extent: Optional[dict] = None
|
||||
status: str
|
||||
project_role: str
|
||||
geoserver: Optional[GeoServerConfigResponse]
|
||||
geoserver: Optional[GeoServerConfigResponse] = None
|
||||
|
||||
|
||||
class ProjectSummaryResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
description: Optional[str] = None
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
|
||||
@@ -15,8 +15,8 @@ import logging
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
from app.infra.db.metadb.database import SessionLocal
|
||||
from app.infra.db.metadb.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,11 +55,26 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# 需要审计的HTTP方法
|
||||
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
|
||||
EXCLUDED_PATHS = {
|
||||
"/api/v1/meta/projects",
|
||||
"/meta/projects",
|
||||
"/api/v1/openproject/",
|
||||
"/openproject/",
|
||||
}
|
||||
EXCLUDED_PATH_PREFIXES = (
|
||||
)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 提取开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 流式 Copilot 请求前置排除,避免读取/改写 body 影响 SSE 生命周期
|
||||
if self._is_excluded_path(request.url.path):
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
|
||||
# 1. 预判是否需要读取Body (针对写操作)
|
||||
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
|
||||
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
|
||||
@@ -68,13 +83,24 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
if should_capture_body:
|
||||
try:
|
||||
# 注意:读取 body 后需要重新设置,避免影响后续处理
|
||||
original_receive = request._receive
|
||||
body = await request.body()
|
||||
if body:
|
||||
request_data = json.loads(body.decode())
|
||||
|
||||
# 重新构造请求以供后续使用
|
||||
# 重新构造请求以供后续使用:仅回放一次,后续回落原始 receive
|
||||
body_sent = False
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body}
|
||||
nonlocal body_sent
|
||||
if not body_sent:
|
||||
body_sent = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": body,
|
||||
"more_body": False,
|
||||
}
|
||||
return await original_receive()
|
||||
|
||||
request._receive = receive
|
||||
except Exception as e:
|
||||
@@ -84,6 +110,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
response = await call_next(request)
|
||||
|
||||
# 3. 决定是否审计
|
||||
if self._is_excluded_path(request.url.path):
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
|
||||
# 检查方法
|
||||
is_audit_method = request.method in self.AUDIT_METHODS
|
||||
# 检查路径
|
||||
@@ -139,6 +170,11 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
return response
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
if path in self.EXCLUDED_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in self.EXCLUDED_PATH_PREFIXES)
|
||||
|
||||
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||
project_header = request.headers.get("X-Project-Id")
|
||||
if not project_header:
|
||||
@@ -155,6 +191,7 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
return None
|
||||
sub = None
|
||||
try:
|
||||
key = (
|
||||
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
@@ -166,17 +203,25 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else [settings.ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=algorithms,
|
||||
audience=settings.KEYCLOAK_AUDIENCE or None,
|
||||
)
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
return None
|
||||
keycloak_id = UUID(sub)
|
||||
except (JWTError, ValueError):
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
async with SessionLocal() as session:
|
||||
repo = MetadataRepository(session)
|
||||
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||
try:
|
||||
keycloak_id = UUID(sub)
|
||||
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||
except ValueError:
|
||||
user = await repo.get_user_by_username(sub)
|
||||
if user and user.is_active:
|
||||
return user.id
|
||||
return None
|
||||
@@ -218,7 +263,25 @@ class AuditMiddleware(BaseHTTPMiddleware):
|
||||
if len(path_parts) >= 4:
|
||||
resource_type = path_parts[3].rstrip("s") # 移除复数s
|
||||
|
||||
if len(path_parts) >= 5 and path_parts[4].isdigit():
|
||||
if len(path_parts) >= 5 and path_parts[4]:
|
||||
resource_id = path_parts[4]
|
||||
|
||||
# 无路径ID时,尝试从查询参数提取业务ID
|
||||
if not resource_id:
|
||||
for key in (
|
||||
"id",
|
||||
"resource_id",
|
||||
"device_id",
|
||||
"device_ids",
|
||||
"element_id",
|
||||
"user_id",
|
||||
"project_id",
|
||||
"network",
|
||||
"name",
|
||||
):
|
||||
value = request.query_params.get(key)
|
||||
if value:
|
||||
resource_id = value
|
||||
break
|
||||
|
||||
return resource_type, resource_id
|
||||
|
||||
@@ -24,6 +24,17 @@ logger = logging.getLogger(__name__)
|
||||
class PgEngineEntry:
|
||||
engine: AsyncEngine
|
||||
sessionmaker: async_sessionmaker[AsyncSession]
|
||||
connection_url: str
|
||||
pool_min_size: int
|
||||
pool_max_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PoolEntry:
|
||||
pool: AsyncConnectionPool
|
||||
connection_url: str
|
||||
pool_min_size: int
|
||||
pool_max_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -35,8 +46,8 @@ class CacheKey:
|
||||
class ProjectConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
|
||||
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._ts_cache: Dict[CacheKey, PoolEntry] = OrderedDict()
|
||||
self._pg_raw_cache: Dict[CacheKey, PoolEntry] = OrderedDict()
|
||||
self._pg_lock = asyncio.Lock()
|
||||
self._ts_lock = asyncio.Lock()
|
||||
self._pg_raw_lock = asyncio.Lock()
|
||||
@@ -56,15 +67,29 @@ class ProjectConnectionManager:
|
||||
pool_max_size: int,
|
||||
) -> async_sessionmaker[AsyncSession]:
|
||||
async with self._pg_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
entry = self._pg_cache.get(key)
|
||||
if entry:
|
||||
self._pg_cache.move_to_end(key)
|
||||
return entry.sessionmaker
|
||||
|
||||
normalized_url = self._normalize_pg_url(connection_url)
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
entry = self._pg_cache.get(key)
|
||||
if entry:
|
||||
if (
|
||||
entry.connection_url == normalized_url
|
||||
and entry.pool_min_size == pool_min_size
|
||||
and entry.pool_max_size == pool_max_size
|
||||
):
|
||||
self._pg_cache.move_to_end(key)
|
||||
return entry.sessionmaker
|
||||
|
||||
await entry.engine.dispose()
|
||||
logger.info(
|
||||
"Rebuilding PostgreSQL engine for project %s (%s) due to config change",
|
||||
project_id,
|
||||
db_role,
|
||||
)
|
||||
self._pg_cache.pop(key, None)
|
||||
|
||||
engine = create_async_engine(
|
||||
normalized_url,
|
||||
pool_size=pool_min_size,
|
||||
@@ -75,6 +100,9 @@ class ProjectConnectionManager:
|
||||
self._pg_cache[key] = PgEngineEntry(
|
||||
engine=engine,
|
||||
sessionmaker=sessionmaker,
|
||||
connection_url=normalized_url,
|
||||
pool_min_size=pool_min_size,
|
||||
pool_max_size=pool_max_size,
|
||||
)
|
||||
await self._evict_pg_if_needed()
|
||||
logger.info(
|
||||
@@ -91,14 +119,28 @@ class ProjectConnectionManager:
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._ts_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._ts_cache.get(key)
|
||||
if pool:
|
||||
self._ts_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
entry = self._ts_cache.get(key)
|
||||
if entry:
|
||||
if (
|
||||
entry.connection_url == connection_url
|
||||
and entry.pool_min_size == pool_min_size
|
||||
and entry.pool_max_size == pool_max_size
|
||||
):
|
||||
self._ts_cache.move_to_end(key)
|
||||
return entry.pool
|
||||
|
||||
await entry.pool.close()
|
||||
logger.info(
|
||||
"Rebuilding TimescaleDB pool for project %s (%s) due to config change",
|
||||
project_id,
|
||||
db_role,
|
||||
)
|
||||
self._ts_cache.pop(key, None)
|
||||
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
@@ -107,7 +149,12 @@ class ProjectConnectionManager:
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._ts_cache[key] = pool
|
||||
self._ts_cache[key] = PoolEntry(
|
||||
pool=pool,
|
||||
connection_url=connection_url,
|
||||
pool_min_size=pool_min_size,
|
||||
pool_max_size=pool_max_size,
|
||||
)
|
||||
await self._evict_ts_if_needed()
|
||||
logger.info(
|
||||
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
|
||||
@@ -123,14 +170,28 @@ class ProjectConnectionManager:
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._pg_raw_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._pg_raw_cache.get(key)
|
||||
if pool:
|
||||
self._pg_raw_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
entry = self._pg_raw_cache.get(key)
|
||||
if entry:
|
||||
if (
|
||||
entry.connection_url == connection_url
|
||||
and entry.pool_min_size == pool_min_size
|
||||
and entry.pool_max_size == pool_max_size
|
||||
):
|
||||
self._pg_raw_cache.move_to_end(key)
|
||||
return entry.pool
|
||||
|
||||
await entry.pool.close()
|
||||
logger.info(
|
||||
"Rebuilding PostgreSQL pool for project %s (%s) due to config change",
|
||||
project_id,
|
||||
db_role,
|
||||
)
|
||||
self._pg_raw_cache.pop(key, None)
|
||||
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
@@ -139,7 +200,12 @@ class ProjectConnectionManager:
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._pg_raw_cache[key] = pool
|
||||
self._pg_raw_cache[key] = PoolEntry(
|
||||
pool=pool,
|
||||
connection_url=connection_url,
|
||||
pool_min_size=pool_min_size,
|
||||
pool_max_size=pool_max_size,
|
||||
)
|
||||
await self._evict_pg_raw_if_needed()
|
||||
logger.info(
|
||||
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
|
||||
@@ -158,8 +224,8 @@ class ProjectConnectionManager:
|
||||
|
||||
async def _evict_ts_if_needed(self) -> None:
|
||||
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
|
||||
key, pool = self._ts_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
key, entry = self._ts_cache.popitem(last=False)
|
||||
await entry.pool.close()
|
||||
logger.info(
|
||||
"Evicted TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
@@ -168,8 +234,8 @@ class ProjectConnectionManager:
|
||||
|
||||
async def _evict_pg_raw_if_needed(self) -> None:
|
||||
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||
key, pool = self._pg_raw_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
key, entry = self._pg_raw_cache.popitem(last=False)
|
||||
await entry.pool.close()
|
||||
logger.info(
|
||||
"Evicted PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
@@ -188,8 +254,8 @@ class ProjectConnectionManager:
|
||||
self._pg_cache.clear()
|
||||
|
||||
async with self._ts_lock:
|
||||
for key, pool in list(self._ts_cache.items()):
|
||||
await pool.close()
|
||||
for key, entry in list(self._ts_cache.items()):
|
||||
await entry.pool.close()
|
||||
logger.info(
|
||||
"Closed TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
@@ -198,8 +264,8 @@ class ProjectConnectionManager:
|
||||
self._ts_cache.clear()
|
||||
|
||||
async with self._pg_raw_lock:
|
||||
for key, pool in list(self._pg_raw_cache.items()):
|
||||
await pool.close()
|
||||
for key, entry in list(self._pg_raw_cache.items()):
|
||||
await entry.pool.close()
|
||||
logger.info(
|
||||
"Closed PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
|
||||
@@ -18,7 +18,7 @@ from dateutil import parser
|
||||
import psycopg
|
||||
import time
|
||||
import app.services.simulation as simulation
|
||||
from app.services.tjnetwork import *
|
||||
from app.services.tjnetwork import close_project, get_time, is_project_open, open_project
|
||||
import schedule
|
||||
import threading
|
||||
import app.services.globals as globals
|
||||
@@ -29,7 +29,7 @@ import pytz
|
||||
import app.infra.db.influxdb.info as influxdb_info
|
||||
import app.services.project_info as project_info
|
||||
import app.services.time_api as time_api
|
||||
from app.native.api.postgresql_info import get_pgconn_string
|
||||
from app.core.config import get_pgconn_string
|
||||
|
||||
# influxdb数据库连接信息
|
||||
url = influxdb_info.url
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
||||
@@ -42,6 +42,7 @@ class Project(Base):
|
||||
code: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
map_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
@@ -95,7 +96,9 @@ class UserProjectMembership(Base):
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user