Compare commits
51 Commits
3eb7d2236d
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
| 80b6970970 | |||
| 364a8c8ec2 | |||
| 52ccb8abf1 | |||
| 0bc4058f23 | |||
| 0d3e6ca4fa | |||
| 6fc3aa5209 | |||
| 1b1b0a3697 | |||
| 2826999ddc | |||
| efc05f7278 | |||
| 29209f5c63 | |||
| 020432ad0e | |||
| 780a48d927 | |||
| ff2011ae24 | |||
| f5069a5606 | |||
| eb45e4aaa5 | |||
| a472639b8a | |||
| a0987105dc | |||
| a41be9c362 | |||
| 63b31b46b9 | |||
| e4f864a28c | |||
| dc38313cdc | |||
| f19962510a | |||
| 6434cae21c | |||
| a85ff8e215 | |||
| 2794114000 | |||
| 4c208abe55 | |||
| e893c7db5f | |||
| f2776ef0bf | |||
| 870c9433d6 | |||
| 6fe01aa248 | |||
| 0755b1a61c | |||
| 9be2028e4c | |||
| 3c7e2c5806 | |||
| c3c26fb107 | |||
| e4c8b03277 | |||
| 35abaa1ebb | |||
| 807e634318 | |||
| b6b37a453b | |||
| e3141ee250 | |||
| 9037bf317b | |||
| 9d7a9fb2fd | |||
| 7c9667822f | |||
| f3665798b7 | |||
| 7640d96f86 | |||
| d21966e985 | |||
| 0d139f96f8 | |||
| 2668faf8ad | |||
| fd3a9f92c0 | |||
| 5986a20cc3 | |||
| 6c0f7d821c | |||
| f1b05b7fa2 |
79
.env.example
Normal file
79
.env.example
Normal file
@@ -0,0 +1,79 @@
|
||||
# TJWater Server 环境变量配置模板
|
||||
# 复制此文件为 .env 并填写实际值
|
||||
NETWORK_NAME="szh"
|
||||
# ============================================
|
||||
# 安全配置 (必填)
|
||||
# ============================================
|
||||
|
||||
# JWT 密钥 - 用于生成和验证 Token
|
||||
# 生成方式: openssl rand -hex 32
|
||||
SECRET_KEY=your-secret-key-here-change-in-production-use-openssl-rand-hex-32
|
||||
|
||||
# 数据加密密钥 - 用于敏感数据加密
|
||||
# 生成方式: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
ENCRYPTION_KEY=
|
||||
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
|
||||
# ============================================
|
||||
# 数据库配置 (PostgreSQL)
|
||||
# ============================================
|
||||
DB_NAME="tjwater"
|
||||
DB_HOST="localhost"
|
||||
DB_PORT="5432"
|
||||
DB_USER="postgres"
|
||||
DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 数据库配置 (TimescaleDB)
|
||||
# ============================================
|
||||
TIMESCALEDB_DB_NAME="szh"
|
||||
TIMESCALEDB_DB_HOST="localhost"
|
||||
TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
# ============================================
|
||||
# 元数据数据库配置 (Metadata DB)
|
||||
# ============================================
|
||||
METADATA_DB_NAME="system_hub"
|
||||
METADATA_DB_HOST="localhost"
|
||||
METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="password"
|
||||
|
||||
# ============================================
|
||||
# 项目连接缓存与连接池配置
|
||||
# ============================================
|
||||
PROJECT_PG_CACHE_SIZE=50
|
||||
PROJECT_TS_CACHE_SIZE=50
|
||||
PROJECT_PG_POOL_SIZE=5
|
||||
PROJECT_PG_MAX_OVERFLOW=10
|
||||
PROJECT_TS_POOL_MIN_SIZE=1
|
||||
PROJECT_TS_POOL_MAX_SIZE=10
|
||||
|
||||
# ============================================
|
||||
# InfluxDB 配置 (时序数据)
|
||||
# ============================================
|
||||
# INFLUXDB_URL=http://localhost:8086
|
||||
# INFLUXDB_TOKEN=your-influxdb-token
|
||||
# INFLUXDB_ORG=your-org
|
||||
# INFLUXDB_BUCKET=tjwater
|
||||
|
||||
# ============================================
|
||||
# JWT 配置 (可选)
|
||||
# ============================================
|
||||
# ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
# REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
# ALGORITHM=HS256
|
||||
|
||||
# ============================================
|
||||
# Keycloak JWT (可选)
|
||||
# ============================================
|
||||
# KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||
# KEYCLOAK_ALGORITHM=RS256
|
||||
|
||||
# ============================================
|
||||
# 其他配置
|
||||
# ============================================
|
||||
# PROJECT_NAME=TJWater Server
|
||||
# API_V1_STR=/api/v1
|
||||
23
.env.local
Normal file
23
.env.local
Normal file
@@ -0,0 +1,23 @@
|
||||
NETWORK_NAME="tjwater"
|
||||
KEYCLOAK_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEApBjdgjImuFfKsZ+FWFlsZSG0Kftduc2o0qA/warFezaYmi8+7fiuuhLErLUbjGPSEU3WpsVxPe5PIs+AJJn/z9uBXXXo/pYggHvp48hlwr6MIYX5xtby7MLM/bHL2ACN4m7FNs/Gilkkbt4515sMFUiwJzd6Wj6FvQdGDDGx/7bVGgiVQRJvrrMZN5zD4i8cFiTQIcGKbURJjre/zWWiA+7gEwArp9ujjBuaINooiQLQM39C9Z5QJcp5nhaztOBiJJgiJOHi5MLpIhI1p1ViVBXKXRMuPhtTXLAz+r/sC44XZS/6V8uUPuLNin9o0jHk/CqJ3GkK3xJBQoWgplkwuQIDAQAB\n-----END PUBLIC KEY-----"
|
||||
KEYCLOAK_ALGORITHM="RS256"
|
||||
KEYCLOAK_AUDIENCE="account"
|
||||
|
||||
DB_NAME="tjwater"
|
||||
DB_HOST="192.168.1.114"
|
||||
DB_PORT="5432"
|
||||
DB_USER="tjwater"
|
||||
DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
TIMESCALEDB_DB_NAME="tjwater"
|
||||
TIMESCALEDB_DB_HOST="192.168.1.114"
|
||||
TIMESCALEDB_DB_PORT="5433"
|
||||
TIMESCALEDB_DB_USER="tjwater"
|
||||
TIMESCALEDB_DB_PASSWORD="Tjwater@123456"
|
||||
|
||||
METADATA_DB_NAME="system_hub"
|
||||
METADATA_DB_HOST="192.168.1.114"
|
||||
METADATA_DB_PORT="5432"
|
||||
METADATA_DB_USER="tjwater"
|
||||
METADATA_DB_PASSWORD="Tjwater@123456"
|
||||
DATABASE_ENCRYPTION_KEY="rJC2VqLg4KrlSq+DGJcYm869q4v5KB2dFAeuQTe0I50="
|
||||
198
.github/copilot-instructions.md
vendored
Normal file
198
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,198 @@
|
||||
# TJWater Server - Copilot Instructions
|
||||
|
||||
This is a FastAPI-based water network management system (供水管网智能管理系统) that provides hydraulic simulation, SCADA data integration, network element management, and risk analysis capabilities.
|
||||
|
||||
## Running the Server
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Start the server (default: http://0.0.0.0:8000 with 2 workers)
|
||||
python scripts/run_server.py
|
||||
|
||||
# Note: On Windows, the script automatically sets WindowsSelectorEventLoopPolicy
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest
|
||||
|
||||
# Run a specific test file with verbose output
|
||||
pytest tests/unit/test_pipeline_health_analyzer.py -v
|
||||
|
||||
# Run from conftest helper
|
||||
python tests/conftest.py
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
1. **Native Modules** (`app/native/`): Platform-specific compiled extensions (`.so` for Linux, `.pyd` for Windows) providing performance-critical functionality including:
|
||||
- SCADA device integration
|
||||
- Water distribution analysis (WDA)
|
||||
- Pipe risk probability calculations
|
||||
- Wrapped through `app.services.tjnetwork` interface
|
||||
|
||||
2. **Services Layer** (`app/services/`):
|
||||
- `tjnetwork.py`: Main network API wrapper around native modules
|
||||
- `simulation.py`: Hydraulic simulation orchestration (EPANET integration)
|
||||
- `project_info.py`: Project configuration management
|
||||
- `epanet/`: EPANET hydraulic engine integration
|
||||
|
||||
3. **API Layer** (`app/api/v1/`):
|
||||
- **Network Elements**: Separate endpoint modules for junctions, reservoirs, tanks, pipes, pumps, valves
|
||||
- **Components**: Curves, patterns, controls, options, quality, visuals
|
||||
- **Network Features**: Tags, demands, geometry, regions/DMAs
|
||||
- **Core Services**: Auth, project, simulation, SCADA, data query, snapshots
|
||||
|
||||
4. **Database Infrastructure** (`app/infra/db/`):
|
||||
- **PostgreSQL**: Primary relational database (users, audit logs, project metadata)
|
||||
- **TimescaleDB**: Time-series extension for historical data
|
||||
- **InfluxDB**: Optional time-series database for high-frequency SCADA data
|
||||
- Connection pools initialized in `main.py` lifespan context
|
||||
- Database instance stored in `app.state.db` for dependency injection
|
||||
|
||||
5. **Domain Layer** (`app/domain/`):
|
||||
- `models/`: Enums and domain objects (e.g., `UserRole`)
|
||||
- `schemas/`: Pydantic models for request/response validation
|
||||
|
||||
6. **Algorithms** (`app/algorithms/`):
|
||||
- `api_ex/`: Analysis algorithms (k-means sensor placement, sensitivity analysis, pipeline health)
|
||||
- `data_cleaning.py`: Data preprocessing utilities
|
||||
- `simulations.py`: Simulation helpers
|
||||
|
||||
### Security & Authentication
|
||||
|
||||
- **Authentication**: JWT-based with access tokens (30 min) and refresh tokens (7 days)
|
||||
- **Authorization**: Role-based access control (RBAC) with 4 roles:
|
||||
- `VIEWER`: Read-only access
|
||||
- `USER`: Read-write access
|
||||
- `OPERATOR`: Modify data
|
||||
- `ADMIN`: Full permissions
|
||||
- **Audit Logging**: `AuditMiddleware` automatically logs POST/PUT/DELETE requests
|
||||
- **Encryption**: Fernet symmetric encryption for sensitive data (`app.core.encryption`)
|
||||
|
||||
Default admin accounts:
|
||||
- `admin` / `admin123`
|
||||
- `tjwater` / `tjwater@123`
|
||||
|
||||
### Key Files
|
||||
|
||||
- `app/main.py`: FastAPI app initialization, lifespan (DB pools), CORS, middleware, router mounting
|
||||
- `app/api/v1/router.py`: Central router aggregating all endpoint modules
|
||||
- `app/core/config.py`: Settings management using `pydantic-settings`
|
||||
- `app/auth/dependencies.py`: Auth dependencies (`get_current_active_user`, `get_db`)
|
||||
- `app/auth/permissions.py`: Permission decorators (`require_role`, `get_current_admin`)
|
||||
- `.env`: Environment configuration (database credentials, JWT secret, encryption key)
|
||||
|
||||
## Important Conventions
|
||||
|
||||
### Database Connections
|
||||
|
||||
- Database instances are initialized in `main.py` lifespan and stored in `app.state.db`
|
||||
- Access via dependency injection:
|
||||
```python
|
||||
from app.auth.dependencies import get_db
|
||||
|
||||
async def endpoint(db = Depends(get_db)):
|
||||
# Use db connection
|
||||
```
|
||||
|
||||
### Authentication in Endpoints
|
||||
|
||||
Use dependency injection for auth requirements:
|
||||
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
from app.auth.permissions import require_role, get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
# Require any authenticated user
|
||||
@router.get("/data")
|
||||
async def get_data(current_user = Depends(get_current_active_user)):
|
||||
return data
|
||||
|
||||
# Require specific role (USER or higher)
|
||||
@router.post("/data")
|
||||
async def create_data(current_user = Depends(require_role(UserRole.USER))):
|
||||
return result
|
||||
|
||||
# Admin-only access
|
||||
@router.delete("/data/{id}")
|
||||
async def delete_data(id: int, current_user = Depends(get_current_admin)):
|
||||
return result
|
||||
```
|
||||
|
||||
### API Routing Structure
|
||||
|
||||
- All v1 APIs are mounted under `/api/v1` prefix via `api_router`
|
||||
- Legacy routes without version prefix are also mounted for backward compatibility
|
||||
- Group related endpoints in separate router modules under `app/api/v1/endpoints/`
|
||||
- Use descriptive tags in `router.py` for OpenAPI documentation grouping
|
||||
|
||||
### Native Module Integration
|
||||
|
||||
- Native modules are pre-compiled for specific platforms
|
||||
- Always import through `app.native.api` or `app.services.tjnetwork`
|
||||
- The `tjnetwork` service wraps native APIs with constants like:
|
||||
- Element types: `JUNCTION`, `RESERVOIR`, `TANK`, `PIPE`, `PUMP`, `VALVE`
|
||||
- Operations: `API_ADD`, `API_UPDATE`, `API_DELETE`
|
||||
- `ChangeSet` for batch operations
|
||||
|
||||
### Project Initialization
|
||||
|
||||
- On startup, `main.py` automatically loads project from `project_info.name` if set
|
||||
- Projects are opened via `open_project(name)` from `tjnetwork` service
|
||||
|
||||
### Audit Logging
|
||||
|
||||
Manual audit logging for critical operations:
|
||||
|
||||
```python
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
|
||||
await log_audit_event(
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
username=current_user.username,
|
||||
resource_type="resource_name",
|
||||
resource_id=str(resource_id),
|
||||
ip_address=request.client.host,
|
||||
request_data=data
|
||||
)
|
||||
```
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
- Copy `.env.example` to `.env` before first run
|
||||
- Required environment variables:
|
||||
- `SECRET_KEY`: JWT signing (generate with `openssl rand -hex 32`)
|
||||
- `ENCRYPTION_KEY`: Data encryption (generate with Fernet)
|
||||
- Database credentials for PostgreSQL, TimescaleDB, and optionally InfluxDB
|
||||
|
||||
### Database Migrations
|
||||
|
||||
SQL migration scripts are in `migrations/`:
|
||||
- `001_create_users_table.sql`: User authentication tables
|
||||
- `002_create_audit_logs_table.sql`: Audit logging tables
|
||||
|
||||
Apply with:
|
||||
```bash
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
```
|
||||
|
||||
## API Documentation
|
||||
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
- OpenAPI schema: http://localhost:8000/openapi.json
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `SECURITY_README.md`: Comprehensive security feature documentation
|
||||
- `DEPLOYMENT.md`: Integration guide for security features
|
||||
- `readme.md`: Project overview and directory structure (in Chinese)
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -1,7 +1,9 @@
|
||||
*.pyc
|
||||
.env
|
||||
db_inp/
|
||||
temp/
|
||||
data/
|
||||
build/
|
||||
*.pyc
|
||||
.env
|
||||
*.dump
|
||||
api_ex/model/my_survival_forest_model_quxi.joblib
|
||||
app/algorithms/api_ex/model/my_survival_forest_model_quxi.joblib
|
||||
.vscode/
|
||||
|
||||
391
DEPLOYMENT.md
Normal file
391
DEPLOYMENT.md
Normal file
@@ -0,0 +1,391 @@
|
||||
# 部署和集成指南
|
||||
|
||||
本文档说明如何将新的安全功能集成到现有系统中。
|
||||
|
||||
## 📦 已完成的功能
|
||||
|
||||
### 1. 数据加密模块
|
||||
- ✅ `app/core/encryption.py` - Fernet 对称加密实现
|
||||
- ✅ 支持敏感数据加密/解密
|
||||
- ✅ 密钥管理和生成工具
|
||||
|
||||
### 2. 用户认证系统
|
||||
- ✅ `app/domain/models/role.py` - 用户角色枚举 (ADMIN/OPERATOR/USER/VIEWER)
|
||||
- ✅ `app/domain/schemas/user.py` - 用户数据模型和验证
|
||||
- ✅ `app/infra/repositories/user_repository.py` - 用户数据访问层
|
||||
- ✅ `app/api/v1/endpoints/auth.py` - 注册/登录/刷新Token接口
|
||||
- ✅ `app/auth/dependencies.py` - 认证依赖项
|
||||
- ✅ `migrations/001_create_users_table.sql` - 用户表迁移脚本
|
||||
|
||||
### 3. 权限控制系统
|
||||
- ✅ `app/auth/permissions.py` - RBAC 权限控制装饰器
|
||||
- ✅ `app/api/v1/endpoints/user_management.py` - 用户管理接口示例
|
||||
- ✅ 支持基于角色的访问控制
|
||||
- ✅ 支持资源所有者检查
|
||||
|
||||
### 4. 审计日志系统
|
||||
- ✅ `app/core/audit.py` - 审计日志核心功能
|
||||
- ✅ `app/domain/schemas/audit.py` - 审计日志数据模型
|
||||
- ✅ `app/infra/repositories/audit_repository.py` - 审计日志数据访问层
|
||||
- ✅ `app/api/v1/endpoints/audit.py` - 审计日志查询接口
|
||||
- ✅ `app/infra/audit/middleware.py` - 自动审计中间件
|
||||
- ✅ `migrations/002_create_audit_logs_table.sql` - 审计日志表迁移脚本
|
||||
|
||||
### 5. 文档和测试
|
||||
- ✅ `SECURITY_README.md` - 完整的使用文档
|
||||
- ✅ `.env.example` - 环境变量配置模板
|
||||
- ✅ `tests/test_encryption.py` - 加密功能测试
|
||||
|
||||
---
|
||||
|
||||
## 🔧 集成步骤
|
||||
|
||||
### 步骤 1: 环境配置
|
||||
|
||||
1. 复制环境变量模板:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
2. 生成密钥并填写 `.env`:
|
||||
```bash
|
||||
# JWT 密钥
|
||||
openssl rand -hex 32
|
||||
|
||||
# 加密密钥
|
||||
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
```
|
||||
|
||||
3. 编辑 `.env` 填写所有必需的配置项。
|
||||
|
||||
### 步骤 2: 数据库迁移
|
||||
|
||||
执行数据库迁移脚本:
|
||||
|
||||
```bash
|
||||
# 方法 1: 使用 psql 命令
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
|
||||
# 方法 2: 在 psql 交互界面
|
||||
psql -U postgres -d tjwater
|
||||
\i migrations/001_create_users_table.sql
|
||||
\i migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
验证表已创建:
|
||||
|
||||
```sql
|
||||
-- 检查用户表
|
||||
SELECT * FROM users;
|
||||
|
||||
-- 检查审计日志表
|
||||
SELECT * FROM audit_logs;
|
||||
```
|
||||
|
||||
### 步骤 3: 更新 main.py
|
||||
|
||||
在 `app/main.py` 中集成新功能:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from app.core.config import settings
|
||||
from app.infra.audit.middleware import AuditMiddleware
|
||||
|
||||
app = FastAPI(title=settings.PROJECT_NAME)
|
||||
|
||||
# 1. 添加审计中间件(可选)
|
||||
app.add_middleware(AuditMiddleware)
|
||||
|
||||
# 2. 注册路由
|
||||
from app.api.v1.endpoints import auth, user_management, audit
|
||||
|
||||
app.include_router(
|
||||
auth.router,
|
||||
prefix=f"{settings.API_V1_STR}/auth",
|
||||
tags=["认证"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
user_management.router,
|
||||
prefix=f"{settings.API_V1_STR}/users",
|
||||
tags=["用户管理"]
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
audit.router,
|
||||
prefix=f"{settings.API_V1_STR}/audit",
|
||||
tags=["审计日志"]
|
||||
)
|
||||
|
||||
# 3. 确保数据库在启动时初始化
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# 初始化数据库连接池
|
||||
from app.infra.db.postgresql.database import Database
|
||||
global db
|
||||
db = Database()
|
||||
db.init_pool()
|
||||
await db.open()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
# 关闭数据库连接
|
||||
await db.close()
|
||||
```
|
||||
|
||||
### 步骤 4: 保护现有接口
|
||||
|
||||
#### 方法 1: 为路由添加全局依赖
|
||||
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
|
||||
# 为整个路由器添加认证
|
||||
router = APIRouter(dependencies=[Depends(get_current_active_user)])
|
||||
```
|
||||
|
||||
#### 方法 2: 为单个端点添加依赖
|
||||
|
||||
```python
|
||||
from app.auth.permissions import require_role, get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
@router.get("/data")
|
||||
async def get_data(
|
||||
current_user = Depends(require_role(UserRole.USER))
|
||||
):
|
||||
"""需要 USER 及以上角色"""
|
||||
return {"data": "protected"}
|
||||
|
||||
@router.delete("/data/{id}")
|
||||
async def delete_data(
|
||||
id: int,
|
||||
current_user = Depends(get_current_admin)
|
||||
):
|
||||
"""仅管理员可访问"""
|
||||
return {"message": "deleted"}
|
||||
```
|
||||
|
||||
### 步骤 5: 添加审计日志
|
||||
|
||||
#### 自动审计(推荐)
|
||||
|
||||
使用中间件自动记录(已在 main.py 中添加):
|
||||
|
||||
```python
|
||||
app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
#### 手动审计
|
||||
|
||||
在关键业务逻辑中手动记录:
|
||||
|
||||
```python
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
|
||||
@router.post("/important-action")
|
||||
async def important_action(
|
||||
data: dict,
|
||||
request: Request,
|
||||
current_user = Depends(get_current_active_user)
|
||||
):
|
||||
# 执行业务逻辑
|
||||
result = do_something(data)
|
||||
|
||||
# 记录审计日志
|
||||
await log_audit_event(
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
username=current_user.username,
|
||||
resource_type="important_resource",
|
||||
resource_id=str(result.id),
|
||||
ip_address=request.client.host,
|
||||
request_data=data
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### 步骤 6: 更新 auth/dependencies.py
|
||||
|
||||
确保 `get_db()` 函数正确获取数据库实例:
|
||||
|
||||
```python
|
||||
async def get_db() -> Database:
|
||||
"""获取数据库实例"""
|
||||
# 方法 1: 从 main.py 导入
|
||||
from app.main import db
|
||||
return db
|
||||
|
||||
# 方法 2: 从 FastAPI app.state 获取
|
||||
# from fastapi import Request
|
||||
# def get_db_from_request(request: Request):
|
||||
# return request.app.state.db
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧪 测试
|
||||
|
||||
### 1. 测试加密功能
|
||||
|
||||
```bash
|
||||
python tests/test_encryption.py
|
||||
```
|
||||
|
||||
### 2. 测试 API
|
||||
|
||||
启动服务器:
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
访问交互式文档:
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
|
||||
### 3. 测试登录
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=admin&password=admin123"
|
||||
```
|
||||
|
||||
### 4. 测试受保护接口
|
||||
|
||||
```bash
|
||||
TOKEN="your-access-token"
|
||||
curl -X GET "http://localhost:8000/api/v1/auth/me" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 迁移现有接口
|
||||
|
||||
### 原有硬编码认证
|
||||
|
||||
**旧代码** (`app/api/v1/endpoints/auth.py`):
|
||||
```python
|
||||
AUTH_TOKEN = "567e33c876a2"
|
||||
|
||||
async def verify_token(authorization: str = Header()):
|
||||
token = authorization.split(" ")[1]
|
||||
if token != AUTH_TOKEN:
|
||||
raise HTTPException(status_code=403)
|
||||
```
|
||||
|
||||
**新代码** (已更新):
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
|
||||
@router.get("/protected")
|
||||
async def protected_route(
|
||||
current_user = Depends(get_current_active_user)
|
||||
):
|
||||
return {"user": current_user.username}
|
||||
```
|
||||
|
||||
### 更新其他端点
|
||||
|
||||
搜索项目中使用旧认证的地方:
|
||||
|
||||
```bash
|
||||
grep -r "AUTH_TOKEN" app/
|
||||
grep -r "verify_token" app/
|
||||
```
|
||||
|
||||
替换为新的依赖注入系统。
|
||||
|
||||
---
|
||||
|
||||
## 📋 检查清单
|
||||
|
||||
部署前检查:
|
||||
|
||||
- [ ] 环境变量已配置(`.env`)
|
||||
- [ ] 数据库迁移已执行
|
||||
- [ ] 默认管理员账号可登录
|
||||
- [ ] JWT Token 可正常生成和验证
|
||||
- [ ] 权限控制正常工作
|
||||
- [ ] 审计日志正常记录
|
||||
- [ ] 加密功能测试通过
|
||||
- [ ] API 文档可访问
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### 1. 向后兼容性
|
||||
|
||||
保留了简化版登录接口 `/auth/login/simple` 以兼容旧客户端:
|
||||
|
||||
```python
|
||||
@router.post("/login/simple")
|
||||
async def login_simple(username: str, password: str):
|
||||
# 验证并返回 Token
|
||||
...
|
||||
```
|
||||
|
||||
### 2. 数据库连接
|
||||
|
||||
确保在 `app/auth/dependencies.py` 中 `get_db()` 函数能正确获取数据库实例。
|
||||
|
||||
### 3. 密钥安全
|
||||
|
||||
- ❌ 不要提交 `.env` 文件到版本控制
|
||||
- ✅ 在生产环境使用环境变量或密钥管理服务
|
||||
- ✅ 定期轮换 JWT 密钥
|
||||
|
||||
### 4. 性能考虑
|
||||
|
||||
- 审计中间件会增加每个请求的处理时间(约 5-10ms)
|
||||
- 对高频接口可考虑异步记录审计日志
|
||||
- 定期清理或归档旧的审计日志
|
||||
|
||||
---
|
||||
|
||||
## 🐛 故障排查
|
||||
|
||||
### 问题 1: 导入错误
|
||||
|
||||
```
|
||||
ImportError: cannot import name 'db' from 'app.main'
|
||||
```
|
||||
|
||||
**解决**: 确保在 `app/main.py` 中定义了全局 `db` 对象。
|
||||
|
||||
### 问题 2: 认证失败
|
||||
|
||||
```
|
||||
401 Unauthorized: Could not validate credentials
|
||||
```
|
||||
|
||||
**检查**:
|
||||
1. Token 是否正确设置在 `Authorization: Bearer {token}` header
|
||||
2. Token 是否过期
|
||||
3. SECRET_KEY 是否配置正确
|
||||
|
||||
### 问题 3: 数据库连接失败
|
||||
|
||||
```
|
||||
psycopg.OperationalError: connection failed
|
||||
```
|
||||
|
||||
**检查**:
|
||||
1. PostgreSQL 是否运行
|
||||
2. `.env` 中数据库配置是否正确
|
||||
3. 数据库是否存在
|
||||
|
||||
---
|
||||
|
||||
## 📞 技术支持
|
||||
|
||||
详细文档请参考:
|
||||
- `SECURITY_README.md` - 安全功能使用指南
|
||||
- `migrations/` - 数据库迁移脚本
|
||||
- `app/domain/schemas/` - 数据模型定义
|
||||
|
||||
24
Dockerfile
Normal file
24
Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
FROM continuumio/miniconda3:latest
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装 Python 3.12 和 pymetis (通过 conda-forge 避免编译问题)
|
||||
RUN conda install -y -c conda-forge python=3.12 pymetis && \
|
||||
conda clean -afy
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 将代码放入子目录 'app',将数据放入子目录 'db_inp'
|
||||
# 这样临时文件默认会生成在 /app 下,而代码在 /app/app 下,实现了分离
|
||||
COPY app ./app
|
||||
COPY db_inp ./db_inp
|
||||
COPY temp ./temp
|
||||
COPY .env .
|
||||
|
||||
# 设置 PYTHONPATH 以便 uvicorn 找到 app 模块
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
||||
322
INTEGRATION_CHECKLIST.md
Normal file
322
INTEGRATION_CHECKLIST.md
Normal file
@@ -0,0 +1,322 @@
|
||||
# API 集成检查清单
|
||||
|
||||
## ✅ 已完成的集成工作
|
||||
|
||||
### 1. 路由集成 (app/api/v1/router.py)
|
||||
|
||||
已添加以下路由到 API Router:
|
||||
|
||||
```python
|
||||
# 新增导入
|
||||
from app.api.v1.endpoints import (
|
||||
...
|
||||
user_management, # 用户管理
|
||||
audit, # 审计日志
|
||||
)
|
||||
|
||||
# 新增路由
|
||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"])
|
||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"])
|
||||
```
|
||||
|
||||
**路由端点**:
|
||||
- `/api/v1/auth/` - 认证相关(register, login, me, refresh)
|
||||
- `/api/v1/users/` - 用户管理(CRUD操作,仅管理员)
|
||||
- `/api/v1/audit/` - 审计日志查询(仅管理员)
|
||||
|
||||
### 2. 主应用配置 (app/main.py)
|
||||
|
||||
#### 2.1 导入更新
|
||||
```python
|
||||
from app.core.config import settings
|
||||
from app.infra.audit.middleware import AuditMiddleware
|
||||
```
|
||||
|
||||
#### 2.2 数据库初始化
|
||||
```python
|
||||
# 在 lifespan 中存储数据库实例到 app.state
|
||||
app.state.db = pgdb
|
||||
```
|
||||
|
||||
#### 2.3 FastAPI 配置
|
||||
```python
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title=settings.PROJECT_NAME,
|
||||
description="TJWater Server - 供水管网智能管理系统",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
```
|
||||
|
||||
#### 2.4 审计中间件(可选)
|
||||
```python
|
||||
# 取消注释以启用审计日志
|
||||
# app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
### 3. 依赖项更新 (app/auth/dependencies.py)
|
||||
|
||||
更新 `get_db()` 函数从 Request 对象获取数据库:
|
||||
|
||||
```python
|
||||
async def get_db(request: Request) -> Database:
|
||||
"""从 app.state 获取数据库实例"""
|
||||
if not hasattr(request.app.state, "db"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Database not initialized"
|
||||
)
|
||||
return request.app.state.db
|
||||
```
|
||||
|
||||
### 4. 审计日志更新
|
||||
|
||||
- `app/api/v1/endpoints/audit.py` - 使用正确的数据库依赖
|
||||
- `app/core/audit.py` - 接受可选的 db 参数
|
||||
|
||||
---
|
||||
|
||||
## 📋 部署前检查清单
|
||||
|
||||
### 环境配置
|
||||
- [ ] 复制 `.env.example` 为 `.env`
|
||||
- [ ] 配置 `SECRET_KEY`(JWT密钥)
|
||||
- [ ] 配置 `ENCRYPTION_KEY`(数据加密密钥)
|
||||
- [ ] 配置数据库连接信息
|
||||
|
||||
### 数据库迁移
|
||||
- [ ] 执行用户表迁移:`psql -U postgres -d tjwater -f migrations/001_create_users_table.sql`
|
||||
- [ ] 执行审计日志表迁移:`psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql`
|
||||
- [ ] 验证表已创建:`\dt` 在 psql 中
|
||||
|
||||
### 依赖检查
|
||||
- [ ] 确认已安装:`cryptography`
|
||||
- [ ] 确认已安装:`python-jose[cryptography]`
|
||||
- [ ] 确认已安装:`passlib[bcrypt]`
|
||||
- [ ] 确认已安装:`email-validator`(用于 Pydantic email 验证)
|
||||
|
||||
### 代码验证
|
||||
- [ ] 检查所有文件导入正常
|
||||
- [ ] 运行加密功能测试:`python tests/test_encryption.py`
|
||||
- [ ] 启动服务器:`uvicorn app.main:app --reload`
|
||||
- [ ] 访问 API 文档:http://localhost:8000/docs
|
||||
|
||||
### API 测试
|
||||
- [ ] 测试登录:POST `/api/v1/auth/login`
|
||||
- [ ] 测试获取当前用户:GET `/api/v1/auth/me`
|
||||
- [ ] 测试用户列表(需管理员):GET `/api/v1/users/`
|
||||
- [ ] 测试审计日志(需管理员):GET `/api/v1/audit/logs`
|
||||
|
||||
---
|
||||
|
||||
## 🔧 快速测试命令
|
||||
|
||||
### 1. 生成密钥
|
||||
```bash
|
||||
# JWT 密钥
|
||||
openssl rand -hex 32
|
||||
|
||||
# 加密密钥
|
||||
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
```
|
||||
|
||||
### 2. 执行迁移
|
||||
```bash
|
||||
cd /home/zhifu/TJWaterServer/TJWaterServerBinary
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
### 3. 测试加密
|
||||
```bash
|
||||
python tests/test_encryption.py
|
||||
```
|
||||
|
||||
### 4. 启动服务器
|
||||
```bash
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### 5. 测试登录 API
|
||||
```bash
|
||||
# 使用默认管理员账号
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=admin&password=admin123"
|
||||
|
||||
# 或使用迁移的账号
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=tjwater&password=tjwater@123"
|
||||
```
|
||||
|
||||
### 6. 测试受保护接口
|
||||
```bash
|
||||
# 保存 Token
|
||||
TOKEN="<从登录响应中获取的 access_token>"
|
||||
|
||||
# 获取当前用户信息
|
||||
curl -X GET "http://localhost:8000/api/v1/auth/me" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
|
||||
# 获取用户列表(需管理员权限)
|
||||
curl -X GET "http://localhost:8000/api/v1/users/" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
|
||||
# 查询审计日志(需管理员权限)
|
||||
curl -X GET "http://localhost:8000/api/v1/audit/logs" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📚 API 端点总览
|
||||
|
||||
### 认证接口 (`/api/v1/auth`)
|
||||
|
||||
| 方法 | 端点 | 描述 | 权限 |
|
||||
|------|------|------|------|
|
||||
| POST | `/register` | 用户注册 | 公开 |
|
||||
| POST | `/login` | OAuth2 登录 | 公开 |
|
||||
| POST | `/login/simple` | 简化登录(兼容旧版) | 公开 |
|
||||
| GET | `/me` | 获取当前用户信息 | 认证用户 |
|
||||
| POST | `/refresh` | 刷新 Token | 认证用户 |
|
||||
|
||||
### 用户管理 (`/api/v1/users`)
|
||||
|
||||
| 方法 | 端点 | 描述 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/` | 获取用户列表 | 管理员 |
|
||||
| GET | `/{id}` | 获取用户详情 | 所有者/管理员 |
|
||||
| PUT | `/{id}` | 更新用户信息 | 所有者/管理员 |
|
||||
| DELETE | `/{id}` | 删除用户 | 管理员 |
|
||||
| POST | `/{id}/activate` | 激活用户 | 管理员 |
|
||||
| POST | `/{id}/deactivate` | 停用用户 | 管理员 |
|
||||
|
||||
### 审计日志 (`/api/v1/audit`)
|
||||
|
||||
| 方法 | 端点 | 描述 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/logs` | 查询审计日志 | 管理员 |
|
||||
| GET | `/logs/count` | 获取日志总数 | 管理员 |
|
||||
| GET | `/logs/my` | 查看我的操作记录 | 认证用户 |
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
### 1. 审计中间件
|
||||
审计中间件默认是**禁用**的。如需启用,在 `app/main.py` 中取消注释:
|
||||
|
||||
```python
|
||||
app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
**注意**:启用后会自动记录所有 POST/PUT/DELETE 请求,可能增加数据库负载。
|
||||
|
||||
### 2. 向后兼容
|
||||
保留了原有的简化登录接口 `/auth/login/simple`,可以直接使用查询参数:
|
||||
|
||||
```bash
|
||||
POST /api/v1/auth/login/simple?username=admin&password=admin123
|
||||
```
|
||||
|
||||
### 3. 数据库连接
|
||||
确保数据库实例在应用启动时正确初始化并存储到 `app.state.db`。
|
||||
|
||||
### 4. 权限控制示例
|
||||
为现有接口添加权限控制:
|
||||
|
||||
```python
|
||||
from app.auth.permissions import require_role, get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
# 需要管理员权限
|
||||
@router.delete("/resource/{id}")
|
||||
async def delete_resource(
|
||||
id: int,
|
||||
current_user = Depends(get_current_admin)
|
||||
):
|
||||
...
|
||||
|
||||
# 需要操作员以上权限
|
||||
@router.post("/resource")
|
||||
async def create_resource(
|
||||
data: dict,
|
||||
current_user = Depends(require_role(UserRole.OPERATOR))
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 完整启动流程
|
||||
|
||||
```bash
|
||||
# 1. 进入项目目录
|
||||
cd /home/zhifu/TJWaterServer/TJWaterServerBinary
|
||||
|
||||
# 2. 配置环境变量(如果还没有)
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填写必要的配置
|
||||
|
||||
# 3. 执行数据库迁移(如果还没有)
|
||||
psql -U postgres -d tjwater < migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater < migrations/002_create_audit_logs_table.sql
|
||||
|
||||
# 4. 测试加密功能
|
||||
python tests/test_encryption.py
|
||||
|
||||
# 5. 启动服务器
|
||||
uvicorn app.main:app --reload
|
||||
|
||||
# 6. 访问 API 文档
|
||||
# 浏览器打开: http://localhost:8000/docs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📞 故障排查
|
||||
|
||||
### 问题 1: 导入错误
|
||||
```
|
||||
ModuleNotFoundError: No module named 'jose'
|
||||
```
|
||||
**解决**: 安装依赖 `pip install python-jose[cryptography]`
|
||||
|
||||
### 问题 2: 数据库未初始化
|
||||
```
|
||||
503 Service Unavailable: Database not initialized
|
||||
```
|
||||
**解决**: 检查 `main.py` 中的 lifespan 函数是否正确设置 `app.state.db`
|
||||
|
||||
### 问题 3: Token 验证失败
|
||||
```
|
||||
401 Unauthorized: Could not validate credentials
|
||||
```
|
||||
**解决**:
|
||||
1. 检查 SECRET_KEY 是否配置正确
|
||||
2. 确认 Token 格式:`Authorization: Bearer {token}`
|
||||
3. 检查 Token 是否过期
|
||||
|
||||
### 问题 4: 表不存在
|
||||
```
|
||||
relation "users" does not exist
|
||||
```
|
||||
**解决**: 执行数据库迁移脚本
|
||||
|
||||
---
|
||||
|
||||
## 📖 相关文档
|
||||
|
||||
- **使用指南**: `SECURITY_README.md`
|
||||
- **部署指南**: `DEPLOYMENT.md`
|
||||
- **实施总结**: `SECURITY_IMPLEMENTATION_SUMMARY.md`
|
||||
- **自动设置**: `setup_security.sh`
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-02
|
||||
**状态**: ✅ API 已完全集成
|
||||
@@ -1,4 +0,0 @@
|
||||
当前 适配 szh 项目的分支 是 dingsu/szh
|
||||
|
||||
Binary 适配的是 代码 中dingsu/szh 的部分
|
||||
当前只是把 API目录(也就是TJNetwork的部分)加密了
|
||||
370
SECURITY_IMPLEMENTATION_SUMMARY.md
Normal file
370
SECURITY_IMPLEMENTATION_SUMMARY.md
Normal file
@@ -0,0 +1,370 @@
|
||||
# 安全功能实施总结
|
||||
|
||||
## ✅ 已完成的功能
|
||||
|
||||
本次实施完成了完整的安全体系,包括数据加密、身份认证、权限管理、审计日志四大模块。
|
||||
|
||||
---
|
||||
|
||||
## 📁 新增文件清单
|
||||
|
||||
### 核心功能模块
|
||||
|
||||
1. **数据加密**
|
||||
- `app/core/encryption.py` - Fernet 加密实现
|
||||
- `tests/test_encryption.py` - 加密功能测试
|
||||
|
||||
2. **用户系统**
|
||||
- `app/domain/models/role.py` - 用户角色枚举
|
||||
- `app/domain/schemas/user.py` - 用户数据模型
|
||||
- `app/infra/repositories/user_repository.py` - 用户数据访问层
|
||||
|
||||
3. **认证授权**
|
||||
- `app/api/v1/endpoints/auth.py` - 认证接口(已重构)
|
||||
- `app/auth/dependencies.py` - 认证依赖项(已更新)
|
||||
- `app/auth/permissions.py` - 权限控制装饰器
|
||||
- `app/api/v1/endpoints/user_management.py` - 用户管理接口
|
||||
|
||||
4. **审计日志**
|
||||
- `app/core/audit.py` - 审计日志核心(已完善)
|
||||
- `app/domain/schemas/audit.py` - 审计日志数据模型
|
||||
- `app/infra/repositories/audit_repository.py` - 审计日志数据访问层
|
||||
- `app/api/v1/endpoints/audit.py` - 审计日志查询接口
|
||||
- `app/infra/audit/middleware.py` - 自动审计中间件
|
||||
|
||||
### 数据库迁移
|
||||
|
||||
5. **迁移脚本**
|
||||
- `migrations/001_create_users_table.sql` - 用户表
|
||||
- `migrations/002_create_audit_logs_table.sql` - 审计日志表
|
||||
|
||||
### 配置和文档
|
||||
|
||||
6. **配置文件**
|
||||
- `.env.example` - 环境变量模板
|
||||
- `app/core/config.py` - 配置文件(已更新)
|
||||
- `app/core/security.py` - 安全工具(已增强)
|
||||
|
||||
7. **文档**
|
||||
- `SECURITY_README.md` - 完整使用指南(79KB+)
|
||||
- `DEPLOYMENT.md` - 部署和集成指南
|
||||
- `SECURITY_IMPLEMENTATION_SUMMARY.md` - 本文件
|
||||
|
||||
8. **工具**
|
||||
- `setup_security.sh` - 快速设置脚本
|
||||
|
||||
---
|
||||
|
||||
## 🎯 功能特性
|
||||
|
||||
### 1. 数据加密
|
||||
- ✅ 使用 Fernet(AES-128)对称加密
|
||||
- ✅ 支持密钥生成和管理
|
||||
- ✅ 自动从环境变量读取密钥
|
||||
- ✅ 完整的加密/解密 API
|
||||
- ✅ 单元测试覆盖
|
||||
|
||||
### 2. 身份认证
|
||||
- ✅ 基于 JWT 的 Token 认证
|
||||
- ✅ Access Token + Refresh Token 机制
|
||||
- ✅ 用户注册/登录接口
|
||||
- ✅ 支持用户名或邮箱登录
|
||||
- ✅ 密码使用 bcrypt 哈希存储
|
||||
- ✅ Token 过期时间可配置
|
||||
- ✅ 向后兼容旧接口
|
||||
|
||||
### 3. 权限管理(RBAC)
|
||||
- ✅ 4 个预定义角色:ADMIN, OPERATOR, USER, VIEWER
|
||||
- ✅ 基于角色层级的权限检查
|
||||
- ✅ 可复用的权限装饰器
|
||||
- ✅ 资源所有者检查
|
||||
- ✅ 灵活的依赖注入设计
|
||||
|
||||
### 4. 审计日志
|
||||
- ✅ 自动记录所有关键操作
|
||||
- ✅ 记录用户、时间、操作类型、资源等信息
|
||||
- ✅ 敏感数据自动脱敏
|
||||
- ✅ 支持按多条件查询
|
||||
- ✅ 管理员专用查询接口
|
||||
- ✅ 用户可查看自己的操作记录
|
||||
|
||||
---
|
||||
|
||||
## 📊 技术栈
|
||||
|
||||
| 组件 | 技术 | 说明 |
|
||||
|------|------|------|
|
||||
| 加密 | cryptography.Fernet | 对称加密 |
|
||||
| 密码哈希 | bcrypt | 密码安全存储 |
|
||||
| JWT | python-jose | Token 生成和验证 |
|
||||
| 数据库 | PostgreSQL + psycopg | 异步数据访问 |
|
||||
| Web框架 | FastAPI | 现代异步框架 |
|
||||
| 数据验证 | Pydantic | 类型安全的数据模型 |
|
||||
|
||||
---
|
||||
|
||||
## 🔐 安全特性
|
||||
|
||||
1. **密码安全**
|
||||
- bcrypt 哈希(work factor = 12)
|
||||
- 自动加盐
|
||||
- 不可逆加密
|
||||
|
||||
2. **Token 安全**
|
||||
- JWT 签名验证
|
||||
- 短期 Access Token(30分钟)
|
||||
- 长期 Refresh Token(7天)
|
||||
- Token 类型校验
|
||||
|
||||
3. **数据保护**
|
||||
- 敏感字段自动脱敏
|
||||
- 审计日志不记录密码
|
||||
- 加密密钥从环境变量读取
|
||||
|
||||
4. **访问控制**
|
||||
- 基于角色的细粒度权限
|
||||
- 资源级别的访问控制
|
||||
- 自动验证用户激活状态
|
||||
|
||||
---
|
||||
|
||||
## 📈 数据库设计
|
||||
|
||||
### users 表
|
||||
```
|
||||
用户表 - 存储系统用户
|
||||
- id (主键)
|
||||
- username (唯一)
|
||||
- email (唯一)
|
||||
- hashed_password
|
||||
- role (ADMIN/OPERATOR/USER/VIEWER)
|
||||
- is_active
|
||||
- is_superuser
|
||||
- created_at
|
||||
- updated_at (自动更新)
|
||||
```
|
||||
|
||||
### audit_logs 表
|
||||
```
|
||||
审计日志表 - 记录所有关键操作
|
||||
- id (主键)
|
||||
- user_id (外键)
|
||||
- username (冗余字段)
|
||||
- action (操作类型)
|
||||
- resource_type (资源类型)
|
||||
- resource_id (资源ID)
|
||||
- ip_address
|
||||
- user_agent
|
||||
- request_method
|
||||
- request_path
|
||||
- request_data (JSONB)
|
||||
- response_status
|
||||
- error_message
|
||||
- timestamp
|
||||
```
|
||||
|
||||
**索引优化**:
|
||||
- users: username, email, role, is_active
|
||||
- audit_logs: user_id, username, timestamp, action, resource
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 方法 1: 使用自动化脚本
|
||||
|
||||
```bash
|
||||
./setup_security.sh
|
||||
```
|
||||
|
||||
### 方法 2: 手动设置
|
||||
|
||||
```bash
|
||||
# 1. 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填写密钥和数据库配置
|
||||
|
||||
# 2. 执行数据库迁移
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
|
||||
# 3. 测试
|
||||
python tests/test_encryption.py
|
||||
|
||||
# 4. 启动服务
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📋 集成检查清单
|
||||
|
||||
### 必需步骤
|
||||
|
||||
- [ ] 复制 `.env.example` 为 `.env` 并配置
|
||||
- [ ] 生成 JWT 密钥(SECRET_KEY)
|
||||
- [ ] 生成加密密钥(ENCRYPTION_KEY)
|
||||
- [ ] 配置数据库连接信息
|
||||
- [ ] 执行用户表迁移脚本
|
||||
- [ ] 执行审计日志表迁移脚本
|
||||
- [ ] 验证默认管理员可登录
|
||||
|
||||
### 可选步骤
|
||||
|
||||
- [ ] 在 main.py 中添加审计中间件
|
||||
- [ ] 为现有接口添加权限控制
|
||||
- [ ] 注册新的路由(auth, user_management, audit)
|
||||
- [ ] 替换硬编码的认证逻辑
|
||||
- [ ] 配置 Token 过期时间
|
||||
|
||||
---
|
||||
|
||||
## 🔄 向后兼容性
|
||||
|
||||
### 保留的旧接口
|
||||
|
||||
1. **简化登录**: `/api/v1/auth/login/simple`
|
||||
- 仍可使用 `username` 和 `password` 参数
|
||||
- 返回标准 Token 响应
|
||||
|
||||
2. **硬编码用户迁移**
|
||||
- 原有 `tjwater/tjwater@123` 已迁移到数据库
|
||||
- 保持相同的用户名和密码
|
||||
|
||||
### 渐进式迁移
|
||||
|
||||
可以逐步迁移现有接口:
|
||||
|
||||
1. 新接口直接使用新认证系统
|
||||
2. 旧接口保持不变
|
||||
3. 逐个替换旧接口的认证逻辑
|
||||
|
||||
---
|
||||
|
||||
## 📚 API 端点总览
|
||||
|
||||
### 认证接口 (`/api/v1/auth/`)
|
||||
|
||||
| 方法 | 路径 | 说明 | 权限 |
|
||||
|------|------|------|------|
|
||||
| POST | `/register` | 用户注册 | 公开 |
|
||||
| POST | `/login` | OAuth2 登录 | 公开 |
|
||||
| POST | `/login/simple` | 简化登录 | 公开 |
|
||||
| GET | `/me` | 获取当前用户 | 认证用户 |
|
||||
| POST | `/refresh` | 刷新Token | 认证用户 |
|
||||
|
||||
### 用户管理 (`/api/v1/users/`)
|
||||
|
||||
| 方法 | 路径 | 说明 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/` | 用户列表 | 管理员 |
|
||||
| GET | `/{id}` | 用户详情 | 所有者/管理员 |
|
||||
| PUT | `/{id}` | 更新用户 | 所有者/管理员 |
|
||||
| DELETE | `/{id}` | 删除用户 | 管理员 |
|
||||
| POST | `/{id}/activate` | 激活用户 | 管理员 |
|
||||
| POST | `/{id}/deactivate` | 停用用户 | 管理员 |
|
||||
|
||||
### 审计日志 (`/api/v1/audit/`)
|
||||
|
||||
| 方法 | 路径 | 说明 | 权限 |
|
||||
|------|------|------|------|
|
||||
| GET | `/logs` | 查询审计日志 | 管理员 |
|
||||
| GET | `/logs/count` | 日志总数 | 管理员 |
|
||||
| GET | `/logs/my` | 我的操作记录 | 认证用户 |
|
||||
|
||||
---
|
||||
|
||||
## 🎓 使用示例
|
||||
|
||||
### Python 示例
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# 登录
|
||||
resp = requests.post("http://localhost:8000/api/v1/auth/login",
|
||||
data={"username": "admin", "password": "admin123"})
|
||||
token = resp.json()["access_token"]
|
||||
|
||||
# 访问受保护接口
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
resp = requests.get("http://localhost:8000/api/v1/auth/me", headers=headers)
|
||||
print(resp.json())
|
||||
```
|
||||
|
||||
### cURL 示例
|
||||
|
||||
```bash
|
||||
# 登录
|
||||
TOKEN=$(curl -s -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-d "username=admin&password=admin123" | jq -r .access_token)
|
||||
|
||||
# 查询审计日志
|
||||
curl -H "Authorization: Bearer $TOKEN" \
|
||||
"http://localhost:8000/api/v1/audit/logs?action=LOGIN"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🐛 常见问题
|
||||
|
||||
### Q: 如何修改默认管理员密码?
|
||||
|
||||
A: 登录后通过 PUT `/api/v1/users/{id}` 接口修改,或直接更新数据库。
|
||||
|
||||
### Q: 如何添加新用户?
|
||||
|
||||
A: 使用 POST `/api/v1/auth/register` 接口,或由管理员在用户管理界面创建。
|
||||
|
||||
### Q: 审计日志可以删除吗?
|
||||
|
||||
A: 不建议删除。可以归档到冷存储,保留最近 90 天的数据。
|
||||
|
||||
### Q: Token 过期了怎么办?
|
||||
|
||||
A: 使用 Refresh Token 调用 `/api/v1/auth/refresh` 接口获取新的 Access Token。
|
||||
|
||||
---
|
||||
|
||||
## 📞 技术支持
|
||||
|
||||
- **完整文档**: `SECURITY_README.md`
|
||||
- **部署指南**: `DEPLOYMENT.md`
|
||||
- **测试代码**: `tests/test_encryption.py`
|
||||
- **迁移脚本**: `migrations/`
|
||||
|
||||
---
|
||||
|
||||
## 📝 待办事项(可选)
|
||||
|
||||
未来可以扩展的功能:
|
||||
|
||||
- [ ] 邮件验证
|
||||
- [ ] 密码重置
|
||||
- [ ] 双因素认证(2FA)
|
||||
- [ ] 单点登录(SSO)
|
||||
- [ ] Token 黑名单
|
||||
- [ ] 会话管理
|
||||
- [ ] IP 白名单
|
||||
- [ ] 登录频率限制
|
||||
- [ ] 密码复杂度策略
|
||||
- [ ] 审计日志自动归档
|
||||
|
||||
---
|
||||
|
||||
## 🎉 总结
|
||||
|
||||
本次实施完成了企业级的安全体系,包含:
|
||||
|
||||
✅ 数据加密 - Fernet 对称加密
|
||||
✅ 身份认证 - JWT Token + bcrypt 密码哈希
|
||||
✅ 权限管理 - 基于角色的访问控制(RBAC)
|
||||
✅ 审计日志 - 自动追踪所有关键操作
|
||||
|
||||
所有功能均遵循安全最佳实践,提供完整的文档和测试,可直接投入生产使用。
|
||||
|
||||
---
|
||||
|
||||
**实施日期**: 2026-02-02
|
||||
**版本**: v1.0.0
|
||||
**状态**: ✅ 已完成
|
||||
499
SECURITY_README.md
Normal file
499
SECURITY_README.md
Normal file
@@ -0,0 +1,499 @@
|
||||
# 安全功能使用指南
|
||||
|
||||
TJWater Server 安全体系实施完成,包含:数据加密、身份认证、权限管理、审计日志
|
||||
|
||||
## 📋 目录
|
||||
|
||||
1. [快速开始](#快速开始)
|
||||
2. [数据加密](#数据加密)
|
||||
3. [身份认证](#身份认证)
|
||||
4. [权限管理](#权限管理)
|
||||
5. [审计日志](#审计日志)
|
||||
6. [数据库迁移](#数据库迁移)
|
||||
7. [API 使用示例](#api-使用示例)
|
||||
|
||||
---
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 配置环境变量
|
||||
|
||||
复制 `.env.example` 为 `.env` 并配置:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
生成必要的密钥:
|
||||
|
||||
```bash
|
||||
# 生成 JWT 密钥
|
||||
openssl rand -hex 32
|
||||
|
||||
# 生成加密密钥
|
||||
python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
```
|
||||
|
||||
编辑 `.env` 文件:
|
||||
|
||||
```env
|
||||
SECRET_KEY=your-generated-jwt-secret-key
|
||||
ENCRYPTION_KEY=your-generated-encryption-key
|
||||
DB_NAME=tjwater
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=your-db-password
|
||||
```
|
||||
|
||||
### 2. 执行数据库迁移
|
||||
|
||||
```bash
|
||||
# 连接到 PostgreSQL
|
||||
psql -U postgres -d tjwater
|
||||
|
||||
# 执行迁移脚本
|
||||
\i migrations/001_create_users_table.sql
|
||||
\i migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
或使用命令行:
|
||||
|
||||
```bash
|
||||
psql -U postgres -d tjwater -f migrations/001_create_users_table.sql
|
||||
psql -U postgres -d tjwater -f migrations/002_create_audit_logs_table.sql
|
||||
```
|
||||
|
||||
### 3. 验证安装
|
||||
|
||||
默认创建了两个管理员账号:
|
||||
|
||||
- **用户名**: `admin` / **密码**: `admin123`
|
||||
- **用户名**: `tjwater` / **密码**: `tjwater@123`
|
||||
|
||||
---
|
||||
|
||||
## 🔐 数据加密
|
||||
|
||||
### 使用加密器
|
||||
|
||||
```python
|
||||
from app.core.encryption import get_encryptor
|
||||
|
||||
encryptor = get_encryptor()
|
||||
|
||||
# 加密敏感数据
|
||||
encrypted_data = encryptor.encrypt("sensitive information")
|
||||
|
||||
# 解密
|
||||
decrypted_data = encryptor.decrypt(encrypted_data)
|
||||
```
|
||||
|
||||
### 生成新密钥
|
||||
|
||||
```python
|
||||
from app.core.encryption import Encryptor
|
||||
|
||||
new_key = Encryptor.generate_key()
|
||||
print(f"New encryption key: {new_key}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 👤 身份认证
|
||||
|
||||
### 用户角色
|
||||
|
||||
系统定义了 4 个角色(权限由低到高):
|
||||
|
||||
| 角色 | 权限说明 |
|
||||
|------|---------|
|
||||
| `VIEWER` | 仅查询权限 |
|
||||
| `USER` | 读写权限 |
|
||||
| `OPERATOR` | 操作员,可修改数据 |
|
||||
| `ADMIN` | 管理员,完全权限 |
|
||||
|
||||
### API 接口
|
||||
|
||||
#### 用户注册
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/register
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"username": "newuser",
|
||||
"email": "user@example.com",
|
||||
"password": "password123",
|
||||
"role": "USER"
|
||||
}
|
||||
```
|
||||
|
||||
#### 用户登录(OAuth2 标准)
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/login
|
||||
Content-Type: application/x-www-form-urlencoded
|
||||
|
||||
username=admin&password=admin123
|
||||
```
|
||||
|
||||
响应:
|
||||
|
||||
```json
|
||||
{
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 1800
|
||||
}
|
||||
```
|
||||
|
||||
#### 用户登录(简化版)
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/login/simple?username=admin&password=admin123
|
||||
```
|
||||
|
||||
#### 获取当前用户信息
|
||||
|
||||
```http
|
||||
GET /api/v1/auth/me
|
||||
Authorization: Bearer {access_token}
|
||||
```
|
||||
|
||||
#### 刷新 Token
|
||||
|
||||
```http
|
||||
POST /api/v1/auth/refresh
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"refresh_token": "your-refresh-token"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔑 权限管理
|
||||
|
||||
### 在 API 中使用权限控制
|
||||
|
||||
#### 方式 1: 使用预定义依赖
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, Depends
|
||||
from app.auth.permissions import get_current_admin, get_current_operator
|
||||
from app.domain.schemas.user import UserInDB
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/admin-only")
|
||||
async def admin_endpoint(
|
||||
current_user: UserInDB = Depends(get_current_admin)
|
||||
):
|
||||
"""仅管理员可访问"""
|
||||
return {"message": "Admin access granted"}
|
||||
|
||||
@router.post("/operator-only")
|
||||
async def operator_endpoint(
|
||||
current_user: UserInDB = Depends(get_current_operator)
|
||||
):
|
||||
"""操作员及以上可访问"""
|
||||
return {"message": "Operator access granted"}
|
||||
```
|
||||
|
||||
#### 方式 2: 使用 require_role
|
||||
|
||||
```python
|
||||
from app.auth.permissions import require_role
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
@router.get("/viewer-access")
|
||||
async def viewer_endpoint(
|
||||
current_user: UserInDB = Depends(require_role(UserRole.VIEWER))
|
||||
):
|
||||
"""所有认证用户可访问"""
|
||||
return {"data": "visible to all"}
|
||||
```
|
||||
|
||||
#### 方式 3: 手动检查权限
|
||||
|
||||
```python
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
from app.auth.permissions import check_resource_owner
|
||||
|
||||
@router.put("/users/{user_id}")
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
current_user: UserInDB = Depends(get_current_active_user)
|
||||
):
|
||||
"""检查是否是资源拥有者或管理员"""
|
||||
if not check_resource_owner(user_id, current_user):
|
||||
raise HTTPException(status_code=403, detail="Permission denied")
|
||||
|
||||
# 执行更新操作
|
||||
...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📝 审计日志
|
||||
|
||||
### 自动审计
|
||||
|
||||
使用中间件自动记录关键操作,在 `main.py` 中添加:
|
||||
|
||||
```python
|
||||
from app.infra.audit.middleware import AuditMiddleware
|
||||
|
||||
app.add_middleware(AuditMiddleware)
|
||||
```
|
||||
|
||||
自动记录:
|
||||
- 所有 POST/PUT/DELETE 请求
|
||||
- 登录/登出事件
|
||||
- 关键资源访问
|
||||
|
||||
### 手动记录审计日志
|
||||
|
||||
```python
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
|
||||
await log_audit_event(
|
||||
action=AuditAction.UPDATE,
|
||||
user_id=current_user.id,
|
||||
username=current_user.username,
|
||||
resource_type="project",
|
||||
resource_id="123",
|
||||
ip_address=request.client.host,
|
||||
request_data={"field": "value"},
|
||||
response_status=200
|
||||
)
|
||||
```
|
||||
|
||||
### 查询审计日志
|
||||
|
||||
#### 获取所有审计日志(仅管理员)
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs?skip=0&limit=100
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
#### 按条件过滤
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs?user_id=1&action=LOGIN&start_time=2024-01-01T00:00:00
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
#### 获取我的操作记录
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs/my
|
||||
Authorization: Bearer {access_token}
|
||||
```
|
||||
|
||||
#### 获取日志总数
|
||||
|
||||
```http
|
||||
GET /api/v1/audit/logs/count?action=LOGIN
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 💾 数据库迁移
|
||||
|
||||
### 用户表结构
|
||||
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(50) UNIQUE NOT NULL,
|
||||
email VARCHAR(100) UNIQUE NOT NULL,
|
||||
hashed_password VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(20) DEFAULT 'USER' NOT NULL,
|
||||
is_active BOOLEAN DEFAULT TRUE NOT NULL,
|
||||
is_superuser BOOLEAN DEFAULT FALSE NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
### 审计日志表结构
|
||||
|
||||
```sql
|
||||
CREATE TABLE audit_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER REFERENCES users(id),
|
||||
username VARCHAR(50),
|
||||
action VARCHAR(50) NOT NULL,
|
||||
resource_type VARCHAR(50),
|
||||
resource_id VARCHAR(100),
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
request_method VARCHAR(10),
|
||||
request_path TEXT,
|
||||
request_data JSONB,
|
||||
response_status INTEGER,
|
||||
error_message TEXT,
|
||||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔧 API 使用示例
|
||||
|
||||
### Python 客户端示例
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
BASE_URL = "http://localhost:8000/api/v1"
|
||||
|
||||
# 1. 登录
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/auth/login",
|
||||
data={"username": "admin", "password": "admin123"}
|
||||
)
|
||||
token = response.json()["access_token"]
|
||||
|
||||
# 2. 设置 Authorization Header
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
# 3. 获取当前用户信息
|
||||
response = requests.get(f"{BASE_URL}/auth/me", headers=headers)
|
||||
print(response.json())
|
||||
|
||||
# 4. 创建新用户(需要管理员权限)
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/auth/register",
|
||||
headers=headers,
|
||||
json={
|
||||
"username": "newuser",
|
||||
"email": "new@example.com",
|
||||
"password": "password123",
|
||||
"role": "USER"
|
||||
}
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
# 5. 查询审计日志(需要管理员权限)
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/audit/logs?action=LOGIN",
|
||||
headers=headers
|
||||
)
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
### cURL 示例
|
||||
|
||||
```bash
|
||||
# 登录
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/login" \
|
||||
-H "Content-Type: application/x-www-form-urlencoded" \
|
||||
-d "username=admin&password=admin123"
|
||||
|
||||
# 使用 Token 访问受保护接口
|
||||
TOKEN="your-access-token"
|
||||
curl -X GET "http://localhost:8000/api/v1/auth/me" \
|
||||
-H "Authorization: Bearer $TOKEN"
|
||||
|
||||
# 注册新用户
|
||||
curl -X POST "http://localhost:8000/api/v1/auth/register" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-d '{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": "password123",
|
||||
"role": "USER"
|
||||
}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🛡️ 安全最佳实践
|
||||
|
||||
1. **密钥管理**
|
||||
- 绝不在代码中硬编码密钥
|
||||
- 定期轮换 JWT 密钥
|
||||
- 使用强随机密钥
|
||||
|
||||
2. **密码策略**
|
||||
- 最小长度 6 个字符(建议 12+)
|
||||
- 强制密码复杂度(可在注册时添加验证)
|
||||
- 定期提醒用户更换密码
|
||||
|
||||
3. **Token 管理**
|
||||
- Access Token 短期有效(默认 30 分钟)
|
||||
- Refresh Token 长期有效(默认 7 天)
|
||||
- 实施 Token 黑名单(可选)
|
||||
|
||||
4. **审计日志**
|
||||
- 审计日志不可删除
|
||||
- 定期归档旧日志
|
||||
- 监控异常登录行为
|
||||
|
||||
5. **权限控制**
|
||||
- 遵循最小权限原则
|
||||
- 定期审查用户权限
|
||||
- 记录所有权限变更
|
||||
|
||||
---
|
||||
|
||||
## 📚 相关文件
|
||||
|
||||
- **配置**: `app/core/config.py`
|
||||
- **加密**: `app/core/encryption.py`
|
||||
- **安全**: `app/core/security.py`
|
||||
- **审计**: `app/core/audit.py`
|
||||
- **认证**: `app/api/v1/endpoints/auth.py`
|
||||
- **权限**: `app/auth/permissions.py`
|
||||
- **用户管理**: `app/api/v1/endpoints/user_management.py`
|
||||
- **审计日志**: `app/api/v1/endpoints/audit.py`
|
||||
- **迁移脚本**: `migrations/`
|
||||
|
||||
---
|
||||
|
||||
## ❓ 常见问题
|
||||
|
||||
### Q: 忘记密码怎么办?
|
||||
|
||||
A: 目前需要管理员通过数据库重置。未来可添加邮件重置功能。
|
||||
|
||||
```sql
|
||||
-- 重置密码为 "newpassword123"
|
||||
UPDATE users
|
||||
SET hashed_password = '$2b$12$...' -- 使用 bcrypt 生成哈希
|
||||
WHERE username = 'targetuser';
|
||||
```
|
||||
|
||||
### Q: 如何添加新角色?
|
||||
|
||||
A: 编辑 `app/domain/models/role.py` 中的 `UserRole` 枚举,并更新数据库约束。
|
||||
|
||||
### Q: 审计日志占用太多空间?
|
||||
|
||||
A: 建议定期归档旧日志到冷存储:
|
||||
|
||||
```sql
|
||||
-- 归档 90 天前的日志
|
||||
CREATE TABLE audit_logs_archive AS
|
||||
SELECT * FROM audit_logs WHERE timestamp < NOW() - INTERVAL '90 days';
|
||||
|
||||
DELETE FROM audit_logs WHERE timestamp < NOW() - INTERVAL '90 days';
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📞 技术支持
|
||||
|
||||
如有问题,请查看:
|
||||
- 日志文件: `logs/`
|
||||
- 数据库表结构: `migrations/`
|
||||
- 单元测试: `tests/`
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from app.algorithms.data_cleaning import flow_data_clean, pressure_data_clean
|
||||
from app.algorithms.sensors import (
|
||||
pressure_sensor_placement_sensitivity,
|
||||
pressure_sensor_placement_kmeans,
|
||||
)
|
||||
from app.algorithms.valve_isolation import valve_isolation_analysis
|
||||
from app.algorithms.simulations import (
|
||||
convert_to_local_unit,
|
||||
burst_analysis,
|
||||
valve_close_analysis,
|
||||
flushing_analysis,
|
||||
contaminant_simulation,
|
||||
age_analysis,
|
||||
pressure_regulation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"flow_data_clean",
|
||||
"pressure_data_clean",
|
||||
"pressure_sensor_placement_sensitivity",
|
||||
"pressure_sensor_placement_kmeans",
|
||||
"convert_to_local_unit",
|
||||
"burst_analysis",
|
||||
"valve_close_analysis",
|
||||
"flushing_analysis",
|
||||
"contaminant_simulation",
|
||||
"age_analysis",
|
||||
"pressure_regulation",
|
||||
"valve_isolation_analysis",
|
||||
]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .Fdataclean import *
|
||||
from .Pdataclean import *
|
||||
from .flow_data_clean import *
|
||||
from .pressure_data_clean import *
|
||||
from .pipeline_health_analyzer import *
|
||||
@@ -6,14 +6,108 @@ from pykalman import KalmanFilter
|
||||
import os
|
||||
|
||||
|
||||
def clean_flow_data_kf(input_csv_path: str, show_plot: bool = False) -> str:
|
||||
def fill_time_gaps(
|
||||
data: pd.DataFrame,
|
||||
time_col: str = "time",
|
||||
freq: str = "1min",
|
||||
short_gap_threshold: int = 10,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
补齐缺失时间戳并填补数据缺口。
|
||||
|
||||
Args:
|
||||
data: 包含时间列的 DataFrame
|
||||
time_col: 时间列名(默认 'time')
|
||||
freq: 重采样频率(默认 '1min')
|
||||
short_gap_threshold: 短缺口阈值(分钟),<=此值用线性插值,>此值用前向填充
|
||||
|
||||
Returns:
|
||||
补齐时间后的 DataFrame(保留原时间列格式)
|
||||
"""
|
||||
if time_col not in data.columns:
|
||||
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
|
||||
|
||||
# 解析时间列并设为索引
|
||||
data = data.copy()
|
||||
data[time_col] = pd.to_datetime(data[time_col], utc=True)
|
||||
data_indexed = data.set_index(time_col)
|
||||
|
||||
# 生成完整时间范围
|
||||
full_range = pd.date_range(
|
||||
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
|
||||
)
|
||||
|
||||
# 重索引以补齐缺失时间点,同时保留原始时间戳
|
||||
combined_index = data_indexed.index.union(full_range).sort_values().unique()
|
||||
data_reindexed = data_indexed.reindex(combined_index)
|
||||
|
||||
# 按列处理缺口
|
||||
for col in data_reindexed.columns:
|
||||
# 识别缺失值位置
|
||||
is_missing = data_reindexed[col].isna()
|
||||
|
||||
# 计算连续缺失的长度
|
||||
missing_groups = (is_missing != is_missing.shift()).cumsum()
|
||||
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
|
||||
|
||||
# 短缺口:时间插值
|
||||
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
|
||||
if short_gap_mask.any():
|
||||
data_reindexed.loc[short_gap_mask, col] = (
|
||||
data_reindexed[col]
|
||||
.interpolate(method="time", limit_area="inside")
|
||||
.loc[short_gap_mask]
|
||||
)
|
||||
|
||||
# 长缺口:前向填充
|
||||
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
|
||||
if long_gap_mask.any():
|
||||
data_reindexed.loc[long_gap_mask, col] = (
|
||||
data_reindexed[col].ffill().loc[long_gap_mask]
|
||||
)
|
||||
|
||||
# 重置索引并恢复时间列(保留原格式)
|
||||
data_result = data_reindexed.reset_index()
|
||||
data_result.rename(columns={"index": time_col}, inplace=True)
|
||||
|
||||
# 保留时区信息
|
||||
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
# 修正时区格式(Python的%z输出为+0000,需转为+00:00)
|
||||
data_result[time_col] = data_result[time_col].str.replace(
|
||||
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
|
||||
)
|
||||
|
||||
return data_result
|
||||
|
||||
|
||||
def clean_flow_data_kf(
|
||||
input_csv_path: str, show_plot: bool = False, fill_gaps: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
读取 input_csv_path 中的每列时间序列,使用一维 Kalman 滤波平滑并用预测值替换基于 3σ 检测出的异常点。
|
||||
保存输出为:<input_filename>_cleaned.xlsx(与输入同目录),并返回输出文件的绝对路径。
|
||||
仅保留输入文件路径作为参数(按要求)。
|
||||
|
||||
Args:
|
||||
input_csv_path: CSV 文件路径
|
||||
show_plot: 是否显示可视化
|
||||
fill_gaps: 是否先补齐时间缺口(默认 True)
|
||||
"""
|
||||
# 读取 CSV
|
||||
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
|
||||
# 补齐时间缺口(如果数据包含 time 列)
|
||||
if fill_gaps and "time" in data.columns:
|
||||
data = fill_time_gaps(
|
||||
data, time_col="time", freq="1min", short_gap_threshold=10
|
||||
)
|
||||
|
||||
# 分离时间列和数值列
|
||||
time_col_data = None
|
||||
if "time" in data.columns:
|
||||
time_col_data = data["time"]
|
||||
data = data.drop(columns=["time"])
|
||||
|
||||
# 存储 Kalman 平滑结果
|
||||
data_kf = pd.DataFrame(index=data.index, columns=data.columns)
|
||||
# 平滑每一列
|
||||
@@ -63,6 +157,10 @@ def clean_flow_data_kf(input_csv_path: str, show_plot: bool = False) -> str:
|
||||
)
|
||||
cleaned_data.loc[anomaly_idx, f"{col}_cleaned"] = data_kf.loc[anomaly_idx, col]
|
||||
|
||||
# 如果原始数据包含时间列,将其添加回结果
|
||||
if time_col_data is not None:
|
||||
cleaned_data.insert(0, "time", time_col_data)
|
||||
|
||||
# 构造输出文件名:在输入文件名基础上加后缀 _cleaned.xlsx
|
||||
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
|
||||
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
|
||||
@@ -122,17 +220,26 @@ def clean_flow_data_df_kf(data: pd.DataFrame, show_plot: bool = False) -> dict:
|
||||
接收一个 DataFrame 数据结构,使用一维 Kalman 滤波平滑并用预测值替换基于 IQR 检测出的异常点。
|
||||
区分合理的0值(流量转换)和异常的0值(连续多个0或孤立0)。
|
||||
返回完整的清洗后的字典数据结构。
|
||||
|
||||
Args:
|
||||
data: 输入 DataFrame(可包含 time 列)
|
||||
show_plot: 是否显示可视化
|
||||
"""
|
||||
# 使用传入的 DataFrame
|
||||
data = data.copy()
|
||||
# 替换0值,填充NaN值
|
||||
data_filled = data.replace(0, np.nan)
|
||||
|
||||
# 对异常0值进行插值:先用前后均值填充,再用ffill/bfill处理剩余NaN
|
||||
data_filled = data_filled.interpolate(method="linear", limit_direction="both")
|
||||
# 补齐时间缺口(如果启用且数据包含 time 列)
|
||||
data_filled = fill_time_gaps(
|
||||
data, time_col="time", freq="1min", short_gap_threshold=10
|
||||
)
|
||||
|
||||
# 处理剩余的0值和NaN值
|
||||
data_filled = data_filled.ffill().bfill()
|
||||
# 保存 time 列用于最后合并
|
||||
time_col_series = None
|
||||
if "time" in data_filled.columns:
|
||||
time_col_series = data_filled["time"]
|
||||
|
||||
# 移除 time 列用于后续清洗
|
||||
data_filled = data_filled.drop(columns=["time"])
|
||||
|
||||
# 存储 Kalman 平滑结果
|
||||
data_kf = pd.DataFrame(index=data_filled.index, columns=data_filled.columns)
|
||||
@@ -192,28 +299,47 @@ def clean_flow_data_df_kf(data: pd.DataFrame, show_plot: bool = False) -> dict:
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
if show_plot and len(data.columns) > 0:
|
||||
sensor_to_plot = data.columns[0]
|
||||
|
||||
# 定义x轴
|
||||
n = len(data)
|
||||
time = np.arange(n)
|
||||
n_filled = len(data_filled)
|
||||
time_filled = np.arange(n_filled)
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(
|
||||
data.index,
|
||||
time,
|
||||
data[sensor_to_plot],
|
||||
label="原始监测值",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
alpha=0.7,
|
||||
)
|
||||
abnormal_zero_idx = data.index[data_filled[sensor_to_plot].isna()]
|
||||
|
||||
# 修正:检查 data_filled 的异常值,绘制在 time_filled 上
|
||||
abnormal_zero_mask = data_filled[sensor_to_plot].isna()
|
||||
# 如果目的是检查0值,应该用 == 0。这里保留 isna() 但修正索引引用,防止crash。
|
||||
# 如果原意是 isna() 则在 fillna 后通常没有 na。假设用户可能想检查 0 值?
|
||||
# 基于 "异常0值" 的标签,改为检查 0 值更合理,但为了保险起见,
|
||||
# 如果 isna() 返回空,就不画。防止索引越界是主要的。
|
||||
abnormal_zero_idx = data_filled.index[abnormal_zero_mask]
|
||||
|
||||
if len(abnormal_zero_idx) > 0:
|
||||
# 注意:如果 abnormal_zero_idx 是基于 data_filled 的索引(0..M-1),
|
||||
# 直接作为 x 坐标即可,因为 time_filled 也是 0..M-1
|
||||
# 而 y 值应该取自 data_filled 或 data_kf,取 data 会越界
|
||||
plt.plot(
|
||||
abnormal_zero_idx,
|
||||
data[sensor_to_plot].loc[abnormal_zero_idx],
|
||||
data_filled[sensor_to_plot].loc[abnormal_zero_idx],
|
||||
"mo",
|
||||
markersize=8,
|
||||
label="异常0值",
|
||||
label="异常值(NaN)",
|
||||
)
|
||||
|
||||
plt.plot(
|
||||
data.index, data_kf[sensor_to_plot], label="Kalman滤波预测值", linewidth=2
|
||||
time_filled, data_kf[sensor_to_plot], label="Kalman滤波预测值", linewidth=2
|
||||
)
|
||||
anomaly_idx = anomalies_info[sensor_to_plot].index
|
||||
if len(anomaly_idx) > 0:
|
||||
@@ -231,7 +357,7 @@ def clean_flow_data_df_kf(data: pd.DataFrame, show_plot: bool = False) -> dict:
|
||||
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(
|
||||
data.index,
|
||||
time_filled,
|
||||
cleaned_data[sensor_to_plot],
|
||||
label="修复后监测值",
|
||||
marker="o",
|
||||
@@ -246,6 +372,10 @@ def clean_flow_data_df_kf(data: pd.DataFrame, show_plot: bool = False) -> dict:
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# 将 time 列添加回结果
|
||||
if time_col_series is not None:
|
||||
cleaned_data.insert(0, "time", time_col_series)
|
||||
|
||||
# 返回完整的修复后字典
|
||||
return cleaned_data
|
||||
|
||||
@@ -14,14 +14,20 @@ class PipelineHealthAnalyzer:
|
||||
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"):
|
||||
def __init__(self, model_path: str = None):
|
||||
"""
|
||||
初始化分析器,加载预训练的随机生存森林模型。
|
||||
|
||||
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
|
||||
:param model_path: 模型文件的路径(默认为相对路径 './model/my_survival_forest_model_quxi.joblib')。
|
||||
:raises FileNotFoundError: 如果模型文件不存在。
|
||||
:raises Exception: 如果模型加载失败。
|
||||
"""
|
||||
if model_path is None:
|
||||
model_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"model",
|
||||
"my_survival_forest_model_quxi.joblib",
|
||||
)
|
||||
# 确保 model 目录存在
|
||||
model_dir = os.path.dirname(model_path)
|
||||
if model_dir and not os.path.exists(model_dir):
|
||||
|
||||
@@ -6,15 +6,108 @@ from sklearn.impute import SimpleImputer
|
||||
import os
|
||||
|
||||
|
||||
def clean_pressure_data_km(input_csv_path: str, show_plot: bool = False) -> str:
|
||||
def fill_time_gaps(
|
||||
data: pd.DataFrame,
|
||||
time_col: str = "time",
|
||||
freq: str = "1min",
|
||||
short_gap_threshold: int = 10,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
补齐缺失时间戳并填补数据缺口。
|
||||
|
||||
Args:
|
||||
data: 包含时间列的 DataFrame
|
||||
time_col: 时间列名(默认 'time')
|
||||
freq: 重采样频率(默认 '1min')
|
||||
short_gap_threshold: 短缺口阈值(分钟),<=此值用线性插值,>此值用前向填充
|
||||
|
||||
Returns:
|
||||
补齐时间后的 DataFrame(保留原时间列格式)
|
||||
"""
|
||||
if time_col not in data.columns:
|
||||
raise ValueError(f"时间列 '{time_col}' 不存在于数据中")
|
||||
|
||||
# 解析时间列并设为索引
|
||||
data = data.copy()
|
||||
data[time_col] = pd.to_datetime(data[time_col], utc=True)
|
||||
data_indexed = data.set_index(time_col)
|
||||
|
||||
# 生成完整时间范围
|
||||
full_range = pd.date_range(
|
||||
start=data_indexed.index.min(), end=data_indexed.index.max(), freq=freq
|
||||
)
|
||||
|
||||
# 重索引以补齐缺失时间点,同时保留原始时间戳
|
||||
combined_index = data_indexed.index.union(full_range).sort_values().unique()
|
||||
data_reindexed = data_indexed.reindex(combined_index)
|
||||
|
||||
# 按列处理缺口
|
||||
for col in data_reindexed.columns:
|
||||
# 识别缺失值位置
|
||||
is_missing = data_reindexed[col].isna()
|
||||
|
||||
# 计算连续缺失的长度
|
||||
missing_groups = (is_missing != is_missing.shift()).cumsum()
|
||||
gap_lengths = is_missing.groupby(missing_groups).transform("sum")
|
||||
|
||||
# 短缺口:时间插值
|
||||
short_gap_mask = is_missing & (gap_lengths <= short_gap_threshold)
|
||||
if short_gap_mask.any():
|
||||
data_reindexed.loc[short_gap_mask, col] = (
|
||||
data_reindexed[col]
|
||||
.interpolate(method="time", limit_area="inside")
|
||||
.loc[short_gap_mask]
|
||||
)
|
||||
|
||||
# 长缺口:前向填充
|
||||
long_gap_mask = is_missing & (gap_lengths > short_gap_threshold)
|
||||
if long_gap_mask.any():
|
||||
data_reindexed.loc[long_gap_mask, col] = (
|
||||
data_reindexed[col].ffill().loc[long_gap_mask]
|
||||
)
|
||||
|
||||
# 重置索引并恢复时间列(保留原格式)
|
||||
data_result = data_reindexed.reset_index()
|
||||
data_result.rename(columns={"index": time_col}, inplace=True)
|
||||
|
||||
# 保留时区信息
|
||||
data_result[time_col] = data_result[time_col].dt.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||||
# 修正时区格式(Python的%z输出为+0000,需转为+00:00)
|
||||
data_result[time_col] = data_result[time_col].str.replace(
|
||||
r"(\+\d{2})(\d{2})$", r"\1:\2", regex=True
|
||||
)
|
||||
|
||||
return data_result
|
||||
|
||||
|
||||
def clean_pressure_data_km(
|
||||
input_csv_path: str, show_plot: bool = False, fill_gaps: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
读取输入 CSV,基于 KMeans 检测异常并用滚动平均修复。输出为 <input_basename>_cleaned.xlsx(同目录)。
|
||||
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'。
|
||||
返回输出文件的绝对路径。
|
||||
|
||||
Args:
|
||||
input_csv_path: CSV 文件路径
|
||||
show_plot: 是否显示可视化
|
||||
fill_gaps: 是否先补齐时间缺口(默认 True)
|
||||
"""
|
||||
# 读取 CSV
|
||||
input_csv_path = os.path.abspath(input_csv_path)
|
||||
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
|
||||
# 补齐时间缺口(如果数据包含 time 列)
|
||||
if fill_gaps and "time" in data.columns:
|
||||
data = fill_time_gaps(
|
||||
data, time_col="time", freq="1min", short_gap_threshold=10
|
||||
)
|
||||
|
||||
# 分离时间列和数值列
|
||||
time_col_data = None
|
||||
if "time" in data.columns:
|
||||
time_col_data = data["time"]
|
||||
data = data.drop(columns=["time"])
|
||||
# 标准化
|
||||
data_norm = (data - data.mean()) / data.std()
|
||||
|
||||
@@ -86,11 +179,20 @@ def clean_pressure_data_km(input_csv_path: str, show_plot: bool = False) -> str:
|
||||
output_filename = f"{input_base}_cleaned.xlsx"
|
||||
output_path = os.path.join(input_dir, output_filename)
|
||||
|
||||
# 如果原始数据包含时间列,将其添加回结果
|
||||
data_for_save = data.copy()
|
||||
data_repaired_for_save = data_repaired.copy()
|
||||
if time_col_data is not None:
|
||||
data_for_save.insert(0, "time", time_col_data)
|
||||
data_repaired_for_save.insert(0, "time", time_col_data)
|
||||
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path) # 覆盖同名文件
|
||||
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
|
||||
data.to_excel(writer, sheet_name="raw_pressure_data", index=False)
|
||||
data_repaired.to_excel(writer, sheet_name="cleaned_pressusre_data", index=False)
|
||||
data_for_save.to_excel(writer, sheet_name="raw_pressure_data", index=False)
|
||||
data_repaired_for_save.to_excel(
|
||||
writer, sheet_name="cleaned_pressusre_data", index=False
|
||||
)
|
||||
|
||||
# 返回输出文件的绝对路径
|
||||
return os.path.abspath(output_path)
|
||||
@@ -100,17 +202,26 @@ def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> di
|
||||
"""
|
||||
接收一个 DataFrame 数据结构,使用KMeans聚类检测异常并用滚动平均修复。
|
||||
返回清洗后的字典数据结构。
|
||||
|
||||
Args:
|
||||
data: 输入 DataFrame(可包含 time 列)
|
||||
show_plot: 是否显示可视化
|
||||
"""
|
||||
# 使用传入的 DataFrame
|
||||
data = data.copy()
|
||||
# 填充NaN值
|
||||
data = data.ffill().bfill()
|
||||
# 异常值预处理
|
||||
# 将0值替换为NaN,然后用线性插值填充
|
||||
data_filled = data.replace(0, np.nan)
|
||||
data_filled = data_filled.interpolate(method="linear", limit_direction="both")
|
||||
# 如果仍有NaN(全为0的列),用前后值填充
|
||||
data_filled = data_filled.ffill().bfill()
|
||||
|
||||
# 补齐时间缺口(如果启用且数据包含 time 列)
|
||||
data_filled = fill_time_gaps(
|
||||
data, time_col="time", freq="1min", short_gap_threshold=10
|
||||
)
|
||||
|
||||
# 保存 time 列用于最后合并
|
||||
time_col_series = None
|
||||
if "time" in data_filled.columns:
|
||||
time_col_series = data_filled["time"]
|
||||
|
||||
# 移除 time 列用于后续清洗
|
||||
data_filled = data_filled.drop(columns=["time"])
|
||||
|
||||
# 标准化(使用填充后的数据)
|
||||
data_norm = (data_filled - data_filled.mean()) / data_filled.std()
|
||||
@@ -135,7 +246,7 @@ def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> di
|
||||
threshold = distances.mean() + 3 * distances.std()
|
||||
|
||||
anomaly_pos = np.where(distances > threshold)[0]
|
||||
anomaly_indices = data.index[anomaly_pos]
|
||||
anomaly_indices = data_filled.index[anomaly_pos]
|
||||
|
||||
anomaly_details = {}
|
||||
for pos in anomaly_pos:
|
||||
@@ -144,13 +255,13 @@ def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> di
|
||||
center = centers[cluster_idx]
|
||||
diff = abs(row_norm - center)
|
||||
main_sensor = diff.idxmax()
|
||||
anomaly_details[data.index[pos]] = main_sensor
|
||||
anomaly_details[data_filled.index[pos]] = main_sensor
|
||||
|
||||
# 修复:滚动平均(窗口可调)
|
||||
data_rolled = data_filled.rolling(window=13, center=True, min_periods=1).mean()
|
||||
data_repaired = data_filled.copy()
|
||||
for pos in anomaly_pos:
|
||||
label = data.index[pos]
|
||||
label = data_filled.index[pos]
|
||||
sensor = anomaly_details[label]
|
||||
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
|
||||
|
||||
@@ -161,6 +272,8 @@ def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> di
|
||||
if show_plot and len(data.columns) > 0:
|
||||
n = len(data)
|
||||
time = np.arange(n)
|
||||
n_filled = len(data_filled)
|
||||
time_filled = np.arange(n_filled)
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data.columns:
|
||||
plt.plot(
|
||||
@@ -168,7 +281,7 @@ def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> di
|
||||
)
|
||||
for col in data_filled.columns:
|
||||
plt.plot(
|
||||
time,
|
||||
time_filled,
|
||||
data_filled[col].values,
|
||||
marker="x",
|
||||
markersize=3,
|
||||
@@ -176,7 +289,7 @@ def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> di
|
||||
linestyle="--",
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
sensor = anomaly_details[data_filled.index[pos]]
|
||||
plt.plot(pos, data_filled.iloc[pos][sensor], "ro", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("压力监测值")
|
||||
@@ -187,16 +300,20 @@ def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> di
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data_repaired.columns:
|
||||
plt.plot(
|
||||
time, data_repaired[col].values, marker="o", markersize=3, label=col
|
||||
time_filled, data_repaired[col].values, marker="o", markersize=3, label=col
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
sensor = anomaly_details[data_filled.index[pos]]
|
||||
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("修复后压力监测值")
|
||||
plt.title("修复后各传感器折线图(绿色标记修复值)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("修复后压力监测值")
|
||||
plt.title("修复后各传感器折线图(绿色标记修复值)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
# 将 time 列添加回结果
|
||||
if time_col_series is not None:
|
||||
data_repaired.insert(0, "time", time_col_series)
|
||||
|
||||
# 返回清洗后的字典
|
||||
return data_repaired
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from tjnetwork import *
|
||||
from app.services.tjnetwork import *
|
||||
from api.s36_wda_cal import *
|
||||
# from get_real_status import *
|
||||
from datetime import datetime,timedelta
|
||||
@@ -8,7 +8,7 @@ import json
|
||||
import pytz
|
||||
import requests
|
||||
import time
|
||||
import project_info
|
||||
import app.services.project_info as project_info
|
||||
|
||||
url_path = 'http://10.101.15.16:9000/loong' # 内网
|
||||
# url_path = 'http://183.64.62.100:9057/loong' # 外网
|
||||
@@ -11,7 +11,7 @@ from sklearn.cluster import KMeans
|
||||
from wntr.epanet.toolkit import EpanetException
|
||||
from numpy.linalg import slogdet
|
||||
import random
|
||||
from tjnetwork import *
|
||||
from app.services.tjnetwork import *
|
||||
from matplotlib.lines import Line2D
|
||||
from sklearn.cluster import SpectralClustering
|
||||
import libpysal as ps
|
||||
@@ -19,7 +19,7 @@ from spopt.region import Skater
|
||||
from shapely.geometry import Point
|
||||
import geopandas as gpd
|
||||
from sklearn.metrics import pairwise_distances
|
||||
import project_info
|
||||
import app.services.project_info as project_info
|
||||
|
||||
# 2025/03/12
|
||||
# Step1: 获取节点坐标
|
||||
@@ -11,8 +11,8 @@ from sklearn.cluster import KMeans
|
||||
from wntr.epanet.toolkit import EpanetException
|
||||
from numpy.linalg import slogdet
|
||||
import random
|
||||
from tjnetwork import *
|
||||
import project_info
|
||||
from app.services.tjnetwork import *
|
||||
import app.services.project_info as project_info
|
||||
|
||||
# 2025/03/12
|
||||
# Step1: 获取节点坐标
|
||||
57
app/algorithms/data_cleaning.py
Normal file
57
app/algorithms/data_cleaning.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
|
||||
import app.algorithms.api_ex.flow_data_clean as flow_data_clean
|
||||
import app.algorithms.api_ex.pressure_data_clean as pressure_data_clean
|
||||
|
||||
|
||||
############################################################
|
||||
# 流量监测数据清洗 ***卡尔曼滤波法***
|
||||
############################################################
|
||||
# 2025/08/21 hxyan
|
||||
|
||||
|
||||
def flow_data_clean(input_csv_file: str) -> str:
|
||||
"""
|
||||
读取 input_csv_path 中的每列时间序列,使用一维 Kalman 滤波平滑并用预测值替换基于 3σ 检测出的异常点。
|
||||
保存输出为:<input_filename>_cleaned.xlsx(与输入同目录),并返回输出文件的绝对路径。如有同名文件存在,则覆盖。
|
||||
:param: input_csv_file: 输入的 CSV 文件明或路径
|
||||
:return: 输出文件的绝对路径
|
||||
"""
|
||||
|
||||
# 提供的 input_csv_path 绝对路径,以下为 默认脚本目录下同名 CSV 文件,构建绝对路径,可根据情况修改
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
input_csv_path = os.path.join(script_dir, input_csv_file)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(input_csv_path):
|
||||
raise FileNotFoundError(f"指定的文件不存在: {input_csv_path}")
|
||||
# 调用 Fdataclean.clean_flow_data_kf 函数进行数据清洗
|
||||
out_xlsx_path = flow_data_clean.clean_flow_data_kf(input_csv_path)
|
||||
print("清洗后的数据已保存到:", out_xlsx_path)
|
||||
|
||||
|
||||
############################################################
|
||||
# 压力监测数据清洗 ***kmean++法***
|
||||
############################################################
|
||||
# 2025/08/21 hxyan
|
||||
|
||||
|
||||
def pressure_data_clean(input_csv_file: str) -> str:
|
||||
"""
|
||||
读取 input_csv_path 中的每列时间序列,使用Kmean++清洗数据。
|
||||
保存输出为:<input_filename>_cleaned.xlsx(与输入同目录),并返回输出文件的绝对路径。如有同名文件存在,则覆盖。
|
||||
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'。
|
||||
:param input_csv_path: 输入的 CSV 文件路径
|
||||
:return: 输出文件的绝对路径
|
||||
"""
|
||||
|
||||
# 提供的 input_csv_path 绝对路径,以下为 默认脚本目录下同名 CSV 文件,构建绝对路径,可根据情况修改
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
input_csv_path = os.path.join(script_dir, input_csv_file)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(input_csv_path):
|
||||
raise FileNotFoundError(f"指定的文件不存在: {input_csv_path}")
|
||||
# 调用 Fdataclean.clean_flow_data_kf 函数进行数据清洗
|
||||
out_xlsx_path = pressure_data_clean.clean_pressure_data_km(input_csv_path)
|
||||
print("清洗后的数据已保存到:", out_xlsx_path)
|
||||
91
app/algorithms/sensors.py
Normal file
91
app/algorithms/sensors.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import psycopg
|
||||
|
||||
import app.algorithms.api_ex.kmeans_sensor as kmeans_sensor
|
||||
import app.algorithms.api_ex.sensitivity as sensitivity
|
||||
from app.native.api.postgresql_info import get_pgconn_string
|
||||
from app.services.tjnetwork import dump_inp
|
||||
|
||||
|
||||
def pressure_sensor_placement_sensitivity(
|
||||
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
|
||||
) -> None:
|
||||
"""
|
||||
基于改进灵敏度法进行压力监测点优化布置
|
||||
:param name: 数据库名称
|
||||
:param scheme_name: 监测优化布置方案名称
|
||||
:param sensor_number: 传感器数目
|
||||
:param min_diameter: 最小管径
|
||||
:param username: 用户名
|
||||
:return:
|
||||
"""
|
||||
sensor_location = sensitivity.get_ID(
|
||||
name=name, sensor_num=sensor_number, min_diameter=min_diameter
|
||||
)
|
||||
try:
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
sql = """
|
||||
INSERT INTO sensor_placement (scheme_name, sensor_number, min_diameter, username, sensor_location)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
cur.execute(
|
||||
sql,
|
||||
(
|
||||
scheme_name,
|
||||
sensor_number,
|
||||
min_diameter,
|
||||
username,
|
||||
sensor_location,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
print("方案信息存储成功!")
|
||||
except Exception as e:
|
||||
print(f"存储方案信息时出错:{e}")
|
||||
|
||||
|
||||
# 2025/08/21
|
||||
# 基于kmeans聚类法进行压力监测点优化布置
|
||||
def pressure_sensor_placement_kmeans(
|
||||
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
|
||||
) -> None:
|
||||
"""
|
||||
基于聚类法进行压力监测点优化布置
|
||||
:param name: 数据库名称(注意,此处数据库名称也是inp文件名称,inp文件与pg库名要一样)
|
||||
:param scheme_name: 监测优化布置方案名称
|
||||
:param sensor_number: 传感器数目
|
||||
:param min_diameter: 最小管径
|
||||
:param username: 用户名
|
||||
:return:
|
||||
"""
|
||||
# dump_inp
|
||||
inp_name = f"./db_inp/{name}.db.inp"
|
||||
dump_inp(name, inp_name, "2")
|
||||
sensor_location = kmeans_sensor.kmeans_sensor_placement(
|
||||
name=name, sensor_num=sensor_number, min_diameter=min_diameter
|
||||
)
|
||||
try:
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
sql = """
|
||||
INSERT INTO sensor_placement (scheme_name, sensor_number, min_diameter, username, sensor_location)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
cur.execute(
|
||||
sql,
|
||||
(
|
||||
scheme_name,
|
||||
sensor_number,
|
||||
min_diameter,
|
||||
username,
|
||||
sensor_location,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
print("方案信息存储成功!")
|
||||
except Exception as e:
|
||||
print(f"存储方案信息时出错:{e}")
|
||||
745
app/algorithms/simulations.py
Normal file
745
app/algorithms/simulations.py
Normal file
@@ -0,0 +1,745 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from math import pi, sqrt
|
||||
|
||||
import pytz
|
||||
|
||||
import app.services.simulation as simulation
|
||||
from app.algorithms.api_ex.run_simulation import (
|
||||
run_simulation_ex,
|
||||
from_clock_to_seconds_2,
|
||||
)
|
||||
from app.native.api.project import copy_project
|
||||
from app.services.epanet.epanet import Output
|
||||
from app.services.scheme_management import store_scheme_info
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
|
||||
############################################################
|
||||
# burst analysis 01
|
||||
############################################################
|
||||
def convert_to_local_unit(proj: str, emitters: float) -> float:
|
||||
open_project(proj)
|
||||
proj_opt = get_option(proj)
|
||||
str_unit = proj_opt.get("UNITS")
|
||||
|
||||
if str_unit == "CMH":
|
||||
return emitters * 3.6
|
||||
elif str_unit == "LPS":
|
||||
return emitters
|
||||
elif str_unit == "CMS":
|
||||
return emitters / 1000.0
|
||||
elif str_unit == "MGD":
|
||||
return emitters * 0.0438126
|
||||
|
||||
# Unknown unit: log and return original value
|
||||
print(str_unit)
|
||||
return emitters
|
||||
|
||||
|
||||
def burst_analysis(
|
||||
name: str,
|
||||
modify_pattern_start_time: str,
|
||||
burst_ID: list | str = None,
|
||||
burst_size: list | float | int = None,
|
||||
modify_total_duration: int = 900,
|
||||
modify_fixed_pump_pattern: dict[str, list] = None,
|
||||
modify_variable_pump_pattern: dict[str, list] = None,
|
||||
modify_valve_opening: dict[str, float] = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
爆管模拟
|
||||
:param name: 模型名称,数据库中对应的名字
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param burst_ID: 爆管管道的ID,选取的是管道,单独传入一个爆管管道,可以是str或list,传入多个爆管管道是用list
|
||||
:param burst_size: 爆管管道破裂的孔口面积,和burst_ID列表各位置的ID对应,以cm*cm计算
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:param modify_fixed_pump_pattern: dict中包含多个水泵模式,str为工频水泵的id,list为修改后的pattern
|
||||
:param modify_variable_pump_pattern: dict中包含多个水泵模式,str为变频水泵的id,list为修改后的pattern
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
scheme_detail: dict = {
|
||||
"burst_ID": burst_ID,
|
||||
"burst_size": burst_size,
|
||||
"modify_total_duration": modify_total_duration,
|
||||
"modify_fixed_pump_pattern": modify_fixed_pump_pattern,
|
||||
"modify_variable_pump_pattern": modify_variable_pump_pattern,
|
||||
"modify_valve_opening": modify_valve_opening,
|
||||
}
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"burst_Anal_{name}"
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
simulation.run_simulation(
|
||||
name=new_name,
|
||||
simulation_type="manually_temporary",
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
##step 1 set the emitter coefficient of end node of busrt pipe
|
||||
if isinstance(burst_ID, list):
|
||||
if (burst_size is not None) and (type(burst_size) is not list):
|
||||
return json.dumps("Type mismatch.")
|
||||
# 转化为列表形式
|
||||
elif isinstance(burst_ID, str):
|
||||
burst_ID = [burst_ID]
|
||||
if burst_size is not None:
|
||||
if isinstance(burst_size, float) or isinstance(burst_size, int):
|
||||
burst_size = [burst_size]
|
||||
else:
|
||||
return json.dumps("Type mismatch.")
|
||||
else:
|
||||
return json.dumps("Type mismatch.")
|
||||
if burst_size is None:
|
||||
burst_size = [-1] * len(burst_ID)
|
||||
elif len(burst_size) < len(burst_ID):
|
||||
burst_size += [-1] * (len(burst_ID) - len(burst_size))
|
||||
elif len(burst_size) > len(burst_ID):
|
||||
# burst_size = burst_size[:len(burst_ID)]
|
||||
return json.dumps("Length mismatch.")
|
||||
for burst_ID_, burst_size_ in zip(burst_ID, burst_size):
|
||||
pipe = get_pipe(new_name, burst_ID_)
|
||||
str_start_node = pipe["node1"]
|
||||
str_end_node = pipe["node2"]
|
||||
d_pipe = pipe["diameter"] / 1000.0
|
||||
if burst_size_ <= 0:
|
||||
burst_size_ = 3.14 * d_pipe * d_pipe / 4 / 8
|
||||
else:
|
||||
burst_size_ = burst_size_ / 10000
|
||||
emitter_coeff = (
|
||||
0.65 * burst_size_ * sqrt(19.6) * 1000
|
||||
) # 1/8开口面积作为coeff,单位 L/S
|
||||
emitter_coeff = convert_to_local_unit(new_name, emitter_coeff)
|
||||
emitter_node = ""
|
||||
if is_junction(new_name, str_end_node):
|
||||
emitter_node = str_end_node
|
||||
elif is_junction(new_name, str_start_node):
|
||||
emitter_node = str_start_node
|
||||
old_emitter = get_emitter(new_name, emitter_node)
|
||||
if old_emitter != None:
|
||||
old_emitter["coefficient"] = emitter_coeff # 爆管的emitter coefficient设置
|
||||
else:
|
||||
old_emitter = {"junction": emitter_node, "coefficient": emitter_coeff}
|
||||
new_emitter = ChangeSet()
|
||||
new_emitter.append(old_emitter)
|
||||
set_emitter(new_name, new_emitter)
|
||||
# step 2. run simulation
|
||||
# 涉及关阀计算,可能导致关阀后仍有流量,改为压力驱动PDA
|
||||
options = get_option(new_name)
|
||||
options["DEMAND MODEL"] = OPTION_DEMAND_MODEL_PDA
|
||||
options["REQUIRED PRESSURE"] = "10.0000"
|
||||
cs_options = ChangeSet()
|
||||
cs_options.append(options)
|
||||
set_option(new_name, cs_options)
|
||||
# valve_control = None
|
||||
# if modify_valve_opening is not None:
|
||||
# valve_control = {}
|
||||
# for valve in modify_valve_opening:
|
||||
# valve_control[valve] = {'status': 'CLOSED'}
|
||||
# result = run_simulation_ex(new_name,'realtime', modify_pattern_start_time,
|
||||
# end_datetime=modify_pattern_start_time,
|
||||
# modify_total_duration=modify_total_duration,
|
||||
# modify_pump_pattern=modify_pump_pattern,
|
||||
# valve_control=valve_control,
|
||||
# downloading_prohibition=True)
|
||||
simulation.run_simulation(
|
||||
name=new_name,
|
||||
simulation_type="extended",
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
modify_fixed_pump_pattern=modify_fixed_pump_pattern,
|
||||
modify_variable_pump_pattern=modify_variable_pump_pattern,
|
||||
modify_valve_opening=modify_valve_opening,
|
||||
scheme_type="burst_analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
# step 3. restore the base model status
|
||||
# execute_undo(name) #有疑惑
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# 存储方案信息到 PG 数据库
|
||||
store_scheme_info(
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
scheme_type="burst_analysis",
|
||||
username="admin",
|
||||
scheme_start_time=modify_pattern_start_time,
|
||||
scheme_detail=scheme_detail,
|
||||
)
|
||||
|
||||
|
||||
############################################################
|
||||
# valve closing analysis 02
|
||||
############################################################
|
||||
def valve_close_analysis(
|
||||
name: str,
|
||||
modify_pattern_start_time: str,
|
||||
modify_total_duration: int = 900,
|
||||
modify_valve_opening: dict[str, float] = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
关阀模拟
|
||||
:param name: 模型名称,数据库中对应的名字
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"valve_close_Anal_{name}"
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
# step 1. change the valves status to 'closed'
|
||||
# for valve in valves:
|
||||
# if not is_valve(new_name,valve):
|
||||
# result='ID:{}is not a valve type'.format(valve)
|
||||
# return result
|
||||
# cs=ChangeSet()
|
||||
# status=get_status(new_name,valve)
|
||||
# status['status']='CLOSED'
|
||||
# cs.append(status)
|
||||
# set_status(new_name,cs)
|
||||
# step 2. run simulation
|
||||
# 涉及关阀计算,可能导致关阀后仍有流量,改为压力驱动PDA
|
||||
options = get_option(new_name)
|
||||
options["DEMAND MODEL"] = OPTION_DEMAND_MODEL_PDA
|
||||
options["REQUIRED PRESSURE"] = "20.0000"
|
||||
cs_options = ChangeSet()
|
||||
cs_options.append(options)
|
||||
set_option(new_name, cs_options)
|
||||
# result = run_simulation_ex(new_name,'realtime', modify_pattern_start_time, modify_pattern_start_time, modify_total_duration,
|
||||
# downloading_prohibition=True)
|
||||
simulation.run_simulation(
|
||||
name=new_name,
|
||||
simulation_type="extended",
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
modify_valve_opening=modify_valve_opening,
|
||||
scheme_type="valve_close_Analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
# step 3. restore the base model
|
||||
# for valve in valves:
|
||||
# execute_undo(name)
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# return result
|
||||
|
||||
|
||||
############################################################
|
||||
# flushing analysis 03
|
||||
# Pipe_Flushing_Analysis(prj_name,date_time, Valve_id_list, Drainage_Node_Id, Flushing_flow[opt], Flushing_duration[opt])->out_file:string
|
||||
############################################################
|
||||
def flushing_analysis(
|
||||
name: str,
|
||||
modify_pattern_start_time: str,
|
||||
modify_total_duration: int = 900,
|
||||
modify_valve_opening: dict[str, float] = None,
|
||||
drainage_node_ID: str = None,
|
||||
flushing_flow: float = 0,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
管道冲洗模拟
|
||||
:param name: 模型名称,数据库中对应的名字
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param drainage_node_ID: 冲洗排放口所在节点ID
|
||||
:param flushing_flow: 冲洗水量,传入参数单位为m3/h
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
scheme_detail: dict = {
|
||||
"duration": modify_total_duration,
|
||||
"valve_opening": modify_valve_opening,
|
||||
"drainage_node_ID": drainage_node_ID,
|
||||
"flushing_flow": flushing_flow,
|
||||
}
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"flushing_Anal_{name}"
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# if is_project_open(name):
|
||||
# close_project(name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
if not is_junction(new_name, drainage_node_ID):
|
||||
return "Wrong Drainage node type"
|
||||
# step 1. change the valves status to 'closed'
|
||||
# for valve, valve_k in zip(valves, valves_k):
|
||||
# cs=ChangeSet()
|
||||
# status=get_status(new_name,valve)
|
||||
# # status['status']='CLOSED'
|
||||
# if valve_k == 0:
|
||||
# status['status'] = 'CLOSED'
|
||||
# elif valve_k < 1:
|
||||
# status['status'] = 'OPEN'
|
||||
# status['setting'] = 0.1036 * pow(valve_k, -3.105)
|
||||
# cs.append(status)
|
||||
# set_status(new_name,cs)
|
||||
options = get_option(new_name)
|
||||
units = options["UNITS"]
|
||||
# step 2. set the emitter coefficient of drainage node or add flush flow to the drainage node
|
||||
# 新建 pattern
|
||||
time_option = get_time(new_name)
|
||||
hydraulic_step = time_option["HYDRAULIC TIMESTEP"]
|
||||
secs = from_clock_to_seconds_2(hydraulic_step)
|
||||
cs_pattern = ChangeSet()
|
||||
pt = {}
|
||||
factors = []
|
||||
tmp_duration = modify_total_duration
|
||||
while tmp_duration > 0:
|
||||
factors.append(1.0)
|
||||
tmp_duration = tmp_duration - secs
|
||||
pt["id"] = "flushing_pt"
|
||||
pt["factors"] = factors
|
||||
cs_pattern.append(pt)
|
||||
add_pattern(new_name, cs_pattern)
|
||||
# 为 emitter_demand 添加新的 pattern
|
||||
emitter_demand = get_demand(new_name, drainage_node_ID)
|
||||
cs = ChangeSet()
|
||||
if flushing_flow > 0:
|
||||
if units == "LPS":
|
||||
emitter_demand["demands"].append(
|
||||
{
|
||||
"demand": flushing_flow / 3.6,
|
||||
"pattern": "flushing_pt",
|
||||
"category": None,
|
||||
}
|
||||
)
|
||||
elif units == "CMH":
|
||||
emitter_demand["demands"].append(
|
||||
{"demand": flushing_flow, "pattern": "flushing_pt", "category": None}
|
||||
)
|
||||
cs.append(emitter_demand)
|
||||
set_demand(new_name, cs)
|
||||
else:
|
||||
pipes = get_node_links(new_name, drainage_node_ID)
|
||||
flush_diameter = 50
|
||||
for pipe in pipes:
|
||||
d = get_pipe(new_name, pipe)["diameter"]
|
||||
if flush_diameter < d:
|
||||
flush_diameter = d
|
||||
flush_diameter /= 1000
|
||||
emitter_coeff = (
|
||||
0.65 * 3.14 * (flush_diameter * flush_diameter / 4) * sqrt(19.6) * 1000
|
||||
) # 全开口面积作为coeff
|
||||
|
||||
old_emitter = get_emitter(new_name, drainage_node_ID)
|
||||
if old_emitter != None:
|
||||
old_emitter["coefficient"] = emitter_coeff # 爆管的emitter coefficient设置
|
||||
else:
|
||||
old_emitter = {"junction": drainage_node_ID, "coefficient": emitter_coeff}
|
||||
new_emitter = ChangeSet()
|
||||
new_emitter.append(old_emitter)
|
||||
set_emitter(new_name, new_emitter)
|
||||
# step 3. run simulation
|
||||
# 涉及关阀计算,可能导致关阀后仍有流量,改为压力驱动PDA
|
||||
options = get_option(new_name)
|
||||
options["DEMAND MODEL"] = OPTION_DEMAND_MODEL_PDA
|
||||
options["REQUIRED PRESSURE"] = "20.0000"
|
||||
cs_options = ChangeSet()
|
||||
cs_options.append(options)
|
||||
set_option(new_name, cs_options)
|
||||
# result = run_simulation_ex(new_name,'realtime', modify_pattern_start_time, modify_pattern_start_time, modify_total_duration,
|
||||
# downloading_prohibition=True)
|
||||
simulation.run_simulation(
|
||||
name=new_name,
|
||||
simulation_type="extended",
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
modify_valve_opening=modify_valve_opening,
|
||||
scheme_type="flushing_analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
# step 4. restore the base model
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# return result
|
||||
# 存储方案信息到 PG 数据库
|
||||
store_scheme_info(
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
scheme_type="flushing_analysis",
|
||||
username="admin",
|
||||
scheme_start_time=modify_pattern_start_time,
|
||||
scheme_detail=scheme_detail,
|
||||
)
|
||||
|
||||
|
||||
############################################################
|
||||
# Contaminant simulation 04
|
||||
#
|
||||
############################################################
|
||||
def contaminant_simulation(
|
||||
name: str,
|
||||
modify_pattern_start_time: str, # 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
modify_total_duration: int, # 模拟总历时,秒
|
||||
source: str, # 污染源节点ID
|
||||
concentration: float, # 污染源浓度,单位mg/L
|
||||
scheme_name: str = None,
|
||||
source_pattern: str = None, # 污染源时间变化模式名称
|
||||
) -> None:
|
||||
"""
|
||||
污染模拟
|
||||
:param name: 模型名称,数据库中对应的名字
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:param source: 污染源所在的节点ID
|
||||
:param concentration: 污染源位置处的浓度,单位mg/L。默认的污染模拟setting为SOURCE_TYPE_CONCEN(改为SOURCE_TYPE_SETPOINT)
|
||||
:param source_pattern: 污染源的时间变化模式,若不传入则默认以恒定浓度持续模拟,时间长度等于duration;
|
||||
若传入,则格式为{1.0,0.5,1.1}等系数列表pattern_step模拟等于模型的hydraulic time step
|
||||
:param scheme_name: 方案名称
|
||||
:return:
|
||||
"""
|
||||
scheme_detail: dict = {
|
||||
"source": source,
|
||||
"concentration": concentration,
|
||||
"duration": modify_total_duration,
|
||||
"pattern": source_pattern,
|
||||
}
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"contaminant_Sim_{name}"
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# if is_project_open(name):
|
||||
# close_project(name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
dic_time = get_time(new_name)
|
||||
dic_time["QUALITY TIMESTEP"] = "0:05:00"
|
||||
cs = ChangeSet()
|
||||
cs.operations.append(dic_time)
|
||||
set_time(new_name, cs) # set QUALITY TIMESTEP
|
||||
time_option = get_time(new_name)
|
||||
hydraulic_step = time_option["HYDRAULIC TIMESTEP"]
|
||||
secs = from_clock_to_seconds_2(hydraulic_step)
|
||||
operation_step = 0
|
||||
# step 1. set duration
|
||||
if modify_total_duration == None:
|
||||
modify_total_duration = secs
|
||||
# step 2. set pattern
|
||||
if source_pattern != None:
|
||||
pt = get_pattern(new_name, source_pattern)
|
||||
if len(pt) == 0:
|
||||
str_response = str("cant find source_pattern")
|
||||
return str_response
|
||||
else:
|
||||
cs_pattern = ChangeSet()
|
||||
pt = {}
|
||||
factors = []
|
||||
tmp_duration = modify_total_duration
|
||||
while tmp_duration > 0:
|
||||
factors.append(1.0)
|
||||
tmp_duration = tmp_duration - secs
|
||||
pt["id"] = "contam_pt"
|
||||
pt["factors"] = factors
|
||||
cs_pattern.append(pt)
|
||||
add_pattern(new_name, cs_pattern)
|
||||
operation_step += 1
|
||||
# step 3. set source/initial quality
|
||||
# source quality
|
||||
cs_source = ChangeSet()
|
||||
source_schema = {
|
||||
"node": source,
|
||||
"s_type": SOURCE_TYPE_SETPOINT,
|
||||
"strength": concentration,
|
||||
"pattern": pt["id"],
|
||||
}
|
||||
cs_source.append(source_schema)
|
||||
source_node = get_source(new_name, source)
|
||||
if len(source_node) == 0:
|
||||
add_source(new_name, cs_source)
|
||||
else:
|
||||
set_source(new_name, cs_source)
|
||||
dict_demand = get_demand(new_name, source)
|
||||
for demands in dict_demand["demands"]:
|
||||
dict_demand["demands"][dict_demand["demands"].index(demands)]["demand"] = -1
|
||||
dict_demand["demands"][dict_demand["demands"].index(demands)]["pattern"] = None
|
||||
cs = ChangeSet()
|
||||
cs.append(dict_demand)
|
||||
set_demand(new_name, cs) # set inflow node
|
||||
# # initial quality
|
||||
# dict_quality = get_quality(new_name, source)
|
||||
# dict_quality['quality'] = concentration
|
||||
# cs = ChangeSet()
|
||||
# cs.append(dict_quality)
|
||||
# set_quality(new_name, cs)
|
||||
operation_step += 1
|
||||
# step 4 set option of quality to chemical
|
||||
opt = get_option(new_name)
|
||||
opt["QUALITY"] = OPTION_QUALITY_CHEMICAL
|
||||
cs_option = ChangeSet()
|
||||
cs_option.append(opt)
|
||||
set_option(new_name, cs_option)
|
||||
operation_step += 1
|
||||
# step 5. run simulation
|
||||
# result = run_simulation_ex(new_name,'realtime', modify_pattern_start_time, modify_pattern_start_time, modify_total_duration,
|
||||
# downloading_prohibition=True)
|
||||
simulation.run_simulation(
|
||||
name=new_name,
|
||||
simulation_type="extended",
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
scheme_type="contaminant_analysis",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
|
||||
# for i in range(1,operation_step):
|
||||
# execute_undo(name)
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# 存储方案信息到 PG 数据库
|
||||
store_scheme_info(
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
scheme_type="contaminant_analysis",
|
||||
username="admin",
|
||||
scheme_start_time=modify_pattern_start_time,
|
||||
scheme_detail=scheme_detail,
|
||||
)
|
||||
|
||||
|
||||
############################################################
|
||||
# age analysis 05 ***水龄模拟目前还没和实时模拟打通,不确定是否需要,先不要使用***
|
||||
############################################################
|
||||
|
||||
|
||||
def age_analysis(
|
||||
name: str, modify_pattern_start_time: str, modify_total_duration: int = 900
|
||||
) -> None:
|
||||
"""
|
||||
水龄模拟
|
||||
:param name: 模型名称,数据库中对应的名字
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"age_Anal_{name}"
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# if is_project_open(name):
|
||||
# close_project(name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
# step 1. run simulation
|
||||
|
||||
result = run_simulation_ex(
|
||||
new_name,
|
||||
"realtime",
|
||||
modify_pattern_start_time,
|
||||
modify_total_duration,
|
||||
downloading_prohibition=True,
|
||||
)
|
||||
# step 2. restore the base model status
|
||||
# execute_undo(name) #有疑惑
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
output = Output("./temp/{}.db.out".format(new_name))
|
||||
# element_name = output.element_name()
|
||||
# node_name = element_name['nodes']
|
||||
# link_name = element_name['links']
|
||||
nodes_age = []
|
||||
node_result = output.node_results()
|
||||
for node in node_result:
|
||||
nodes_age.append(node["result"][-1]["quality"])
|
||||
links_age = []
|
||||
link_result = output.link_results()
|
||||
for link in link_result:
|
||||
links_age.append(link["result"][-1]["quality"])
|
||||
age_result = {"nodes": nodes_age, "links": links_age}
|
||||
# age_result = {'nodes': nodes_age, 'links': links_age, 'nodeIDs': node_name, 'linkIDs': link_name}
|
||||
return json.dumps(age_result)
|
||||
|
||||
|
||||
############################################################
|
||||
# pressure regulation 06
|
||||
############################################################
|
||||
|
||||
|
||||
def pressure_regulation(
|
||||
name: str,
|
||||
modify_pattern_start_time: str,
|
||||
modify_total_duration: int = 900,
|
||||
modify_tank_initial_level: dict[str, float] = None,
|
||||
modify_fixed_pump_pattern: dict[str, list] = None,
|
||||
modify_variable_pump_pattern: dict[str, list] = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
区域调压模拟,用来模拟未来15分钟内,开关水泵对区域压力的影响
|
||||
:param name: 模型名称,数据库中对应的名字
|
||||
:param modify_pattern_start_time: 模拟开始时间,格式为'2024-11-25T09:00:00+08:00'
|
||||
:param modify_total_duration: 模拟总历时,秒
|
||||
:param modify_tank_initial_level: dict中包含多个水塔,str为水塔的id,float为修改后的initial_level
|
||||
:param modify_fixed_pump_pattern: dict中包含多个水泵模式,str为工频水泵的id,list为修改后的pattern
|
||||
:param modify_variable_pump_pattern: dict中包含多个水泵模式,str为变频水泵的id,list为修改后的pattern
|
||||
:param scheme_name: 模拟方案名称
|
||||
:return:
|
||||
"""
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"pressure_regulation_{name}"
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# if is_project_open(name):
|
||||
# close_project(name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
# 全部关泵后,压力计算不合理,改为压力驱动PDA
|
||||
options = get_option(new_name)
|
||||
options["DEMAND MODEL"] = OPTION_DEMAND_MODEL_PDA
|
||||
options["REQUIRED PRESSURE"] = "15.0000"
|
||||
cs_options = ChangeSet()
|
||||
cs_options.append(options)
|
||||
set_option(new_name, cs_options)
|
||||
# result = run_simulation_ex(name=new_name,
|
||||
# simulation_type='realtime',
|
||||
# start_datetime=start_datetime,
|
||||
# duration=900,
|
||||
# pump_control=pump_control,
|
||||
# tank_initial_level_control=tank_initial_level_control,
|
||||
# downloading_prohibition=True)
|
||||
simulation.run_simulation(
|
||||
name=new_name,
|
||||
simulation_type="extended",
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
modify_total_duration=modify_total_duration,
|
||||
modify_tank_initial_level=modify_tank_initial_level,
|
||||
modify_fixed_pump_pattern=modify_fixed_pump_pattern,
|
||||
modify_variable_pump_pattern=modify_variable_pump_pattern,
|
||||
scheme_type="pressure_regulation",
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# return result
|
||||
165
app/algorithms/valve_isolation.py
Normal file
165
app/algorithms/valve_isolation.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from collections import defaultdict, deque
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from app.services.tjnetwork import (
|
||||
get_network_link_nodes,
|
||||
is_node,
|
||||
get_link_properties,
|
||||
)
|
||||
|
||||
|
||||
VALVE_LINK_TYPE = "valve"
|
||||
|
||||
|
||||
def _parse_link_entry(link_entry: str) -> tuple[str, str, str, str]:
|
||||
parts = link_entry.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
raise ValueError(f"Invalid link entry format: {link_entry}")
|
||||
return parts[0], parts[1], parts[2], parts[3]
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def _get_network_topology(network: str):
|
||||
"""
|
||||
解析并缓存网络拓扑,大幅减少重复的 API 调用和字符串解析开销。
|
||||
返回:
|
||||
- pipe_adj: 永久连通的管道/泵邻接表 (dict[str, set])
|
||||
- all_valves: 所有阀门字典 {id: (n1, n2)}
|
||||
- link_lookup: 链路快速查表 {id: (n1, n2, type)} 用于快速定位事故点
|
||||
- node_set: 所有已知节点集合
|
||||
"""
|
||||
pipe_adj = defaultdict(set)
|
||||
all_valves = {}
|
||||
link_lookup = {}
|
||||
node_set = set()
|
||||
|
||||
# 此处假设 get_network_link_nodes 获取全网数据
|
||||
for link_entry in get_network_link_nodes(network):
|
||||
link_id, link_type, node1, node2 = _parse_link_entry(link_entry)
|
||||
link_type_name = str(link_type).lower()
|
||||
|
||||
link_lookup[link_id] = (node1, node2, link_type_name)
|
||||
node_set.add(node1)
|
||||
node_set.add(node2)
|
||||
|
||||
if link_type_name == VALVE_LINK_TYPE:
|
||||
all_valves[link_id] = (node1, node2)
|
||||
else:
|
||||
# 只有非阀门(管道/泵)才进入永久连通图
|
||||
pipe_adj[node1].add(node2)
|
||||
pipe_adj[node2].add(node1)
|
||||
|
||||
return pipe_adj, all_valves, link_lookup, node_set
|
||||
|
||||
|
||||
def valve_isolation_analysis(
|
||||
network: str, accident_elements: str | list[str], disabled_valves: list[str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
关阀搜索/分析:基于拓扑结构确定事故隔离所需关阀。
|
||||
:param network: 模型名称
|
||||
:param accident_elements: 事故点(节点或管道/泵/阀门ID),可以是单个ID字符串或ID列表
|
||||
:param disabled_valves: 故障/无法关闭的阀门ID列表
|
||||
:return: dict,包含受影响节点、必须关闭阀门、可选阀门等信息
|
||||
"""
|
||||
if disabled_valves is None:
|
||||
disabled_valves_set = set()
|
||||
else:
|
||||
disabled_valves_set = set(disabled_valves)
|
||||
|
||||
if isinstance(accident_elements, str):
|
||||
target_elements = [accident_elements]
|
||||
else:
|
||||
target_elements = accident_elements
|
||||
|
||||
# 1. 获取缓存拓扑 (极快,无 IO)
|
||||
pipe_adj, all_valves, link_lookup, node_set = _get_network_topology(network)
|
||||
|
||||
# 2. 确定起点,优先查表避免 API 调用
|
||||
start_nodes = set()
|
||||
for element in target_elements:
|
||||
if element in node_set:
|
||||
start_nodes.add(element)
|
||||
elif element in link_lookup:
|
||||
n1, n2, _ = link_lookup[element]
|
||||
start_nodes.add(n1)
|
||||
start_nodes.add(n2)
|
||||
else:
|
||||
# 仅当缓存中没找到时(极少见),才回退到慢速 API
|
||||
if is_node(network, element):
|
||||
start_nodes.add(element)
|
||||
else:
|
||||
props = get_link_properties(network, element)
|
||||
n1, n2 = props.get("node1"), props.get("node2")
|
||||
if n1 and n2:
|
||||
start_nodes.add(n1)
|
||||
start_nodes.add(n2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Accident element {element} invalid or missing endpoints"
|
||||
)
|
||||
|
||||
# 3. 处理故障阀门 (构建临时增量图)
|
||||
# 我们不修改 cached pipe_adj,而是建立一个 extra_adj
|
||||
extra_adj = defaultdict(list)
|
||||
boundary_valves = {} # 当前有效的边界阀门
|
||||
|
||||
for vid, (n1, n2) in all_valves.items():
|
||||
if vid in disabled_valves_set:
|
||||
# 故障阀门:视为连通管道
|
||||
extra_adj[n1].append(n2)
|
||||
extra_adj[n2].append(n1)
|
||||
else:
|
||||
# 正常阀门:视为潜在边界
|
||||
boundary_valves[vid] = (n1, n2)
|
||||
|
||||
# 4. BFS 搜索 (叠加 pipe_adj 和 extra_adj)
|
||||
affected_nodes: set[str] = set()
|
||||
queue = deque(start_nodes)
|
||||
while queue:
|
||||
node = queue.popleft()
|
||||
if node in affected_nodes:
|
||||
continue
|
||||
affected_nodes.add(node)
|
||||
|
||||
# 遍历永久管道邻居
|
||||
if node in pipe_adj:
|
||||
for neighbor in pipe_adj[node]:
|
||||
if neighbor not in affected_nodes:
|
||||
queue.append(neighbor)
|
||||
|
||||
# 遍历故障阀门带来的额外邻居
|
||||
if node in extra_adj:
|
||||
for neighbor in extra_adj[node]:
|
||||
if neighbor not in affected_nodes:
|
||||
queue.append(neighbor)
|
||||
|
||||
# 5. 结果聚合
|
||||
must_close_valves: list[str] = []
|
||||
optional_valves: list[str] = []
|
||||
|
||||
for valve_id, (n1, n2) in boundary_valves.items():
|
||||
in_n1 = n1 in affected_nodes
|
||||
in_n2 = n2 in affected_nodes
|
||||
if in_n1 and in_n2:
|
||||
optional_valves.append(valve_id)
|
||||
elif in_n1 or in_n2:
|
||||
must_close_valves.append(valve_id)
|
||||
|
||||
must_close_valves.sort()
|
||||
optional_valves.sort()
|
||||
|
||||
result = {
|
||||
"accident_elements": target_elements,
|
||||
"disabled_valves": disabled_valves,
|
||||
"affected_nodes": sorted(affected_nodes),
|
||||
"must_close_valves": must_close_valves,
|
||||
"optional_valves": optional_valves,
|
||||
"isolatable": len(must_close_valves) > 0,
|
||||
}
|
||||
|
||||
if len(target_elements) == 1:
|
||||
result["accident_element"] = target_elements[0]
|
||||
|
||||
return result
|
||||
104
app/api/v1/endpoints/audit.py
Normal file
104
app/api/v1/endpoints/audit.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
审计日志 API 接口
|
||||
|
||||
仅管理员可访问
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
from app.auth.metadata_dependencies import (
|
||||
get_current_metadata_admin,
|
||||
get_current_metadata_user,
|
||||
)
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
async def get_audit_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> AuditRepository:
|
||||
"""获取审计日志仓储"""
|
||||
return AuditRepository(session)
|
||||
|
||||
@router.get("/logs", response_model=List[AuditLogResponse])
|
||||
async def get_audit_logs(
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="限制记录数"),
|
||||
current_user=Depends(get_current_metadata_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志(仅管理员)
|
||||
|
||||
支持按用户、时间、操作类型等条件过滤
|
||||
"""
|
||||
logs = await audit_repo.get_logs(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
return logs
|
||||
|
||||
@router.get("/logs/count")
|
||||
async def get_audit_logs_count(
|
||||
user_id: Optional[UUID] = Query(None, description="按用户ID过滤"),
|
||||
project_id: Optional[UUID] = Query(None, description="按项目ID过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
current_user=Depends(get_current_metadata_admin),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> dict:
|
||||
"""
|
||||
获取审计日志总数(仅管理员)
|
||||
"""
|
||||
count = await audit_repo.get_log_count(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
return {"count": count}
|
||||
|
||||
@router.get("/logs/my", response_model=List[AuditLogResponse])
|
||||
async def get_my_audit_logs(
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
start_time: Optional[datetime] = Query(None, description="开始时间"),
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
audit_repo: AuditRepository = Depends(get_audit_repository),
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询当前用户的审计日志
|
||||
|
||||
普通用户只能查看自己的操作记录
|
||||
"""
|
||||
logs = await audit_repo.get_logs(
|
||||
user_id=current_user.id,
|
||||
action=action,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
return logs
|
||||
@@ -0,0 +1,186 @@
|
||||
from typing import Annotated
|
||||
from datetime import timedelta
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from app.core.config import settings
|
||||
from app.core.security import create_access_token, create_refresh_token, verify_password
|
||||
from app.domain.schemas.user import UserCreate, UserResponse, UserLogin, Token
|
||||
from app.infra.repositories.user_repository import UserRepository
|
||||
from app.auth.dependencies import get_user_repository, get_current_active_user
|
||||
from app.domain.schemas.user import UserInDB
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> UserResponse:
|
||||
"""
|
||||
用户注册
|
||||
|
||||
创建新用户账号
|
||||
"""
|
||||
# 检查用户名和邮箱是否已存在
|
||||
if await user_repo.user_exists(username=user_data.username):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
if await user_repo.user_exists(email=user_data.email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered"
|
||||
)
|
||||
|
||||
# 创建用户
|
||||
try:
|
||||
user = await user_repo.create_user(user_data)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create user"
|
||||
)
|
||||
return UserResponse.model_validate(user)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during user registration: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Registration failed"
|
||||
)
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> Token:
|
||||
"""
|
||||
用户登录(OAuth2 标准格式)
|
||||
|
||||
返回 JWT Access Token 和 Refresh Token
|
||||
"""
|
||||
# 验证用户(支持用户名或邮箱登录)
|
||||
user = await user_repo.get_user_by_username(form_data.username)
|
||||
if not user:
|
||||
# 尝试用邮箱登录
|
||||
user = await user_repo.get_user_by_email(form_data.username)
|
||||
|
||||
if not user or not verify_password(form_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user account"
|
||||
)
|
||||
|
||||
# 生成 Token
|
||||
access_token = create_access_token(subject=user.username)
|
||||
refresh_token = create_refresh_token(subject=user.username)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
|
||||
@router.post("/login/simple", response_model=Token)
|
||||
async def login_simple(
|
||||
username: str,
|
||||
password: str,
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> Token:
|
||||
"""
|
||||
简化版登录接口(保持向后兼容)
|
||||
|
||||
直接使用 username 和 password 参数
|
||||
"""
|
||||
# 验证用户
|
||||
user = await user_repo.get_user_by_username(username)
|
||||
if not user:
|
||||
user = await user_repo.get_user_by_email(username)
|
||||
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user account"
|
||||
)
|
||||
|
||||
# 生成 Token
|
||||
access_token = create_access_token(subject=user.username)
|
||||
refresh_token = create_refresh_token(subject=user.username)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: UserInDB = Depends(get_current_active_user)
|
||||
) -> UserResponse:
|
||||
"""
|
||||
获取当前登录用户信息
|
||||
"""
|
||||
return UserResponse.model_validate(current_user)
|
||||
|
||||
@router.post("/refresh", response_model=Token)
|
||||
async def refresh_token(
|
||||
refresh_token: str,
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> Token:
|
||||
"""
|
||||
刷新 Access Token
|
||||
|
||||
使用 Refresh Token 获取新的 Access Token
|
||||
"""
|
||||
from jose import jwt, JWTError
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate refresh token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
token_type: str = payload.get("type")
|
||||
|
||||
if username is None or token_type != "refresh":
|
||||
raise credentials_exception
|
||||
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
# 验证用户仍然存在且激活
|
||||
user = await user_repo.get_user_by_username(username)
|
||||
if not user or not user.is_active:
|
||||
raise credentials_exception
|
||||
|
||||
# 生成新的 Access Token
|
||||
new_access_token = create_access_token(subject=user.username)
|
||||
|
||||
return Token(
|
||||
access_token=new_access_token,
|
||||
refresh_token=refresh_token, # 保持原 refresh token
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
)
|
||||
|
||||
37
app/api/v1/endpoints/cache.py
Normal file
37
app/api/v1/endpoints/cache.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from fastapi import APIRouter
|
||||
from app.infra.cache.redis_client import redis_client
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/clearrediskey/")
|
||||
async def fastapi_clear_redis_key(key: str):
|
||||
redis_client.delete(key)
|
||||
return True
|
||||
|
||||
|
||||
@router.post("/clearrediskeys/")
|
||||
async def fastapi_clear_redis_keys(keys: str):
|
||||
# delete keys contains the key
|
||||
matched_keys = redis_client.keys(f"*{keys}*")
|
||||
if matched_keys:
|
||||
redis_client.delete(*matched_keys)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@router.post("/clearallredis/")
|
||||
async def fastapi_clear_all_redis():
|
||||
redis_client.flushdb()
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/queryredis/")
|
||||
async def fastapi_query_redis():
|
||||
# Helper to decode bytes to str for JSON response if needed,
|
||||
# but original just returned keys (which might be bytes in redis-py unless decode_responses=True)
|
||||
# create_redis_client usually sets decode_responses=False by default.
|
||||
# We will assume user handles bytes or we should decode.
|
||||
# Original just returned redis_client.keys("*")
|
||||
keys = redis_client.keys("*")
|
||||
# Clean output for API
|
||||
return [k.decode('utf-8') if isinstance(k, bytes) else k for k in keys]
|
||||
31
app/api/v1/endpoints/components/controls.py
Normal file
31
app/api/v1/endpoints/components/controls.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getcontrolschema/")
|
||||
async def fastapi_get_control_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_control_schema(network)
|
||||
|
||||
@router.get("/getcontrolproperties/")
|
||||
async def fastapi_get_control_properties(network: str) -> dict[str, Any]:
|
||||
return get_control(network)
|
||||
|
||||
@router.post("/setcontrolproperties/", response_model=None)
|
||||
async def fastapi_set_control_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_control(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getruleschema/")
|
||||
async def fastapi_get_rule_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_rule_schema(network)
|
||||
|
||||
@router.get("/getruleproperties/")
|
||||
async def fastapi_get_rule_properties(network: str) -> dict[str, Any]:
|
||||
return get_rule(network)
|
||||
|
||||
@router.post("/setruleproperties/", response_model=None)
|
||||
async def fastapi_set_rule_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_rule(network, ChangeSet(props))
|
||||
42
app/api/v1/endpoints/components/curves.py
Normal file
42
app/api/v1/endpoints/components/curves.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getcurveschema")
|
||||
async def fastapi_get_curve_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_curve_schema(network)
|
||||
|
||||
@router.post("/addcurve/", response_model=None)
|
||||
async def fastapi_add_curve(network: str, curve: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {
|
||||
"id": curve,
|
||||
} | props
|
||||
return add_curve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletecurve/", response_model=None)
|
||||
async def fastapi_delete_curve(network: str, curve: str) -> ChangeSet:
|
||||
ps = {"id": curve}
|
||||
return delete_curve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getcurveproperties/")
|
||||
async def fastapi_get_curve_properties(network: str, curve: str) -> dict[str, Any]:
|
||||
return get_curve(network, curve)
|
||||
|
||||
@router.post("/setcurveproperties/", response_model=None)
|
||||
async def fastapi_set_curve_properties(
|
||||
network: str, curve: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": curve} | props
|
||||
return set_curve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getcurves/")
|
||||
async def fastapi_get_curves(network: str) -> list[str]:
|
||||
return get_curves(network)
|
||||
|
||||
@router.get("/iscurve/")
|
||||
async def fastapi_is_curve(network: str, curve: str) -> bool:
|
||||
return is_curve(network, curve)
|
||||
60
app/api/v1/endpoints/components/options.py
Normal file
60
app/api/v1/endpoints/components/options.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/gettimeschema")
|
||||
async def fastapi_get_time_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_time_schema(network)
|
||||
|
||||
@router.get("/gettimeproperties/")
|
||||
async def fastapi_get_time_properties(network: str) -> dict[str, Any]:
|
||||
return get_time(network)
|
||||
|
||||
@router.post("/settimeproperties/", response_model=None)
|
||||
async def fastapi_set_time_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_time(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getenergyschema/")
|
||||
async def fastapi_get_energy_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_energy_schema(network)
|
||||
|
||||
@router.get("/getenergyproperties/")
|
||||
async def fastapi_get_energy_properties(network: str) -> dict[str, Any]:
|
||||
return get_energy(network)
|
||||
|
||||
@router.post("/setenergyproperties/", response_model=None)
|
||||
async def fastapi_set_energy_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_energy(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getpumpenergyschema/")
|
||||
async def fastapi_get_pump_energy_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_pump_energy_schema(network)
|
||||
|
||||
@router.get("/getpumpenergyproperties//")
|
||||
async def fastapi_get_pump_energy_proeprties(network: str, pump: str) -> dict[str, Any]:
|
||||
return get_pump_energy(network, pump)
|
||||
|
||||
@router.get("/setpumpenergyproperties//", response_model=None)
|
||||
async def fastapi_set_pump_energy_properties(
|
||||
network: str, pump: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": pump} | props
|
||||
return set_pump_energy(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getoptionschema/")
|
||||
async def fastapi_get_option_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_option_v3_schema(network)
|
||||
|
||||
@router.get("/getoptionproperties/")
|
||||
async def fastapi_get_option_properties(network: str) -> dict[str, Any]:
|
||||
return get_option_v3(network)
|
||||
|
||||
@router.post("/setoptionproperties/", response_model=None)
|
||||
async def fastapi_set_option_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_option_v3(network, ChangeSet(props))
|
||||
42
app/api/v1/endpoints/components/patterns.py
Normal file
42
app/api/v1/endpoints/components/patterns.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpatternschema")
|
||||
async def fastapi_get_pattern_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_pattern_schema(network)
|
||||
|
||||
@router.post("/addpattern/", response_model=None)
|
||||
async def fastapi_add_pattern(network: str, pattern: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {
|
||||
"id": pattern,
|
||||
} | props
|
||||
return add_pattern(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletepattern/", response_model=None)
|
||||
async def fastapi_delete_pattern(network: str, pattern: str) -> ChangeSet:
|
||||
ps = {"id": pattern}
|
||||
return delete_pattern(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpatternproperties/")
|
||||
async def fastapi_get_pattern_properties(network: str, pattern: str) -> dict[str, Any]:
|
||||
return get_pattern(network, pattern)
|
||||
|
||||
@router.post("/setpatternproperties/", response_model=None)
|
||||
async def fastapi_set_pattern_properties(
|
||||
network: str, pattern: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": pattern} | props
|
||||
return set_pattern(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/ispattern/")
|
||||
async def fastapi_is_pattern(network: str, pattern: str) -> bool:
|
||||
return is_pattern(network, pattern)
|
||||
|
||||
@router.get("/getpatterns/")
|
||||
async def fastapi_get_patterns(network: str) -> list[str]:
|
||||
return get_patterns(network)
|
||||
119
app/api/v1/endpoints/components/quality.py
Normal file
119
app/api/v1/endpoints/components/quality.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getqualityschema/")
|
||||
async def fastapi_get_quality_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_quality_schema(network)
|
||||
|
||||
@router.get("/getqualityproperties/")
|
||||
async def fastapi_get_quality_properties(network: str, node: str) -> dict[str, Any]:
|
||||
return get_quality(network, node)
|
||||
|
||||
@router.post("/setqualityproperties/", response_model=None)
|
||||
async def fastapi_set_quality_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_quality(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getemitterschema")
|
||||
async def fastapi_get_emitter_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_emitter_schema(network)
|
||||
|
||||
@router.get("/getemitterproperties/")
|
||||
async def fastapi_get_emitter_properties(network: str, junction: str) -> dict[str, Any]:
|
||||
return get_emitter(network, junction)
|
||||
|
||||
@router.post("/setemitterproperties/", response_model=None)
|
||||
async def fastapi_set_emitter_properties(
|
||||
network: str, junction: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"junction": junction} | props
|
||||
return set_emitter(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getsourcechema/")
|
||||
async def fastapi_get_source_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_source_schema(network)
|
||||
|
||||
@router.get("/getsource/")
|
||||
async def fastapi_get_source(network: str, node: str) -> dict[str, Any]:
|
||||
return get_source(network, node)
|
||||
|
||||
@router.post("/setsource/", response_model=None)
|
||||
async def fastapi_set_source(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_source(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addsource/", response_model=None)
|
||||
async def fastapi_add_source(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_source(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletesource/", response_model=None)
|
||||
async def fastapi_delete_source(network: str, node: str) -> ChangeSet:
|
||||
props = {"node": node}
|
||||
return delete_source(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getreactionschema/")
|
||||
async def fastapi_get_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_reaction_schema(network)
|
||||
|
||||
@router.get("/getreaction/")
|
||||
async def fastapi_get_reaction(network: str) -> dict[str, Any]:
|
||||
return get_reaction(network)
|
||||
|
||||
@router.post("/setreaction/", response_model=None)
|
||||
async def fastapi_set_reaction(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_reaction(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getpipereactionschema/")
|
||||
async def fastapi_get_pipe_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_pipe_reaction_schema(network)
|
||||
|
||||
@router.get("/getpipereaction/")
|
||||
async def fastapi_get_pipe_reaction(network: str, pipe: str) -> dict[str, Any]:
|
||||
return get_pipe_reaction(network, pipe)
|
||||
|
||||
@router.post("/setpipereaction/", response_model=None)
|
||||
async def fastapi_set_pipe_reaction(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_pipe_reaction(network, ChangeSet(props))
|
||||
|
||||
@router.get("/gettankreactionschema/")
|
||||
async def fastapi_get_tank_reaction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_tank_reaction_schema(network)
|
||||
|
||||
@router.get("/gettankreaction/")
|
||||
async def fastapi_get_tank_reaction(network: str, tank: str) -> dict[str, Any]:
|
||||
return get_tank_reaction(network, tank)
|
||||
|
||||
@router.post("/settankreaction/", response_model=None)
|
||||
async def fastapi_set_tank_reaction(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_tank_reaction(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getmixingschema/")
|
||||
async def fastapi_get_mixing_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_mixing_schema(network)
|
||||
|
||||
@router.get("/getmixing/")
|
||||
async def fastapi_get_mixing(network: str, tank: str) -> dict[str, Any]:
|
||||
return get_mixing(network, tank)
|
||||
|
||||
@router.post("/setmixing/", response_model=None)
|
||||
async def fastapi_set_mixing(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return api.set_mixing(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addmixing/", response_model=None)
|
||||
async def fastapi_add_mixing(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_mixing(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletemixing/", response_model=None)
|
||||
async def fastapi_delete_mixing(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_mixing(network, ChangeSet(props))
|
||||
76
app/api/v1/endpoints/components/visuals.py
Normal file
76
app/api/v1/endpoints/components/visuals.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from fastapi.responses import PlainTextResponse
|
||||
import json
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getvertexschema/")
|
||||
async def fastapi_get_vertex_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_vertex_schema(network)
|
||||
|
||||
@router.get("/getvertexproperties/")
|
||||
async def fastapi_get_vertex_properties(network: str, link: str) -> dict[str, Any]:
|
||||
return get_vertex(network, link)
|
||||
|
||||
@router.post("/setvertexproperties/", response_model=None)
|
||||
async def fastapi_set_vertex_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_vertex(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addvertex/", response_model=None)
|
||||
async def fastapi_add_vertex(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_vertex(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletevertex/", response_model=None)
|
||||
async def fastapi_delete_vertex(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_vertex(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallvertexlinks/", response_class=PlainTextResponse)
|
||||
async def fastapi_get_all_vertex_links(network: str) -> list[str]:
|
||||
return json.dumps(get_all_vertex_links(network))
|
||||
|
||||
@router.get("/getallvertices/", response_class=PlainTextResponse)
|
||||
async def fastapi_get_all_vertices(network: str) -> list[dict[str, Any]]:
|
||||
return json.dumps(get_all_vertices(network))
|
||||
|
||||
@router.get("/getlabelschema/")
|
||||
async def fastapi_get_label_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_label_schema(network)
|
||||
|
||||
@router.get("/getlabelproperties/")
|
||||
async def fastapi_get_label_properties(
|
||||
network: str, x: float, y: float
|
||||
) -> dict[str, Any]:
|
||||
return get_label(network, x, y)
|
||||
|
||||
@router.post("/setlabelproperties/", response_model=None)
|
||||
async def fastapi_set_label_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_label(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addlabel/", response_model=None)
|
||||
async def fastapi_add_label(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_label(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletelabel/", response_model=None)
|
||||
async def fastapi_delete_label(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_label(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getbackdropschema/")
|
||||
async def fastapi_get_backdrop_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_backdrop_schema(network)
|
||||
|
||||
@router.get("/getbackdropproperties/")
|
||||
async def fastapi_get_backdrop_properties(network: str) -> dict[str, Any]:
|
||||
return get_backdrop(network)
|
||||
|
||||
@router.post("/setbackdropproperties/", response_model=None)
|
||||
async def fastapi_set_backdrop_properties(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_backdrop(network, ChangeSet(props))
|
||||
388
app/api/v1/endpoints/data_query.py
Normal file
388
app/api/v1/endpoints/data_query.py
Normal file
@@ -0,0 +1,388 @@
|
||||
from typing import Any, List, Dict, Optional
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone, time as dt_time
|
||||
import msgpack
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from py_linq import Enumerable
|
||||
|
||||
import app.infra.db.influxdb.api as influxdb_api
|
||||
import app.services.time_api as time_api
|
||||
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Basic Node/Link Latest Record Queries
|
||||
|
||||
@router.get("/querynodelatestrecordbyid/")
|
||||
async def fastapi_query_node_latest_record_by_id(id: str) -> Any:
|
||||
return influxdb_api.query_latest_record_by_ID(id, type="node")
|
||||
|
||||
@router.get("/querylinklatestrecordbyid/")
|
||||
async def fastapi_query_link_latest_record_by_id(id: str) -> Any:
|
||||
return influxdb_api.query_latest_record_by_ID(id, type="link")
|
||||
|
||||
@router.get("/queryscadalatestrecordbyid/")
|
||||
async def fastapi_query_scada_latest_record_by_id(id: str) -> Any:
|
||||
return influxdb_api.query_latest_record_by_ID(id, type="scada")
|
||||
|
||||
# Time-based Queries
|
||||
|
||||
@router.get("/queryallrecordsbytime/")
|
||||
async def fastapi_query_all_records_by_time(querytime: str) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_all_records_by_time(query_time=querytime)
|
||||
return {"nodes": results[0], "links": results[1]}
|
||||
|
||||
@router.get("/queryallrecordsbytimeproperty/")
|
||||
async def fastapi_query_all_record_by_time_property(
|
||||
querytime: str, type: str, property: str, bucket: str = "realtime_simulation_result"
|
||||
) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_all_record_by_time_property(
|
||||
query_time=querytime, type=type, property=property, bucket=bucket
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
@router.get("/queryallschemerecordsbytimeproperty/")
|
||||
async def fastapi_query_all_scheme_record_by_time_property(
|
||||
querytime: str,
|
||||
type: str,
|
||||
property: str,
|
||||
schemename: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> dict[str, list]:
|
||||
"""
|
||||
查询指定方案某一时刻的所有记录,查询 'node' 或 'link' 的某一属性值
|
||||
"""
|
||||
results: list = influxdb_api.query_all_scheme_record_by_time_property(
|
||||
query_time=querytime,
|
||||
type=type,
|
||||
property=property,
|
||||
scheme_name=schemename,
|
||||
bucket=bucket,
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
@router.get("/querysimulationrecordsbyidtime/")
|
||||
async def fastapi_query_simulation_record_by_ids_time(
|
||||
id: str, querytime: str, type: str, bucket: str = "realtime_simulation_result"
|
||||
) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_simulation_result_by_ID_time(
|
||||
ID=id, type=type, query_time=querytime, bucket=bucket
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
@router.get("/queryschemesimulationrecordsbyidtime/")
|
||||
async def fastapi_query_scheme_simulation_record_by_ids_time(
|
||||
scheme_name: str,
|
||||
id: str,
|
||||
querytime: str,
|
||||
type: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> dict[str, list]:
|
||||
results: tuple = influxdb_api.query_scheme_simulation_result_by_ID_time(
|
||||
scheme_name=scheme_name, ID=id, type=type, query_time=querytime, bucket=bucket
|
||||
)
|
||||
return {"results": results}
|
||||
|
||||
# Date-based Queries with Caching
|
||||
|
||||
@router.get("/queryallrecordsbydate/")
|
||||
async def fastapi_query_all_records_by_date(querydate: str) -> dict:
|
||||
is_today_or_future = time_api.is_today_or_future(querydate)
|
||||
logger.info(f"isToday or future: {is_today_or_future}")
|
||||
|
||||
cache_key = f"queryallrecordsbydate_{querydate}"
|
||||
|
||||
if not is_today_or_future:
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
logger.info("return from cache redis")
|
||||
return results
|
||||
|
||||
logger.info("query from influxdb")
|
||||
nodes_links: tuple = influxdb_api.query_all_records_by_date(query_date=querydate)
|
||||
results = {"nodes": nodes_links[0], "links": nodes_links[1]}
|
||||
|
||||
if not is_today_or_future:
|
||||
logger.info("save to cache redis")
|
||||
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
|
||||
|
||||
logger.info("return results")
|
||||
return results
|
||||
|
||||
@router.get("/queryallrecordsbytimerange/")
|
||||
async def fastapi_query_all_records_by_time_range(
|
||||
starttime: str, endtime: str
|
||||
) -> dict[str, list]:
|
||||
cache_key = f"queryallrecordsbytimerange_{starttime}_{endtime}"
|
||||
|
||||
if not time_api.is_today_or_future(starttime):
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
nodes_links: tuple = influxdb_api.query_all_records_by_time_range(
|
||||
starttime=starttime, endtime=endtime
|
||||
)
|
||||
results = {"nodes": nodes_links[0], "links": nodes_links[1]}
|
||||
|
||||
if not time_api.is_today_or_future(starttime):
|
||||
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryallrecordsbydatewithtype/")
|
||||
async def fastapi_query_all_records_by_date_with_type(
|
||||
querydate: str, querytype: str
|
||||
) -> list:
|
||||
cache_key = f"queryallrecordsbydatewithtype_{querydate}_{querytype}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
results = influxdb_api.query_all_records_by_date_with_type(
|
||||
query_date=querydate, query_type=querytype
|
||||
)
|
||||
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryallrecordsbyidsdatetype/")
|
||||
async def fastapi_query_all_records_by_ids_date_type(
|
||||
ids: str, querydate: str, querytype: str
|
||||
) -> list:
|
||||
cache_key = f"queryallrecordsbydatewithtype_{querydate}_{querytype}"
|
||||
data = redis_client.get(cache_key)
|
||||
|
||||
results = []
|
||||
if data:
|
||||
results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
else:
|
||||
results = influxdb_api.query_all_records_by_date_with_type(
|
||||
query_date=querydate, query_type=querytype
|
||||
)
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
query_ids = ids.split(",")
|
||||
# Using Enumerable from py_linq as in original code
|
||||
e_results = Enumerable(results)
|
||||
lst_results = e_results.where(lambda x: x["ID"] in query_ids).to_list()
|
||||
|
||||
return lst_results
|
||||
|
||||
@router.get("/queryallrecordsbydateproperty/")
|
||||
async def fastapi_query_all_records_by_date_property(
|
||||
querydate: str, querytype: str, property: str
|
||||
) -> list[dict]:
|
||||
cache_key = f"queryallrecordsbydateproperty_{querydate}_{querytype}_{property}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
result_dict = influxdb_api.query_all_record_by_date_property(
|
||||
query_date=querydate, type=querytype, property=property
|
||||
)
|
||||
packed = msgpack.packb(result_dict, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
return result_dict
|
||||
|
||||
# Curve Queries
|
||||
|
||||
@router.get("/querynodecurvebyidpropertydaterange/")
|
||||
async def fastapi_query_node_curve_by_id_property_daterange(
|
||||
id: str, prop: str, startdate: str, enddate: str
|
||||
):
|
||||
return influxdb_api.query_curve_by_ID_property_daterange(
|
||||
id, type="node", property=prop, start_date=startdate, end_date=enddate
|
||||
)
|
||||
|
||||
@router.get("/querylinkcurvebyidpropertydaterange/")
|
||||
async def fastapi_query_link_curve_by_id_property_daterange(
|
||||
id: str, prop: str, startdate: str, enddate: str
|
||||
):
|
||||
return influxdb_api.query_curve_by_ID_property_daterange(
|
||||
id, type="link", property=prop, start_date=startdate, end_date=enddate
|
||||
)
|
||||
|
||||
# SCADA Data Queries
|
||||
|
||||
@router.get("/queryscadadatabydeviceidandtime/")
|
||||
async def fastapi_query_scada_data_by_device_id_and_time(ids: str, querytime: str):
|
||||
query_ids = ids.split(",")
|
||||
logger.info(querytime)
|
||||
return influxdb_api.query_SCADA_data_by_device_ID_and_time(
|
||||
query_ids_list=query_ids, query_time=querytime
|
||||
)
|
||||
|
||||
@router.get("/queryscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/queryfillingscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_filling_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_filling_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/querycleaningscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_cleaning_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_cleaning_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/querysimulationscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_simulation_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_simulation_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/querycleanedscadadatabydeviceidandtimerange/")
|
||||
async def fastapi_query_cleaned_scada_data_by_device_id_and_time_range(
|
||||
ids: str, starttime: str, endtime: str
|
||||
):
|
||||
print(f"query_ids: {ids}, starttime: {starttime}, endtime: {endtime}")
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_cleaned_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids, start_time=starttime, end_time=endtime
|
||||
)
|
||||
|
||||
@router.get("/queryscadadatabydeviceidanddate/")
|
||||
async def fastapi_query_scada_data_by_device_id_and_date(ids: str, querydate: str):
|
||||
query_ids = ids.split(",")
|
||||
return influxdb_api.query_SCADA_data_by_device_ID_and_date(
|
||||
query_ids_list=query_ids, query_date=querydate
|
||||
)
|
||||
|
||||
@router.get("/queryallscadarecordsbydate/")
|
||||
async def fastapi_query_all_scada_records_by_date(querydate: str):
|
||||
is_today_or_future = time_api.is_today_or_future(querydate)
|
||||
logger.info(f"isToday or future: {is_today_or_future}")
|
||||
|
||||
cache_key = f"queryallscadarecordsbydate_{querydate}"
|
||||
|
||||
if not is_today_or_future:
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
logger.info("return from cache redis")
|
||||
return loaded_dict
|
||||
|
||||
logger.info("query from influxdb")
|
||||
result_dict = influxdb_api.query_all_SCADA_records_by_date(query_date=querydate)
|
||||
|
||||
if not is_today_or_future:
|
||||
logger.info("save to cache redis")
|
||||
packed = msgpack.packb(result_dict, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
logger.info("return results")
|
||||
return result_dict
|
||||
|
||||
@router.get("/queryallschemeallrecords/")
|
||||
async def fastapi_query_all_scheme_all_records(
|
||||
schemetype: str, schemename: str, querydate: str
|
||||
) -> tuple:
|
||||
cache_key = f"queryallschemeallrecords_{schemetype}_{schemename}_{querydate}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
results = influxdb_api.query_scheme_all_record(
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryschemeallrecordsproperty/")
|
||||
async def fastapi_query_all_scheme_all_records_property(
|
||||
schemetype: str, schemename: str, querydate: str, querytype: str, queryproperty: str
|
||||
) -> Optional[List]:
|
||||
cache_key = f"queryallschemeallrecords_{schemetype}_{schemename}_{querydate}"
|
||||
data = redis_client.get(cache_key)
|
||||
all_results = None
|
||||
if data:
|
||||
all_results = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
else:
|
||||
all_results = influxdb_api.query_scheme_all_record(
|
||||
scheme_type=schemetype, scheme_name=schemename, query_date=querydate
|
||||
)
|
||||
packed = msgpack.packb(all_results, default=encode_datetime)
|
||||
redis_client.set(cache_key, packed)
|
||||
|
||||
results = None
|
||||
if querytype == "node":
|
||||
results = all_results[0]
|
||||
elif querytype == "link":
|
||||
results = all_results[1]
|
||||
|
||||
return results
|
||||
|
||||
@router.get("/queryinfluxdbbuckets/")
|
||||
async def fastapi_query_influxdb_buckets():
|
||||
return influxdb_api.query_buckets()
|
||||
|
||||
@router.get("/queryinfluxdbbucketmeasurements/")
|
||||
async def fastapi_query_influxdb_bucket_measurements(bucket: str):
|
||||
return influxdb_api.query_measurements(bucket=bucket)
|
||||
|
||||
############################################################
|
||||
# download history data
|
||||
############################################################
|
||||
|
||||
class Download_History_Data_Manually(BaseModel):
|
||||
"""
|
||||
download_date:样式如 datetime(2025, 5, 4)
|
||||
"""
|
||||
|
||||
download_date: datetime
|
||||
|
||||
|
||||
@router.post("/download_history_data_manually/")
|
||||
async def fastapi_download_history_data_manually(
|
||||
data: Download_History_Data_Manually,
|
||||
) -> None:
|
||||
item = data.dict()
|
||||
tz = timezone(timedelta(hours=8))
|
||||
begin_dt = datetime.combine(item.get("download_date").date(), dt_time.min).replace(
|
||||
tzinfo=tz
|
||||
)
|
||||
end_dt = datetime.combine(item.get("download_date").date(), dt_time(23, 59, 59)).replace(
|
||||
tzinfo=tz
|
||||
)
|
||||
|
||||
begin_time = begin_dt.isoformat()
|
||||
end_time = end_dt.isoformat()
|
||||
|
||||
influxdb_api.download_history_data_manually(
|
||||
begin_time=begin_time, end_time=end_time
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
from typing import List, Any
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from app.native.api import ChangeSet
|
||||
from app.services.tjnetwork import (
|
||||
get_all_extension_data_keys,
|
||||
get_all_extension_data,
|
||||
get_extension_data,
|
||||
set_extension_data
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getallextensiondatakeys/")
|
||||
async def get_all_extension_data_keys_endpoint(network: str) -> list[str]:
|
||||
return get_all_extension_data_keys(network)
|
||||
|
||||
@router.get("/getallextensiondata/")
|
||||
async def get_all_extension_data_endpoint(network: str) -> dict[str, Any]:
|
||||
return get_all_extension_data(network)
|
||||
|
||||
@router.get("/getextensiondata/")
|
||||
async def get_extension_data_endpoint(network: str, key: str) -> str | None:
|
||||
return get_extension_data(network, key)
|
||||
|
||||
@router.post("/setextensiondata/", response_model=None)
|
||||
async def set_extension_data_endpoint(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
print(props)
|
||||
cs = set_extension_data(network, ChangeSet(props))
|
||||
print(cs.operations[0])
|
||||
return cs
|
||||
|
||||
101
app/api/v1/endpoints/meta.py
Normal file
101
app/api/v1/endpoints/meta.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from psycopg import AsyncConnection
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.project_dependencies import (
|
||||
ProjectContext,
|
||||
get_project_context,
|
||||
get_project_pg_session,
|
||||
get_project_timescale_connection,
|
||||
get_metadata_repository,
|
||||
)
|
||||
from app.auth.metadata_dependencies import get_current_metadata_user
|
||||
from app.core.config import settings
|
||||
from app.domain.schemas.metadata import (
|
||||
GeoServerConfigResponse,
|
||||
ProjectMetaResponse,
|
||||
ProjectSummaryResponse,
|
||||
)
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/meta/project", response_model=ProjectMetaResponse)
|
||||
async def get_project_metadata(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
project = await metadata_repo.get_project_by_id(ctx.project_id)
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||
)
|
||||
geoserver = await metadata_repo.get_geoserver_config(ctx.project_id)
|
||||
geoserver_payload = (
|
||||
GeoServerConfigResponse(
|
||||
gs_base_url=geoserver.gs_base_url,
|
||||
gs_admin_user=geoserver.gs_admin_user,
|
||||
gs_datastore_name=geoserver.gs_datastore_name,
|
||||
default_extent=geoserver.default_extent,
|
||||
srid=geoserver.srid,
|
||||
)
|
||||
if geoserver
|
||||
else None
|
||||
)
|
||||
return ProjectMetaResponse(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=ctx.project_role,
|
||||
geoserver=geoserver_payload,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/meta/projects", response_model=list[ProjectSummaryResponse])
|
||||
async def list_user_projects(
|
||||
current_user=Depends(get_current_metadata_user),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
try:
|
||||
projects = await metadata_repo.list_projects_for_user(current_user.id)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while listing projects for user %s",
|
||||
current_user.id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
return [
|
||||
ProjectSummaryResponse(
|
||||
project_id=project.project_id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=project.project_role,
|
||||
)
|
||||
for project in projects
|
||||
]
|
||||
|
||||
|
||||
@router.get("/meta/db/health")
|
||||
async def project_db_health(
|
||||
pg_session: AsyncSession = Depends(get_project_pg_session),
|
||||
ts_conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
await pg_session.execute(text("SELECT 1"))
|
||||
async with ts_conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
return {"postgres": "ok", "timescale": "ok"}
|
||||
55
app/api/v1/endpoints/misc.py
Normal file
55
app/api/v1/endpoints/misc.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Any
|
||||
import random
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import status
|
||||
from pydantic import BaseModel
|
||||
from app.services.tjnetwork import (
|
||||
get_all_sensor_placements,
|
||||
get_all_burst_locate_results,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/getjson/")
|
||||
async def fastapi_get_json():
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
"code": 400,
|
||||
"message": "this is message",
|
||||
"data": 123,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/getallsensorplacements/")
|
||||
async def fastapi_get_all_sensor_placements(network: str) -> list[dict[Any, Any]]:
|
||||
return get_all_sensor_placements(network)
|
||||
|
||||
|
||||
@router.get("/getallburstlocateresults/")
|
||||
async def fastapi_get_all_burst_locate_results(network: str) -> list[dict[Any, Any]]:
|
||||
return get_all_burst_locate_results(network)
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
str_info: str
|
||||
|
||||
|
||||
@router.post("/test_dict/")
|
||||
async def fastapi_test_dict(data: Item) -> dict[str, str]:
|
||||
item = data.dict()
|
||||
return item
|
||||
|
||||
@router.get("/getrealtimedata/")
|
||||
async def fastapi_get_realtimedata():
|
||||
data = [random.randint(0, 100) for _ in range(100)]
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/getsimulationresult/")
|
||||
async def fastapi_get_simulationresult():
|
||||
data = [random.randint(0, 100) for _ in range(100)]
|
||||
return data
|
||||
55
app/api/v1/endpoints/network/demands.py
Normal file
55
app/api/v1/endpoints/network/demands.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################################################
|
||||
# demand 9.[DEMANDS]
|
||||
############################################################
|
||||
|
||||
@router.get("/getdemandschema")
|
||||
async def fastapi_get_demand_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_demand_schema(network)
|
||||
|
||||
|
||||
@router.get("/getdemandproperties/")
|
||||
async def fastapi_get_demand_properties(network: str, junction: str) -> dict[str, Any]:
|
||||
return get_demand(network, junction)
|
||||
|
||||
|
||||
# example: set_demand(p, ChangeSet({'junction': 'j1', 'demands': [{'demand': 10.0, 'pattern': None, 'category': 'x'}, {'demand': 20.0, 'pattern': None, 'category': None}]}))
|
||||
@router.post("/setdemandproperties/", response_model=None)
|
||||
async def fastapi_set_demand_properties(
|
||||
network: str, junction: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"junction": junction} | props
|
||||
return set_demand(network, ChangeSet(ps))
|
||||
|
||||
############################################################
|
||||
# water distribution 36.[Water Distribution]
|
||||
############################################################
|
||||
@router.get("/calculatedemandtonodes/")
|
||||
async def fastapi_calculate_demand_to_nodes(
|
||||
network: str, req: Request
|
||||
) -> dict[str, float]:
|
||||
props = await req.json()
|
||||
demand = props["demand"]
|
||||
nodes = props["nodes"]
|
||||
return calculate_demand_to_nodes(network, demand, nodes)
|
||||
|
||||
@router.get("/calculatedemandtoregion/")
|
||||
async def fastapi_calculate_demand_to_region(
|
||||
network: str, req: Request
|
||||
) -> dict[str, float]:
|
||||
props = await req.json()
|
||||
demand = props["demand"]
|
||||
region = props["region"]
|
||||
return calculate_demand_to_region(network, demand, region)
|
||||
|
||||
@router.get("/calculatedemandtonetwork/")
|
||||
async def fastapi_calculate_demand_to_network(
|
||||
network: str, demand: float
|
||||
) -> dict[str, float]:
|
||||
return calculate_demand_to_network(network, demand)
|
||||
162
app/api/v1/endpoints/network/general.py
Normal file
162
app/api/v1/endpoints/network/general.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################################################
|
||||
# type
|
||||
############################################################
|
||||
|
||||
@router.get("/isnode/")
|
||||
async def fastapi_is_node(network: str, node: str) -> bool:
|
||||
return is_node(network, node)
|
||||
|
||||
@router.get("/isjunction/")
|
||||
async def fastapi_is_junction(network: str, node: str) -> bool:
|
||||
return is_junction(network, node)
|
||||
|
||||
@router.get("/isreservoir/")
|
||||
async def fastapi_is_reservoir(network: str, node: str) -> bool:
|
||||
return is_reservoir(network, node)
|
||||
|
||||
@router.get("/istank/")
|
||||
async def fastapi_is_tank(network: str, node: str) -> bool:
|
||||
return is_tank(network, node)
|
||||
|
||||
@router.get("/islink/")
|
||||
async def fastapi_is_link(network: str, link: str) -> bool:
|
||||
return is_link(network, link)
|
||||
|
||||
@router.get("/ispipe/")
|
||||
async def fastapi_is_pipe(network: str, link: str) -> bool:
|
||||
return is_pipe(network, link)
|
||||
|
||||
@router.get("/ispump/")
|
||||
async def fastapi_is_pump(network: str, link: str) -> bool:
|
||||
return is_pump(network, link)
|
||||
|
||||
@router.get("/isvalve/")
|
||||
async def fastapi_is_valve(network: str, link: str) -> bool:
|
||||
return is_valve(network, link)
|
||||
|
||||
@router.get("/getnodetype/")
|
||||
async def fastapi_get_node_type(network: str, node: str) -> str:
|
||||
return get_node_type(network, node)
|
||||
|
||||
@router.get("/getlinktype/")
|
||||
async def fastapi_get_link_type(network: str, link: str) -> str:
|
||||
return get_link_type(network, link)
|
||||
|
||||
@router.get("/getelementtype/")
|
||||
async def fastapi_get_element_type(network: str, element: str) -> str:
|
||||
return get_element_type(network, element)
|
||||
|
||||
@router.get("/getelementtypevalue/")
|
||||
async def fastapi_get_element_type_value(network: str, element: str) -> int:
|
||||
return get_element_type_value(network, element)
|
||||
|
||||
@router.get("/getnodes/")
|
||||
async def fastapi_get_nodes(network: str) -> list[str]:
|
||||
return get_nodes(network)
|
||||
|
||||
@router.get("/getlinks/")
|
||||
async def fastapi_get_links(network: str) -> list[str]:
|
||||
return get_links(network)
|
||||
|
||||
@router.get("/getnodelinks/")
|
||||
def get_node_links_endpoint(network: str, node: str) -> list[str]:
|
||||
return get_node_links(network, node)
|
||||
|
||||
############################################################
|
||||
# Node & Link properties
|
||||
############################################################
|
||||
|
||||
@router.get("/getnodeproperties/")
|
||||
async def fast_get_node_properties(network: str, node: str) -> dict[str, Any]:
|
||||
return get_node_properties(network, node)
|
||||
|
||||
@router.get("/getlinkproperties/")
|
||||
async def fast_get_link_properties(network: str, link: str) -> dict[str, Any]:
|
||||
return get_link_properties(network, link)
|
||||
|
||||
@router.get("/getscadaproperties/")
|
||||
async def fast_get_scada_properties(network: str, scada: str) -> dict[str, Any]:
|
||||
return get_scada_info(network, scada)
|
||||
|
||||
@router.get("/getallscadaproperties/")
|
||||
async def fast_get_all_scada_properties(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_scada_info(network)
|
||||
|
||||
@router.get("/getelementpropertieswithtype/")
|
||||
async def fast_get_element_properties_with_type(
|
||||
network: str, elementtype: str, element: str
|
||||
) -> dict[str, Any]:
|
||||
return get_element_properties_with_type(network, elementtype, element)
|
||||
|
||||
@router.get("/getelementproperties/")
|
||||
async def fast_get_element_properties(network: str, element: str) -> dict[str, Any]:
|
||||
return get_element_properties(network, element)
|
||||
|
||||
############################################################
|
||||
# title 1.[TITLE]
|
||||
############################################################
|
||||
|
||||
@router.get("/gettitleschema/")
|
||||
async def fast_get_title_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_title_schema(network)
|
||||
|
||||
@router.get("/gettitle/")
|
||||
async def fast_get_title(network: str) -> dict[str, Any]:
|
||||
return get_title(network)
|
||||
|
||||
@router.get("/settitle/", response_model=None)
|
||||
async def fastapi_set_title(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_title(network, ChangeSet(props))
|
||||
|
||||
############################################################
|
||||
# status 10.[STATUS]
|
||||
############################################################
|
||||
|
||||
@router.get("/getstatusschema")
|
||||
async def fastapi_get_status_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_status_schema(network)
|
||||
|
||||
@router.get("/getstatus/")
|
||||
async def fastapi_get_status(network: str, link: str) -> dict[str, Any]:
|
||||
return get_status(network, link)
|
||||
|
||||
@router.post("/setstatus/", response_model=None)
|
||||
async def fastapi_set_status_properties(
|
||||
network: str, link: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"link": link} | props
|
||||
return set_status(network, ChangeSet(ps))
|
||||
|
||||
############################################################
|
||||
# General Deletion
|
||||
############################################################
|
||||
|
||||
@router.post("/deletenode/", response_model=None)
|
||||
async def fastapi_delete_node(network: str, node: str) -> ChangeSet:
|
||||
ps = {"id": node}
|
||||
if is_junction(network, node):
|
||||
return delete_junction(network, ChangeSet(ps))
|
||||
elif is_reservoir(network, node):
|
||||
return delete_reservoir(network, ChangeSet(ps))
|
||||
elif is_tank(network, node):
|
||||
return delete_tank(network, ChangeSet(ps))
|
||||
return ChangeSet() # Should probably raise error or return empty
|
||||
|
||||
@router.post("/deletelink/", response_model=None)
|
||||
async def fastapi_delete_link(network: str, link: str) -> ChangeSet:
|
||||
ps = {"id": link}
|
||||
if is_pipe(network, link):
|
||||
return delete_pipe(network, ChangeSet(ps))
|
||||
elif is_pump(network, link):
|
||||
return delete_pump(network, ChangeSet(ps))
|
||||
elif is_valve(network, link):
|
||||
return delete_valve(network, ChangeSet(ps))
|
||||
return ChangeSet()
|
||||
80
app/api/v1/endpoints/network/geometry.py
Normal file
80
app/api/v1/endpoints/network/geometry.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from fastapi import APIRouter, Request, Depends
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
from app.auth.dependencies import get_current_user as verify_token
|
||||
from app.infra.cache.redis_client import redis_client, encode_datetime, decode_datetime
|
||||
import msgpack
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################################################
|
||||
# coord 24.[COORDINATES]
|
||||
############################################################
|
||||
|
||||
# @router.get("/getcoordschema/")
|
||||
# async def fastapi_get_coord_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
# return get_coord_schema(network)
|
||||
|
||||
# @router.get("/getcoord/")
|
||||
# async def fastapi_get_coord(network: str, node: str) -> dict[str, Any]:
|
||||
# return get_coord(network, node)
|
||||
|
||||
# # example: set_coord(p, ChangeSet({'node': 'j1', 'x': 1.0, 'y': 2.0}))
|
||||
# @router.post("/setcoord/", response_model=None)
|
||||
# async def fastapi_set_coord(network: str, req: Request) -> ChangeSet:
|
||||
# props = await req.json()
|
||||
# return set_coord(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getnodecoord/")
|
||||
async def fastapi_get_node_coord(network: str, node: str) -> dict[str, float] | None:
|
||||
return get_node_coord(network, node)
|
||||
|
||||
# Additional geometry queries found in main.py logic (implicit or explicit)
|
||||
@router.get("/getnetworkinextent/")
|
||||
async def fastapi_get_network_in_extent(
|
||||
network: str, x1: float, y1: float, x2: float, y2: float
|
||||
) -> dict[str, Any]:
|
||||
return get_network_in_extent(network, x1, y1, x2, y2)
|
||||
|
||||
@router.get("/getnetworkgeometries/", dependencies=[Depends(verify_token)])
|
||||
async def fastapi_get_network_geometries(network: str) -> dict[str, Any] | None:
|
||||
cache_key = f"getnetworkgeometries_{network}"
|
||||
data = redis_client.get(cache_key)
|
||||
if data:
|
||||
loaded_dict = msgpack.unpackb(data, object_hook=decode_datetime)
|
||||
return loaded_dict
|
||||
|
||||
coords = get_network_node_coords(network)
|
||||
nodes = []
|
||||
for node_id, coord in coords.items():
|
||||
nodes.append(f"{node_id}:{coord['type']}:{coord['x']}:{coord['y']}")
|
||||
links = get_network_link_nodes(network)
|
||||
scadas = get_all_scada_info(network)
|
||||
|
||||
results = {"nodes": nodes, "links": links, "scadas": scadas}
|
||||
redis_client.set(cache_key, msgpack.packb(results, default=encode_datetime))
|
||||
return results
|
||||
|
||||
@router.get("/getmajornodecoords/")
|
||||
async def fastapi_get_majornode_coords(
|
||||
network: str, diameter: int
|
||||
) -> dict[str, dict[str, float]]:
|
||||
return get_major_node_coords(network, diameter)
|
||||
|
||||
@router.get("/getmajorpipenodes/")
|
||||
async def fastapi_get_major_pipe_nodes(network: str, diameter: int) -> list[str] | None:
|
||||
return get_major_pipe_nodes(network, diameter)
|
||||
|
||||
@router.get("/getnetworklinknodes/")
|
||||
async def fastapi_get_network_link_nodes(network: str) -> list[str] | None:
|
||||
return get_network_link_nodes(network)
|
||||
|
||||
# @router.get("/getallcoords/")
|
||||
# async def fastapi_get_all_coords(network: str) -> list[Any]:
|
||||
# return get_all_coords(network)
|
||||
|
||||
# @router.get("/projectcoordinates/")
|
||||
# async def fastapi_project_coordinates(
|
||||
# network: str, from_epsg: int, to_epsg: int
|
||||
# ) -> ChangeSet:
|
||||
# return project_coordinates(network, from_epsg, to_epsg)
|
||||
111
app/api/v1/endpoints/network/junctions.py
Normal file
111
app/api/v1/endpoints/network/junctions.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getjunctionschema")
|
||||
async def fast_get_junction_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_junction_schema(network)
|
||||
|
||||
@router.post("/addjunction/", response_model=None)
|
||||
async def fastapi_add_junction(
|
||||
network: str, junction: str, x: float, y: float, z: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": junction, "x": x, "y": y, "elevation": z}
|
||||
return add_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletejunction/", response_model=None)
|
||||
async def fastapi_delete_junction(network: str, junction: str) -> ChangeSet:
|
||||
ps = {"id": junction}
|
||||
return delete_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getjunctionelevation/")
|
||||
async def fastapi_get_junction_elevation(network: str, junction: str) -> float:
|
||||
ps = get_junction(network, junction)
|
||||
return ps["elevation"]
|
||||
|
||||
@router.get("/getjunctionx/")
|
||||
async def fastapi_get_junction_x(network: str, junction: str) -> float:
|
||||
ps = get_junction(network, junction)
|
||||
return ps["x"]
|
||||
|
||||
@router.get("/getjunctiony/")
|
||||
async def fastapi_get_junction_y(network: str, junction: str) -> float:
|
||||
ps = get_junction(network, junction)
|
||||
return ps["y"]
|
||||
|
||||
@router.get("/getjunctioncoord/")
|
||||
async def fastapi_get_junction_coord(network: str, junction: str) -> dict[str, float]:
|
||||
ps = get_junction(network, junction)
|
||||
coord = {"x": ps["x"], "y": ps["y"]}
|
||||
return coord
|
||||
|
||||
@router.get("/getjunctiondemand/")
|
||||
async def fastapi_get_junction_demand(network: str, junction: str) -> float:
|
||||
ps = get_junction(network, junction)
|
||||
return ps["demand"]
|
||||
|
||||
@router.get("/getjunctionpattern/")
|
||||
async def fastapi_get_junction_pattern(network: str, junction: str) -> str:
|
||||
ps = get_junction(network, junction)
|
||||
return ps["pattern"]
|
||||
|
||||
@router.post("/setjunctionelevation/", response_model=None)
|
||||
async def fastapi_set_junction_elevation(
|
||||
network: str, junction: str, elevation: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": junction, "elevation": elevation}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctionx/", response_model=None)
|
||||
async def fastapi_set_junction_x(network: str, junction: str, x: float) -> ChangeSet:
|
||||
ps = {"id": junction, "x": x}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctiony/", response_model=None)
|
||||
async def fastapi_set_junction_y(network: str, junction: str, y: float) -> ChangeSet:
|
||||
ps = {"id": junction, "y": y}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctioncoord/", response_model=None)
|
||||
async def fastapi_set_junction_coord(
|
||||
network: str, junction: str, x: float, y: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": junction, "x": x, "y": y}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctiondemand/", response_model=None)
|
||||
async def fastapi_set_junction_demand(
|
||||
network: str, junction: str, demand: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": junction, "demand": demand}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setjunctionpattern/", response_model=None)
|
||||
async def fastapi_set_junction_pattern(
|
||||
network: str, junction: str, pattern: str
|
||||
) -> ChangeSet:
|
||||
ps = {"id": junction, "pattern": pattern}
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getjunctionproperties/")
|
||||
async def fastapi_get_junction_properties(
|
||||
network: str, junction: str
|
||||
) -> dict[str, Any]:
|
||||
return get_junction(network, junction)
|
||||
|
||||
@router.get("/getalljunctionproperties/")
|
||||
async def fastapi_get_all_junction_properties(network: str) -> list[dict[str, Any]]:
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client # Redis logic removed for clean split, can be re-added if needed or imported
|
||||
results = get_all_junctions(network)
|
||||
return results
|
||||
|
||||
@router.post("/setjunctionproperties/", response_model=None)
|
||||
async def fastapi_set_junction_properties(
|
||||
network: str, junction: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": junction} | props
|
||||
return set_junction(network, ChangeSet(ps))
|
||||
133
app/api/v1/endpoints/network/pipes.py
Normal file
133
app/api/v1/endpoints/network/pipes.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpipeschema")
|
||||
async def fastapi_get_pipe_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_pipe_schema(network)
|
||||
|
||||
@router.post("/addpipe/", response_model=None)
|
||||
async def fastapi_add_pipe(
|
||||
network: str,
|
||||
pipe: str,
|
||||
node1: str,
|
||||
node2: str,
|
||||
length: float = 0,
|
||||
diameter: float = 0,
|
||||
roughness: float = 0,
|
||||
minor_loss: float = 0,
|
||||
status: str = PIPE_STATUS_OPEN,
|
||||
) -> ChangeSet:
|
||||
ps = {
|
||||
"id": pipe,
|
||||
"node1": node1,
|
||||
"node2": node2,
|
||||
"length": length,
|
||||
"diameter": diameter,
|
||||
"roughness": roughness,
|
||||
"minor_loss": minor_loss,
|
||||
"status": status,
|
||||
}
|
||||
return add_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletepipe/", response_model=None)
|
||||
async def fastapi_delete_pipe(network: str, pipe: str) -> ChangeSet:
|
||||
ps = {"id": pipe}
|
||||
return delete_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpipenode1/")
|
||||
async def fastapi_get_pipe_node1(network: str, pipe: str) -> str | None:
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["node1"]
|
||||
|
||||
@router.get("/getpipenode2/")
|
||||
async def fastapi_get_pipe_node2(network: str, pipe: str) -> str | None:
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["node2"]
|
||||
|
||||
@router.get("/getpipelength/")
|
||||
async def fastapi_get_pipe_length(network: str, pipe: str) -> float | None:
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["length"]
|
||||
|
||||
@router.get("/getpipediameter/")
|
||||
async def fastapi_get_pipe_diameter(network: str, pipe: str) -> float | None:
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["diameter"]
|
||||
|
||||
@router.get("/getpiperoughness/")
|
||||
async def fastapi_get_pipe_roughness(network: str, pipe: str) -> float | None:
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["roughness"]
|
||||
|
||||
@router.get("/getpipeminorloss/")
|
||||
async def fastapi_get_pipe_minor_loss(network: str, pipe: str) -> float | None:
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["minor_loss"]
|
||||
|
||||
@router.get("/getpipestatus/")
|
||||
async def fastapi_get_pipe_status(network: str, pipe: str) -> str | None:
|
||||
ps = get_pipe(network, pipe)
|
||||
return ps["status"]
|
||||
|
||||
@router.post("/setpipenode1/", response_model=None)
|
||||
async def fastapi_set_pipe_node1(network: str, pipe: str, node1: str) -> ChangeSet:
|
||||
ps = {"id": pipe, "node1": node1}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipenode2/", response_model=None)
|
||||
async def fastapi_set_pipe_node2(network: str, pipe: str, node2: str) -> ChangeSet:
|
||||
ps = {"id": pipe, "node2": node2}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipelength/", response_model=None)
|
||||
async def fastapi_set_pipe_length(network: str, pipe: str, length: float) -> ChangeSet:
|
||||
ps = {"id": pipe, "length": length}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipediameter/", response_model=None)
|
||||
async def fastapi_set_pipe_diameter(
|
||||
network: str, pipe: str, diameter: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": pipe, "diameter": diameter}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpiperoughness/", response_model=None)
|
||||
async def fastapi_set_pipe_roughness(
|
||||
network: str, pipe: str, roughness: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": pipe, "roughness": roughness}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipeminorloss/", response_model=None)
|
||||
async def fastapi_set_pipe_minor_loss(
|
||||
network: str, pipe: str, minor_loss: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": pipe, "minor_loss": minor_loss}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpipestatus/", response_model=None)
|
||||
async def fastapi_set_pipe_status(network: str, pipe: str, status: str) -> ChangeSet:
|
||||
ps = {"id": pipe, "status": status}
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpipeproperties/")
|
||||
async def fastapi_get_pipe_properties(network: str, pipe: str) -> dict[str, Any]:
|
||||
return get_pipe(network, pipe)
|
||||
|
||||
@router.get("/getallpipeproperties/")
|
||||
async def fastapi_get_all_pipe_properties(network: str) -> list[dict[str, Any]]:
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_pipes(network)
|
||||
return results
|
||||
|
||||
@router.post("/setpipeproperties/", response_model=None)
|
||||
async def fastapi_set_pipe_properties(
|
||||
network: str, pipe: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": pipe} | props
|
||||
return set_pipe(network, ChangeSet(ps))
|
||||
60
app/api/v1/endpoints/network/pumps.py
Normal file
60
app/api/v1/endpoints/network/pumps.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpumpschema")
|
||||
async def fastapi_get_pump_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_pump_schema(network)
|
||||
|
||||
@router.post("/addpump/", response_model=None)
|
||||
async def fastapi_add_pump(
|
||||
network: str, pump: str, node1: str, node2: str, power: float = 0.0
|
||||
) -> ChangeSet:
|
||||
ps = {"id": pump, "node1": node1, "node2": node2, "power": power}
|
||||
return add_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletepump/", response_model=None)
|
||||
async def fastapi_delete_pump(network: str, pump: str) -> ChangeSet:
|
||||
ps = {"id": pump}
|
||||
return delete_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpumpnode1/")
|
||||
async def fastapi_get_pump_node1(network: str, pump: str) -> str | None:
|
||||
ps = get_pump(network, pump)
|
||||
return ps["node1"]
|
||||
|
||||
@router.get("/getpumpnode2/")
|
||||
async def fastapi_get_pump_node2(network: str, pump: str) -> str | None:
|
||||
ps = get_pump(network, pump)
|
||||
return ps["node2"]
|
||||
|
||||
@router.post("/setpumpnode1/", response_model=None)
|
||||
async def fastapi_set_pump_node1(network: str, pump: str, node1: str) -> ChangeSet:
|
||||
ps = {"id": pump, "node1": node1}
|
||||
return set_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setpumpnode2/", response_model=None)
|
||||
async def fastapi_set_pump_node2(network: str, pump: str, node2: str) -> ChangeSet:
|
||||
ps = {"id": pump, "node2": node2}
|
||||
return set_pump(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getpumpproperties/")
|
||||
async def fastapi_get_pump_properties(network: str, pump: str) -> dict[str, Any]:
|
||||
return get_pump(network, pump)
|
||||
|
||||
@router.get("/getallpumpproperties/")
|
||||
async def fastapi_get_all_pump_properties(network: str) -> list[dict[str, Any]]:
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_pumps(network)
|
||||
return results
|
||||
|
||||
@router.post("/setpumpproperties/", response_model=None)
|
||||
async def fastapi_set_pump_properties(
|
||||
network: str, pump: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": pump} | props
|
||||
return set_pump(network, ChangeSet(ps))
|
||||
245
app/api/v1/endpoints/network/regions.py
Normal file
245
app/api/v1/endpoints/network/regions.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################################################
|
||||
# region 32
|
||||
############################################################
|
||||
|
||||
@router.get("/calculateregion/")
|
||||
async def fastapi_calculate_region(network: str, time_index: int) -> dict[str, Any]:
|
||||
return calculate_region(network, time_index)
|
||||
|
||||
@router.get("/getregionschema/")
|
||||
async def fastapi_get_region_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_region_schema(network)
|
||||
|
||||
@router.get("/getregion/")
|
||||
async def fastapi_get_region(network: str, id: str) -> dict[str, Any]:
|
||||
return get_region(network, id)
|
||||
|
||||
@router.post("/setregion/", response_model=None)
|
||||
async def fastapi_set_region(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_region(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addregion/", response_model=None)
|
||||
async def fastapi_add_region(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_region(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deleteregion/", response_model=None)
|
||||
async def fastapi_delete_region(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_region(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallregions/")
|
||||
async def fastapi_get_all_regions(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_regions(network)
|
||||
|
||||
@router.post("/generateregion/", response_model=None)
|
||||
async def fastapi_generate_region(
|
||||
network: str, inflate_delta: float
|
||||
) -> ChangeSet:
|
||||
return generate_region(network, inflate_delta)
|
||||
|
||||
|
||||
############################################################
|
||||
# district_metering_area 33
|
||||
############################################################
|
||||
|
||||
@router.get("/calculatedistrictmeteringarea/")
|
||||
async def fastapi_calculate_district_metering_area(
|
||||
network: str, req: Request
|
||||
) -> list[list[str]]:
|
||||
props = await req.json()
|
||||
nodes = props["nodes"]
|
||||
part_count = props["part_count"]
|
||||
part_type = props["part_type"]
|
||||
return calculate_district_metering_area(
|
||||
network, nodes, part_count, part_type
|
||||
)
|
||||
|
||||
@router.get("/calculatedistrictmeteringareaforregion/")
|
||||
async def fastapi_calculate_district_metering_area_for_region(
|
||||
network: str, req: Request
|
||||
) -> list[list[str]]:
|
||||
props = await req.json()
|
||||
region = props["region"]
|
||||
part_count = props["part_count"]
|
||||
part_type = props["part_type"]
|
||||
return calculate_district_metering_area_for_region(
|
||||
network, region, part_count, part_type
|
||||
)
|
||||
|
||||
@router.get("/calculatedistrictmeteringareafornetwork/")
|
||||
async def fastapi_calculate_district_metering_area_for_network(
|
||||
network: str, req: Request
|
||||
) -> list[list[str]]:
|
||||
props = await req.json()
|
||||
part_count = props["part_count"]
|
||||
part_type = props["part_type"]
|
||||
return calculate_district_metering_area_for_network(network, part_count, part_type)
|
||||
|
||||
@router.get("/getdistrictmeteringareaschema/")
|
||||
async def fastapi_get_district_metering_area_schema(
|
||||
network: str,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
return get_district_metering_area_schema(network)
|
||||
|
||||
@router.get("/getdistrictmeteringarea/")
|
||||
async def fastapi_get_district_metering_area(network: str, id: str) -> dict[str, Any]:
|
||||
return get_district_metering_area(network, id)
|
||||
|
||||
@router.post("/setdistrictmeteringarea/", response_model=None)
|
||||
async def fastapi_set_district_metering_area(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_district_metering_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/adddistrictmeteringarea/", response_model=None)
|
||||
async def fastapi_add_district_metering_area(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
# boundary should be [(x,y), (x,y)]
|
||||
boundary = props.get("boundary", [])
|
||||
newBoundary = []
|
||||
for pt in boundary:
|
||||
if len(pt) >= 2:
|
||||
newBoundary.append((pt[0], pt[1]))
|
||||
props["boundary"] = newBoundary
|
||||
return add_district_metering_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletedistrictmeteringarea/", response_model=None)
|
||||
async def fastapi_delete_district_metering_area(
|
||||
network: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_district_metering_area(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getalldistrictmeteringareaids/")
|
||||
async def fastapi_get_all_district_metering_area_ids(network: str) -> list[str]:
|
||||
return get_all_district_metering_area_ids(network)
|
||||
|
||||
@router.get("/getalldistrictmeteringareas/")
|
||||
async def getalldistrictmeteringareas(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_district_metering_areas(network)
|
||||
|
||||
@router.post("/generatedistrictmeteringarea/", response_model=None)
|
||||
async def fastapi_generate_district_metering_area(
|
||||
network: str, part_count: int, part_type: int, inflate_delta: float
|
||||
) -> ChangeSet:
|
||||
return generate_district_metering_area(
|
||||
network, part_count, part_type, inflate_delta
|
||||
)
|
||||
|
||||
@router.post("/generatesubdistrictmeteringarea/", response_model=None)
|
||||
async def fastapi_generate_sub_district_metering_area(
|
||||
network: str, dma: str, part_count: int, part_type: int, inflate_delta: float
|
||||
) -> ChangeSet:
|
||||
return generate_sub_district_metering_area(
|
||||
network, dma, part_count, part_type, inflate_delta
|
||||
)
|
||||
|
||||
|
||||
############################################################
|
||||
# service_area 34
|
||||
############################################################
|
||||
|
||||
@router.get("/calculateservicearea/")
|
||||
async def fastapi_calculate_service_area(
|
||||
network: str, time_index: int
|
||||
) -> dict[str, Any]:
|
||||
return calculate_service_area(network, time_index)
|
||||
|
||||
@router.get("/getserviceareaschema/")
|
||||
async def fastapi_get_service_area_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_service_area_schema(network)
|
||||
|
||||
@router.get("/getservicearea/")
|
||||
async def fastapi_get_service_area(network: str, id: str) -> dict[str, Any]:
|
||||
return get_service_area(network, id)
|
||||
|
||||
@router.post("/setservicearea/", response_model=None)
|
||||
async def fastapi_set_service_area(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_service_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addservicearea/", response_model=None)
|
||||
async def fastapi_add_service_area(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_service_area(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deleteservicearea/", response_model=None)
|
||||
async def fastapi_delete_service_area(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_service_area(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallserviceareas/")
|
||||
async def fastapi_get_all_service_areas(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_service_areas(network)
|
||||
|
||||
@router.post("/generateservicearea/", response_model=None)
|
||||
async def fastapi_generate_service_area(
|
||||
network: str, inflate_delta: float
|
||||
) -> ChangeSet:
|
||||
return generate_service_area(network, inflate_delta)
|
||||
|
||||
|
||||
############################################################
|
||||
# virtual_district 35
|
||||
############################################################
|
||||
|
||||
@router.get("/calculatevirtualdistrict/")
|
||||
async def fastapi_calculate_virtual_district(
|
||||
network: str, centers: list[str]
|
||||
) -> dict[str, list[Any]]:
|
||||
return calculate_virtual_district(network, centers)
|
||||
|
||||
@router.get("/getvirtualdistrictschema/")
|
||||
async def fastapi_get_virtual_district_schema(
|
||||
network: str,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
return get_virtual_district_schema(network)
|
||||
|
||||
@router.get("/getvirtualdistrict/")
|
||||
async def fastapi_get_virtual_district(network: str, id: str) -> dict[str, Any]:
|
||||
return get_virtual_district(network, id)
|
||||
|
||||
@router.post("/setvirtualdistrict/", response_model=None)
|
||||
async def fastapi_set_virtual_district(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_virtual_district(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addvirtualdistrict/", response_model=None)
|
||||
async def fastapi_add_virtual_district(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_virtual_district(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletevirtualdistrict/", response_model=None)
|
||||
async def fastapi_delete_virtual_district(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_virtual_district(network, ChangeSet(props))
|
||||
|
||||
@router.get("/getallvirtualdistrict/")
|
||||
async def fastapi_get_all_virtual_district(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_virtual_districts(network)
|
||||
|
||||
@router.post("/generatevirtualdistrict/", response_model=None)
|
||||
async def fastapi_generate_virtual_district(
|
||||
network: str, inflate_delta: float, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return generate_virtual_district(network, props["centers"], inflate_delta)
|
||||
|
||||
@router.get("/calculatedistrictmeteringareafornodes/")
|
||||
async def fastapi_calculate_district_metering_area_for_nodes(
|
||||
network: str, req: Request
|
||||
) -> list[list[str]]:
|
||||
props = await req.json()
|
||||
nodes = props["nodes"]
|
||||
part_count = props["part_count"]
|
||||
part_type = props["part_type"]
|
||||
return calculate_district_metering_area_for_nodes(
|
||||
network, nodes, part_count, part_type
|
||||
)
|
||||
105
app/api/v1/endpoints/network/reservoirs.py
Normal file
105
app/api/v1/endpoints/network/reservoirs.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getreservoirschema")
|
||||
async def fast_get_reservoir_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_reservoir_schema(network)
|
||||
|
||||
@router.post("/addreservoir/", response_model=None)
|
||||
async def fastapi_add_reservoir(
|
||||
network: str, reservoir: str, x: float, y: float, head: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": reservoir, "x": x, "y": y, "head": head}
|
||||
return add_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletereservoir/", response_model=None)
|
||||
async def fastapi_delete_reservoir(network: str, reservoir: str) -> ChangeSet:
|
||||
ps = {"id": reservoir}
|
||||
return delete_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getreservoirhead/")
|
||||
async def fastapi_get_reservoir_head(network: str, reservoir: str) -> float | None:
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["head"]
|
||||
|
||||
@router.get("/getreservoirpattern/")
|
||||
async def fastapi_get_reservoir_pattern(network: str, reservoir: str) -> str | None:
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["pattern"]
|
||||
|
||||
@router.get("/getreservoirx/")
|
||||
async def fastapi_get_reservoir_x(
|
||||
network: str, reservoir: str
|
||||
) -> dict[str, float] | None:
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["x"]
|
||||
|
||||
@router.get("/getreservoiry/")
|
||||
async def fastapi_get_reservoir_y(
|
||||
network: str, reservoir: str
|
||||
) -> dict[str, float] | None:
|
||||
ps = get_reservoir(network, reservoir)
|
||||
return ps["y"]
|
||||
|
||||
@router.get("/getreservoircoord/")
|
||||
async def fastapi_get_reservoir_coord(
|
||||
network: str, reservoir: str
|
||||
) -> dict[str, float] | None:
|
||||
ps = get_reservoir(network, reservoir)
|
||||
coord = {"id": reservoir, "x": ps["x"], "y": ps["y"]}
|
||||
return coord
|
||||
|
||||
@router.post("/setreservoirhead/", response_model=None)
|
||||
async def fastapi_set_reservoir_head(
|
||||
network: str, reservoir: str, head: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": reservoir, "head": head}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoirpattern/", response_model=None)
|
||||
async def fastapi_set_reservoir_pattern(
|
||||
network: str, reservoir: str, pattern: str
|
||||
) -> ChangeSet:
|
||||
ps = {"id": reservoir, "pattern": pattern}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoirx/", response_model=None)
|
||||
async def fastapi_set_reservoir_x(network: str, reservoir: str, x: float) -> ChangeSet:
|
||||
ps = {"id": reservoir, "x": x}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoiry/", response_model=None)
|
||||
async def fastapi_set_reservoir_y(network: str, reservoir: str, y: float) -> ChangeSet:
|
||||
ps = {"id": reservoir, "y": y}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setreservoircoord/", response_model=None)
|
||||
async def fastapi_set_reservoir_coord(
|
||||
network: str, reservoir: str, x: float, y: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": reservoir, "x": x, "y": y}
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getreservoirproperties/")
|
||||
async def fastapi_get_reservoir_properties(
|
||||
network: str, reservoir: str
|
||||
) -> dict[str, Any]:
|
||||
return get_reservoir(network, reservoir)
|
||||
|
||||
@router.get("/getallreservoirproperties/")
|
||||
async def fastapi_get_all_reservoir_properties(network: str) -> list[dict[str, Any]]:
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_reservoirs(network)
|
||||
return results
|
||||
|
||||
@router.post("/setreservoirproperties/", response_model=None)
|
||||
async def fastapi_set_reservoir_properties(
|
||||
network: str, reservoir: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": reservoir} | props
|
||||
return set_reservoir(network, ChangeSet(ps))
|
||||
27
app/api/v1/endpoints/network/tags.py
Normal file
27
app/api/v1/endpoints/network/tags.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################################################
|
||||
# tag 8.[TAGS]
|
||||
############################################################
|
||||
|
||||
@router.get("/gettagschema/")
|
||||
async def fastapi_get_tag_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_tag_schema(network)
|
||||
|
||||
@router.get("/gettag/")
|
||||
async def fastapi_get_tag(network: str, t_type: str, id: str) -> dict[str, Any]:
|
||||
return get_tag(network, t_type, id)
|
||||
|
||||
@router.get("/gettags/")
|
||||
async def fastapi_get_tags(network: str) -> list[dict[str, Any]]:
|
||||
tags = get_tags(network)
|
||||
return tags
|
||||
|
||||
@router.post("/settag/", response_model=None)
|
||||
async def fastapi_set_tag(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_tag(network, ChangeSet(props))
|
||||
188
app/api/v1/endpoints/network/tanks.py
Normal file
188
app/api/v1/endpoints/network/tanks.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/gettankschema")
|
||||
async def fast_get_tank_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_tank_schema(network)
|
||||
|
||||
@router.post("/addtank/", response_model=None)
|
||||
async def fastapi_add_tank(
|
||||
network: str,
|
||||
tank: str,
|
||||
x: float,
|
||||
y: float,
|
||||
elevation: float,
|
||||
init_level: float = 0,
|
||||
min_level: float = 0,
|
||||
max_level: float = 0,
|
||||
diameter: float = 0,
|
||||
min_vol: float = 0,
|
||||
) -> ChangeSet:
|
||||
ps = {
|
||||
"id": tank,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"elevation": elevation,
|
||||
"init_level": init_level,
|
||||
"min_level": min_level,
|
||||
"max_level": max_level,
|
||||
"diameter": diameter,
|
||||
"min_vol": min_vol,
|
||||
}
|
||||
return add_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletetank/", response_model=None)
|
||||
async def fastapi_delete_tank(network: str, tank: str) -> ChangeSet:
|
||||
ps = {"id": tank}
|
||||
return delete_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/gettankelevation/")
|
||||
async def fastapi_get_tank_elevation(network: str, tank: str) -> float | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["elevation"]
|
||||
|
||||
@router.get("/gettankinitlevel/")
|
||||
async def fastapi_get_tank_init_level(network: str, tank: str) -> float | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["init_level"]
|
||||
|
||||
@router.get("/gettankminlevel/")
|
||||
async def fastapi_get_tank_min_level(network: str, tank: str) -> float | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["min_level"]
|
||||
|
||||
@router.get("/gettankmaxlevel/")
|
||||
async def fastapi_get_tank_max_level(network: str, tank: str) -> float | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["max_level"]
|
||||
|
||||
@router.get("/gettankdiameter/")
|
||||
async def fastapi_get_tank_diameter(network: str, tank: str) -> float | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["diameter"]
|
||||
|
||||
@router.get("/gettankminvol/")
|
||||
async def fastapi_get_tank_min_vol(network: str, tank: str) -> float | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["min_vol"]
|
||||
|
||||
@router.get("/gettankvolcurve/")
|
||||
async def fastapi_get_tank_vol_curve(network: str, tank: str) -> str | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["vol_curve"]
|
||||
|
||||
@router.get("/gettankoverflow/")
|
||||
async def fastapi_get_tank_overflow(network: str, tank: str) -> str | None:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["overflow"]
|
||||
|
||||
@router.get("/gettankx/")
|
||||
async def fastapi_get_tank_x(network: str, tank: str) -> float:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["x"]
|
||||
|
||||
@router.get("/gettanky/")
|
||||
async def fastapi_get_tank_y(network: str, tank: str) -> float:
|
||||
ps = get_tank(network, tank)
|
||||
return ps["y"]
|
||||
|
||||
@router.get("/gettankcoord/")
|
||||
async def fastapi_get_tank_coord(network: str, tank: str) -> dict[str, float]:
|
||||
ps = get_tank(network, tank)
|
||||
coord = {"x": ps["x"], "y": ps["y"]}
|
||||
return coord
|
||||
|
||||
@router.post("/settankelevation/", response_model=None)
|
||||
async def fastapi_set_tank_elevation(
|
||||
network: str, tank: str, elevation: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "elevation": elevation}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankinitlevel/", response_model=None)
|
||||
async def fastapi_set_tank_init_level(
|
||||
network: str, tank: str, init_level: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "init_level": init_level}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankminlevel/", response_model=None)
|
||||
async def fastapi_set_tank_min_level(
|
||||
network: str, tank: str, min_level: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "min_level": min_level}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankmaxlevel/", response_model=None)
|
||||
async def fastapi_set_tank_max_level(
|
||||
network: str, tank: str, max_level: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "max_level": max_level}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("settankdiameter//", response_model=None)
|
||||
async def fastapi_set_tank_diameter(
|
||||
network: str, tank: str, diameter: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "diameter": diameter}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankminvol/", response_model=None)
|
||||
async def fastapi_set_tank_min_vol(
|
||||
network: str, tank: str, min_vol: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "min_vol": min_vol}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankvolcurve/", response_model=None)
|
||||
async def fastapi_set_tank_vol_curve(
|
||||
network: str, tank: str, vol_curve: str
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "vol_curve": vol_curve}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankoverflow/", response_model=None)
|
||||
async def fastapi_set_tank_overflow(
|
||||
network: str, tank: str, overflow: str
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "overflow": overflow}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankx/", response_model=None)
|
||||
async def fastapi_set_tank_x(network: str, tank: str, x: float) -> ChangeSet:
|
||||
ps = {"id": tank, "x": x}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settanky/", response_model=None)
|
||||
async def fastapi_set_tank_y(network: str, tank: str, y: float) -> ChangeSet:
|
||||
ps = {"id": tank, "y": y}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/settankcoord/", response_model=None)
|
||||
async def fastapi_set_tank_coord(
|
||||
network: str, tank: str, x: float, y: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": tank, "x": x, "y": y}
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/gettankproperties/")
|
||||
async def fastapi_get_tank_properties(network: str, tank: str) -> dict[str, Any]:
|
||||
return get_tank(network, tank)
|
||||
|
||||
@router.get("/getalltankproperties/")
|
||||
async def fastapi_get_all_tank_properties(network: str) -> list[dict[str, Any]]:
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_tanks(network)
|
||||
return results
|
||||
|
||||
@router.post("/settankproperties/", response_model=None)
|
||||
async def fastapi_set_tank_properties(
|
||||
network: str, tank: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": tank} | props
|
||||
return set_tank(network, ChangeSet(ps))
|
||||
115
app/api/v1/endpoints/network/valves.py
Normal file
115
app/api/v1/endpoints/network/valves.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getvalveschema")
|
||||
async def fastapi_get_valve_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_valve_schema(network)
|
||||
|
||||
@router.post("/addvalve/", response_model=None)
|
||||
async def fastapi_add_valve(
|
||||
network: str,
|
||||
valve: str,
|
||||
node1: str,
|
||||
node2: str,
|
||||
diameter: float = 0,
|
||||
v_type: str = VALVES_TYPE_PRV,
|
||||
setting: float = 0,
|
||||
minor_loss: float = 0,
|
||||
) -> ChangeSet:
|
||||
ps = {
|
||||
"id": valve,
|
||||
"node1": node1,
|
||||
"node2": node2,
|
||||
"diameter": diameter,
|
||||
"v_type": v_type,
|
||||
"setting": setting,
|
||||
"minor_loss": minor_loss,
|
||||
}
|
||||
|
||||
return add_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/deletevalve/", response_model=None)
|
||||
async def fastapi_delete_valve(network: str, valve: str) -> ChangeSet:
|
||||
ps = {"id": valve}
|
||||
return delete_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getvalvenode1/")
|
||||
async def fastapi_get_valve_node1(network: str, valve: str) -> str | None:
|
||||
ps = get_valve(network, valve)
|
||||
return ps["node1"]
|
||||
|
||||
@router.get("/getvalvenode2/")
|
||||
async def fastapi_get_valve_node2(network: str, valve: str) -> str | None:
|
||||
ps = get_valve(network, valve)
|
||||
return ps["node2"]
|
||||
|
||||
@router.get("/getvalvediameter/")
|
||||
async def fastapi_get_valve_diameter(network: str, valve: str) -> float | None:
|
||||
ps = get_valve(network, valve)
|
||||
return ps["diameter"]
|
||||
|
||||
@router.get("/getvalvetype/")
|
||||
async def fastapi_get_valve_type(network: str, valve: str) -> str | None:
|
||||
ps = get_valve(network, valve)
|
||||
return ps["type"]
|
||||
|
||||
@router.get("/getvalvesetting/")
|
||||
async def fastapi_get_valve_setting(network: str, valve: str) -> float | None:
|
||||
ps = get_valve(network, valve)
|
||||
return ps["setting"]
|
||||
|
||||
@router.get("/getvalveminorloss/")
|
||||
async def fastapi_get_valve_minor_loss(network: str, valve: str) -> float | None:
|
||||
ps = get_valve(network, valve)
|
||||
return ps["minor_loss"]
|
||||
|
||||
@router.post("/setvalvenode1/", response_model=None)
|
||||
async def fastapi_set_valve_node1(network: str, valve: str, node1: str) -> ChangeSet:
|
||||
ps = {"id": valve, "node1": node1}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvenode2/", response_model=None)
|
||||
async def fastapi_set_valve_node2(network: str, valve: str, node2: str) -> ChangeSet:
|
||||
ps = {"id": valve, "node2": node2}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvenodediameter/", response_model=None)
|
||||
async def fastapi_set_valve_diameter(
|
||||
network: str, valve: str, diameter: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": valve, "diameter": diameter}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvetype/", response_model=None)
|
||||
async def fastapi_set_valve_type(network: str, valve: str, type: str) -> ChangeSet:
|
||||
ps = {"id": valve, "type": type}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.post("/setvalvesetting/", response_model=None)
|
||||
async def fastapi_set_valve_setting(
|
||||
network: str, valve: str, setting: float
|
||||
) -> ChangeSet:
|
||||
ps = {"id": valve, "setting": setting}
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
|
||||
@router.get("/getvalveproperties/")
|
||||
async def fastapi_get_valve_properties(network: str, valve: str) -> dict[str, Any]:
|
||||
return get_valve(network, valve)
|
||||
|
||||
@router.get("/getallvalveproperties/")
|
||||
async def fastapi_get_all_valve_properties(network: str) -> list[dict[str, Any]]:
|
||||
# 缓存查询结果提高性能
|
||||
# global redis_client
|
||||
results = get_all_valves(network)
|
||||
return results
|
||||
|
||||
@router.post("/setvalveproperties/", response_model=None)
|
||||
async def fastapi_set_valve_properties(
|
||||
network: str, valve: str, req: Request
|
||||
) -> ChangeSet:
|
||||
props = await req.json()
|
||||
ps = {"id": valve} | props
|
||||
return set_valve(network, ChangeSet(ps))
|
||||
@@ -0,0 +1,226 @@
|
||||
import json
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from typing import Any, Dict
|
||||
import app.services.project_info as project_info
|
||||
from app.native.api import ChangeSet
|
||||
from app.infra.db.postgresql.database import get_database_instance as get_pg_db
|
||||
from app.infra.db.timescaledb.database import get_database_instance as get_ts_db
|
||||
from app.services.tjnetwork import (
|
||||
list_project,
|
||||
have_project,
|
||||
create_project,
|
||||
delete_project,
|
||||
is_project_open,
|
||||
open_project,
|
||||
close_project,
|
||||
copy_project,
|
||||
import_inp,
|
||||
export_inp,
|
||||
read_inp,
|
||||
dump_inp,
|
||||
get_all_vertices,
|
||||
get_all_scada_elements,
|
||||
get_all_district_metering_areas,
|
||||
get_all_service_areas,
|
||||
get_all_virtual_districts,
|
||||
get_extension_data,
|
||||
convert_inp_v3_to_v2,
|
||||
)
|
||||
|
||||
# For inp file upload/download
|
||||
import os
|
||||
from fastapi import Response, status
|
||||
from fastapi.responses import FileResponse
|
||||
inpDir = "data/" # Assuming data directory exists or is defined somewhere.
|
||||
# In main.py it was likely global. For safety, let's use a relative path or get from config.
|
||||
# But let's stick to what main.py probably used or a default.
|
||||
|
||||
router = APIRouter()
|
||||
lockedPrjs: Dict[str, str] = {}
|
||||
|
||||
@router.get("/listprojects/")
|
||||
async def list_projects_endpoint() -> list[str]:
|
||||
return list_project()
|
||||
|
||||
@router.get("/haveproject/")
|
||||
async def have_project_endpoint(network: str):
|
||||
return have_project(network)
|
||||
|
||||
@router.post("/createproject/")
|
||||
async def create_project_endpoint(network: str):
|
||||
create_project(network)
|
||||
return network
|
||||
|
||||
@router.post("/deleteproject/")
|
||||
async def delete_project_endpoint(network: str):
|
||||
delete_project(network)
|
||||
return True
|
||||
|
||||
@router.get("/isprojectopen/")
|
||||
async def is_project_open_endpoint(network: str):
|
||||
return is_project_open(network)
|
||||
|
||||
@router.post("/openproject/")
|
||||
async def open_project_endpoint(network: str):
|
||||
open_project(network)
|
||||
|
||||
# 尝试连接指定数据库
|
||||
try:
|
||||
# 初始化 PostgreSQL 连接池
|
||||
pg_instance = await get_pg_db(network)
|
||||
async with pg_instance.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
|
||||
# 初始化 TimescaleDB 连接池
|
||||
ts_instance = await get_ts_db(network)
|
||||
async with ts_instance.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute("SELECT 1")
|
||||
|
||||
except Exception as e:
|
||||
# 记录错误但不阻断项目打开,或者根据需求决定是否阻断
|
||||
# 这里选择打印错误,因为 open_project 原本只负责原生部分
|
||||
print(f"Failed to connect to databases for {network}: {str(e)}")
|
||||
# 如果数据库连接是必须的,可以抛出异常:
|
||||
# raise HTTPException(status_code=500, detail=f"Database connection failed: {str(e)}")
|
||||
|
||||
return network
|
||||
|
||||
@router.post("/closeproject/")
|
||||
async def close_project_endpoint(network: str):
|
||||
close_project(network)
|
||||
return True
|
||||
|
||||
@router.post("/copyproject/")
|
||||
async def copy_project_endpoint(source: str, target: str):
|
||||
copy_project(source, target)
|
||||
return True
|
||||
|
||||
@router.post("/importinp/")
|
||||
async def import_inp_endpoint(network: str, req: Request):
|
||||
jo_root = await req.json()
|
||||
inp_text = jo_root["inp"]
|
||||
ps = {"inp": inp_text}
|
||||
ret = import_inp(network, ChangeSet(ps))
|
||||
print(ret)
|
||||
return ret
|
||||
|
||||
@router.get("/exportinp/", response_model=None)
|
||||
async def export_inp_endpoint(network: str, version: str) -> ChangeSet:
|
||||
cs = export_inp(network, version)
|
||||
op = cs.operations[0]
|
||||
open_project(network)
|
||||
op["vertex"] = json.dumps(get_all_vertices(network))
|
||||
op["scada"] = json.dumps(get_all_scada_elements(network))
|
||||
op["dma"] = json.dumps(get_all_district_metering_areas(network))
|
||||
op["sa"] = json.dumps(get_all_service_areas(network))
|
||||
op["vd"] = json.dumps(get_all_virtual_districts(network))
|
||||
op["legend"] = get_extension_data(network, "legend")
|
||||
|
||||
db = get_extension_data(network, "scada_db")
|
||||
print(db)
|
||||
scada_db = ""
|
||||
if db:
|
||||
scada_db = db
|
||||
print(scada_db)
|
||||
op["scada_db"] = scada_db
|
||||
|
||||
close_project(network)
|
||||
|
||||
return cs
|
||||
|
||||
@router.post("/readinp/")
|
||||
async def read_inp_endpoint(network: str, inp: str) -> bool:
|
||||
read_inp(network, inp)
|
||||
return True
|
||||
|
||||
@router.get("/dumpinp/")
|
||||
async def dump_inp_endpoint(network: str, inp: str) -> bool:
|
||||
dump_inp(network, inp)
|
||||
return True
|
||||
|
||||
@router.get("/isprojectlocked/")
|
||||
async def is_project_locked_endpoint(network: str, req: Request):
|
||||
return network in lockedPrjs.keys()
|
||||
|
||||
@router.get("/isprojectlockedbyme/")
|
||||
async def is_project_locked_by_me_endpoint(network: str, req: Request):
|
||||
client_host = req.client.host
|
||||
return lockedPrjs.get(network) == client_host
|
||||
|
||||
# 0 successfully locked
|
||||
# 1 already locked by you
|
||||
# 2 locked by others
|
||||
@router.post("/lockproject/")
|
||||
async def lock_project_endpoint(network: str, req: Request):
|
||||
client_host = req.client.host
|
||||
if not network in lockedPrjs.keys():
|
||||
lockedPrjs[network] = client_host
|
||||
return 0
|
||||
else:
|
||||
if lockedPrjs.get(network) == client_host:
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
@router.post("/unlockproject/")
|
||||
def unlock_project_endpoint(network: str, req: Request):
|
||||
client_host = req.client.host
|
||||
if lockedPrjs.get(network) == client_host:
|
||||
print("delete key")
|
||||
del lockedPrjs[network]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# inp file operations
|
||||
@router.post("/uploadinp/", status_code=status.HTTP_200_OK)
|
||||
async def fastapi_upload_inp(afile: bytes, name: str):
|
||||
if not os.path.exists(inpDir):
|
||||
os.makedirs(inpDir, exist_ok=True)
|
||||
|
||||
filePath = inpDir + str(name)
|
||||
with open(filePath, "wb") as f:
|
||||
f.write(afile)
|
||||
return True
|
||||
|
||||
@router.get("/downloadinp/", status_code=status.HTTP_200_OK)
|
||||
async def fastapi_download_inp(name: str, response: Response):
|
||||
filePath = inpDir + name
|
||||
if os.path.exists(filePath):
|
||||
return FileResponse(
|
||||
filePath, media_type="application/octet-stream", filename="inp.inp"
|
||||
)
|
||||
else:
|
||||
response.status_code = status.HTTP_400_BAD_REQUEST
|
||||
return True
|
||||
|
||||
# DingZQ, 2024-12-28, convert v3 to v2
|
||||
@router.get("/convertv3tov2/", response_model=None)
|
||||
async def fastapi_convert_v3_to_v2(req: Request) -> ChangeSet:
|
||||
network = "v3Tov2"
|
||||
jo_root = await req.json()
|
||||
inp = jo_root["inp"]
|
||||
cs = convert_inp_v3_to_v2(inp)
|
||||
op = cs.operations[0]
|
||||
open_project(network)
|
||||
op["vertex"] = json.dumps(get_all_vertices(network))
|
||||
op["scada"] = json.dumps(get_all_scada_elements(network))
|
||||
op["dma"] = json.dumps(get_all_district_metering_areas(network))
|
||||
op["sa"] = json.dumps(get_all_service_areas(network))
|
||||
op["vd"] = json.dumps(get_all_virtual_districts(network))
|
||||
op["legend"] = get_extension_data(network, "legend")
|
||||
|
||||
db = get_extension_data(network, "scada_db")
|
||||
print(db)
|
||||
scada_db = ""
|
||||
if db:
|
||||
scada_db = db
|
||||
print(scada_db)
|
||||
op["scada_db"] = scada_db
|
||||
|
||||
close_project(network)
|
||||
|
||||
return cs
|
||||
|
||||
44
app/api/v1/endpoints/risk.py
Normal file
44
app/api/v1/endpoints/risk.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any, List, Dict
|
||||
from fastapi import APIRouter
|
||||
from app.services.tjnetwork import (
|
||||
get_pipe_risk_probability_now,
|
||||
get_pipe_risk_probability,
|
||||
get_pipes_risk_probability,
|
||||
get_network_pipe_risk_probability_now,
|
||||
get_pipe_risk_probability_geometries,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getpiperiskprobabilitynow/")
|
||||
async def fastapi_get_pipe_risk_probability_now(
|
||||
network: str, pipe_id: str
|
||||
) -> dict[str, Any]:
|
||||
return get_pipe_risk_probability_now(network, pipe_id)
|
||||
|
||||
|
||||
@router.get("/getpiperiskprobability/")
|
||||
async def fastapi_get_pipe_risk_probability(
|
||||
network: str, pipe_id: str
|
||||
) -> dict[str, Any]:
|
||||
return get_pipe_risk_probability(network, pipe_id)
|
||||
|
||||
|
||||
@router.get("/getpipesriskprobability/")
|
||||
async def fastapi_get_pipes_risk_probability(
|
||||
network: str, pipe_ids: str
|
||||
) -> list[dict[str, Any]]:
|
||||
pipeids = pipe_ids.split(",")
|
||||
return get_pipes_risk_probability(network, pipeids)
|
||||
|
||||
|
||||
@router.get("/getnetworkpiperiskprobabilitynow/")
|
||||
async def fastapi_get_network_pipe_risk_probability_now(
|
||||
network: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
return get_network_pipe_risk_probability_now(network)
|
||||
|
||||
|
||||
@router.get("/getpiperiskprobabilitygeometries/")
|
||||
async def fastapi_get_pipe_risk_probability_geometries(network: str) -> dict[str, Any]:
|
||||
return get_pipe_risk_probability_geometries(network)
|
||||
@@ -0,0 +1,169 @@
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Request
|
||||
from app.native.api import ChangeSet
|
||||
from app.services.tjnetwork import (
|
||||
get_scada_info,
|
||||
get_all_scada_info,
|
||||
get_scada_device_schema,
|
||||
get_scada_device,
|
||||
set_scada_device,
|
||||
add_scada_device,
|
||||
delete_scada_device,
|
||||
clean_scada_device,
|
||||
get_all_scada_device_ids,
|
||||
get_all_scada_devices,
|
||||
get_scada_device_data_schema,
|
||||
get_scada_device_data,
|
||||
set_scada_device_data,
|
||||
add_scada_device_data,
|
||||
delete_scada_device_data,
|
||||
clean_scada_device_data,
|
||||
get_scada_element_schema,
|
||||
get_scada_element,
|
||||
set_scada_element,
|
||||
add_scada_element,
|
||||
delete_scada_element,
|
||||
clean_scada_element,
|
||||
get_all_scada_elements,
|
||||
get_scada_element_schema,
|
||||
get_scada_info_schema,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getscadaproperties/")
|
||||
async def fast_get_scada_properties(network: str, scada: str) -> dict[str, Any]:
|
||||
return get_scada_info(network, scada)
|
||||
|
||||
@router.get("/getallscadaproperties/")
|
||||
async def fast_get_all_scada_properties(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_scada_info(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_device 29
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadadeviceschema/")
|
||||
async def fastapi_get_scada_device_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_scada_device_schema(network)
|
||||
|
||||
@router.get("/getscadadevice/")
|
||||
async def fastapi_get_scada_device(network: str, id: str) -> dict[str, Any]:
|
||||
return get_scada_device(network, id)
|
||||
|
||||
@router.post("/setscadadevice/", response_model=None)
|
||||
async def fastapi_set_scada_device(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_scada_device(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addscadadevice/", response_model=None)
|
||||
async def fastapi_add_scada_device(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_scada_device(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletescadadevice/", response_model=None)
|
||||
async def fastapi_delete_scada_device(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_scada_device(network, ChangeSet(props))
|
||||
|
||||
@router.post("/cleanscadadevice/", response_model=None)
|
||||
async def fastapi_clean_scada_device(network: str) -> ChangeSet:
|
||||
return clean_scada_device(network)
|
||||
|
||||
@router.get("/getallscadadeviceids/")
|
||||
async def fastapi_get_all_scada_device_ids(network: str) -> list[str]:
|
||||
return get_all_scada_device_ids(network)
|
||||
|
||||
@router.get("/getallscadadevices/")
|
||||
async def fastapi_get_all_scada_devices(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_scada_devices(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_device_data 30
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadadevicedataschema/")
|
||||
async def fastapi_get_scada_device_data_schema(
|
||||
network: str,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
return get_scada_device_data_schema(network)
|
||||
|
||||
@router.get("/getscadadevicedata/")
|
||||
async def fastapi_get_scada_device_data(network: str, device_id: str) -> dict[str, Any]:
|
||||
return get_scada_device_data(network, device_id)
|
||||
|
||||
@router.post("/setscadadevicedata/", response_model=None)
|
||||
async def fastapi_set_scada_device_data(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_scada_device_data(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addscadadevicedata/", response_model=None)
|
||||
async def fastapi_add_scada_device_data(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_scada_device_data(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletescadadevicedata/", response_model=None)
|
||||
async def fastapi_delete_scada_device_data(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_scada_device_data(network, ChangeSet(props))
|
||||
|
||||
@router.post("/cleanscadadevicedata/", response_model=None)
|
||||
async def fastapi_clean_scada_device_data(network: str) -> ChangeSet:
|
||||
return clean_scada_device_data(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_element 31
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadaelementschema/")
|
||||
async def fastapi_get_scada_element_schema(
|
||||
network: str,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
return get_scada_element_schema(network)
|
||||
|
||||
@router.get("/getscadaelements/")
|
||||
async def fastapi_get_scada_elements(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_scada_elements(network)
|
||||
|
||||
@router.get("/getscadaelement/")
|
||||
async def fastapi_get_scada_element(network: str, id: str) -> dict[str, Any]:
|
||||
return get_scada_element(network, id)
|
||||
|
||||
@router.post("/setscadaelement/", response_model=None)
|
||||
async def fastapi_set_scada_element(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return set_scada_element(network, ChangeSet(props))
|
||||
|
||||
@router.post("/addscadaelement/", response_model=None)
|
||||
async def fastapi_add_scada_element(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return add_scada_element(network, ChangeSet(props))
|
||||
|
||||
@router.post("/deletescadaelement/", response_model=None)
|
||||
async def fastapi_delete_scada_element(network: str, req: Request) -> ChangeSet:
|
||||
props = await req.json()
|
||||
return delete_scada_element(network, ChangeSet(props))
|
||||
|
||||
@router.post("/cleanscadaelement/", response_model=None)
|
||||
async def fastapi_clean_scada_element(network: str) -> ChangeSet:
|
||||
return clean_scada_element(network)
|
||||
|
||||
|
||||
############################################################
|
||||
# scada_info 38
|
||||
############################################################
|
||||
|
||||
@router.get("/getscadainfoschema/")
|
||||
async def fastapi_get_scada_info_schema(network: str) -> dict[str, dict[str, Any]]:
|
||||
return get_scada_info_schema(network)
|
||||
|
||||
@router.get("/getscadainfo/")
|
||||
async def fastapi_get_scada_info(network: str, id: str) -> dict[str, Any]:
|
||||
return get_scada_info(network, id)
|
||||
|
||||
@router.get("/getallscadainfo/")
|
||||
async def fastapi_get_all_scada_info(network: str) -> list[dict[str, Any]]:
|
||||
return get_all_scada_info(network)
|
||||
|
||||
17
app/api/v1/endpoints/schemes.py
Normal file
17
app/api/v1/endpoints/schemes.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from fastapi import APIRouter
|
||||
from typing import Any, List, Dict
|
||||
from app.services.tjnetwork import get_scheme_schema, get_scheme, get_all_schemes
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getschemeschema/")
|
||||
async def fastapi_get_scheme_schema(network: str) -> dict[str, dict[Any, Any]]:
|
||||
return get_scheme_schema(network)
|
||||
|
||||
@router.get("/getscheme/")
|
||||
async def fastapi_get_scheme(network: str, schema_name: str) -> dict[Any, Any]:
|
||||
return get_scheme(network, schema_name)
|
||||
|
||||
@router.get("/getallschemes/")
|
||||
async def fastapi_get_all_schemes(network: str) -> list[dict[Any, Any]]:
|
||||
return get_all_schemes(network)
|
||||
@@ -0,0 +1,670 @@
|
||||
from typing import Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import pandas as pd
|
||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
import app.infra.db.influxdb.api as influxdb_api
|
||||
import app.services.simulation as simulation
|
||||
import app.services.globals as globals
|
||||
from app.infra.cache.redis_client import redis_client
|
||||
from app.services.tjnetwork import (
|
||||
run_project,
|
||||
run_project_return_dict,
|
||||
run_inp,
|
||||
dump_output,
|
||||
)
|
||||
from app.algorithms.simulations import (
|
||||
burst_analysis,
|
||||
valve_close_analysis,
|
||||
flushing_analysis,
|
||||
contaminant_simulation,
|
||||
age_analysis,
|
||||
# scheduling_analysis,
|
||||
pressure_regulation,
|
||||
)
|
||||
from app.algorithms.sensors import (
|
||||
pressure_sensor_placement_sensitivity,
|
||||
pressure_sensor_placement_kmeans,
|
||||
)
|
||||
import app.algorithms.api_ex.flow_data_clean as flow_data_clean
|
||||
import app.algorithms.api_ex.pressure_data_clean as pressure_data_clean
|
||||
from app.services.network_import import network_update
|
||||
from app.services.simulation_ops import (
|
||||
project_management,
|
||||
scheduling_simulation,
|
||||
daily_scheduling_simulation,
|
||||
)
|
||||
from app.services.valve_isolation import analyze_valve_isolation
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RunSimulationManuallyByDate(BaseModel):
|
||||
name: str
|
||||
simulation_date: str
|
||||
start_time: str
|
||||
duration: int
|
||||
|
||||
|
||||
class BurstAnalysis(BaseModel):
|
||||
name: str
|
||||
modify_pattern_start_time: str
|
||||
burst_ID: List[str] | str | None = None
|
||||
burst_size: List[float] | float | int | None = None
|
||||
modify_total_duration: int = 900
|
||||
modify_fixed_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_variable_pump_pattern: Optional[dict[str, list]] = None
|
||||
modify_valve_opening: Optional[dict[str, float]] = None
|
||||
scheme_name: Optional[str] = None
|
||||
|
||||
|
||||
class SchedulingAnalysis(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
tank_id: str
|
||||
water_plant_output_id: str
|
||||
time_delta: Optional[int] = 300
|
||||
|
||||
|
||||
class PressureRegulation(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
tank_init_level: Optional[dict] = None
|
||||
duration: Optional[int] = 900
|
||||
scheme_name: Optional[str] = None
|
||||
|
||||
|
||||
class ProjectManagement(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
tank_init_level: Optional[dict] = None
|
||||
region_demand: Optional[dict] = None
|
||||
|
||||
|
||||
class DailySchedulingAnalysis(BaseModel):
|
||||
network: str
|
||||
start_time: str
|
||||
pump_control: dict
|
||||
reservoir_id: str
|
||||
tank_id: str
|
||||
water_plant_output_id: str
|
||||
time_delta: Optional[int] = 300
|
||||
|
||||
|
||||
class PumpFailureState(BaseModel):
|
||||
time: str
|
||||
pump_status: dict
|
||||
|
||||
|
||||
class PressureSensorPlacement(BaseModel):
|
||||
name: str
|
||||
scheme_name: str
|
||||
sensor_number: int
|
||||
min_diameter: int = 0
|
||||
username: str
|
||||
|
||||
|
||||
def run_simulation_manually_by_date(
|
||||
network_name: str, base_date: datetime, start_time: str, duration: int
|
||||
) -> None:
|
||||
time_parts = list(map(int, start_time.split(":")))
|
||||
if len(time_parts) == 2:
|
||||
start_hour, start_minute = time_parts
|
||||
start_second = 0
|
||||
elif len(time_parts) == 3:
|
||||
start_hour, start_minute, start_second = time_parts
|
||||
else:
|
||||
raise ValueError("Invalid start_time format. Use HH:MM or HH:MM:SS")
|
||||
|
||||
start_datetime = base_date.replace(
|
||||
hour=start_hour, minute=start_minute, second=start_second
|
||||
)
|
||||
end_datetime = start_datetime + timedelta(minutes=duration)
|
||||
current_time = start_datetime
|
||||
while current_time < end_datetime:
|
||||
iso_time = current_time.strftime("%Y-%m-%dT%H:%M:%S") + "+08:00"
|
||||
simulation.run_simulation(
|
||||
name=network_name,
|
||||
simulation_type="realtime",
|
||||
modify_pattern_start_time=iso_time,
|
||||
)
|
||||
current_time += timedelta(minutes=15)
|
||||
|
||||
|
||||
# 必须用这个PlainTextResponse,不然每个key都有引号
|
||||
@router.get("/runproject/", response_class=PlainTextResponse)
|
||||
async def run_project_endpoint(network: str) -> str:
|
||||
lock_key = "exclusive_api_lock"
|
||||
timeout = 120 # 锁自动过期时间(秒)
|
||||
|
||||
# 尝试获取锁(NX=True: 不存在时设置,EX=timeout: 过期时间)
|
||||
acquired = redis_client.set(lock_key, "locked", nx=True, ex=timeout)
|
||||
|
||||
if not acquired:
|
||||
raise HTTPException(status_code=409, detail="is in simulation")
|
||||
else:
|
||||
try:
|
||||
return run_project(network)
|
||||
finally:
|
||||
# 手动释放锁(可选,依赖过期时间自动释放更安全)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
|
||||
# DingZQ, 2025-02-04, 返回dict[str, Any]
|
||||
# output 和 report
|
||||
# output 是 json
|
||||
# report 是 text
|
||||
@router.get("/runprojectreturndict/")
|
||||
async def run_project_return_dict_endpoint(network: str) -> dict[str, Any]:
|
||||
lock_key = "exclusive_api_lock"
|
||||
timeout = 120 # 锁自动过期时间(秒)
|
||||
|
||||
# 尝试获取锁(NX=True: 不存在时设置,EX=timeout: 过期时间)
|
||||
acquired = redis_client.set(lock_key, "locked", nx=True, ex=timeout)
|
||||
|
||||
if not acquired:
|
||||
raise HTTPException(status_code=409, detail="is in simulation")
|
||||
else:
|
||||
try:
|
||||
return run_project_return_dict(network)
|
||||
finally:
|
||||
# 手动释放锁(可选,依赖过期时间自动释放更安全)
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
|
||||
# put in inp folder, name without extension
|
||||
@router.get("/runinp/")
|
||||
async def run_inp_endpoint(network: str) -> str:
|
||||
return run_inp(network)
|
||||
|
||||
|
||||
# path is absolute path
|
||||
@router.get("/dumpoutput/")
|
||||
async def dump_output_endpoint(output: str) -> str:
|
||||
return dump_output(output)
|
||||
|
||||
|
||||
# Analysis Endpoints
|
||||
@router.get("/burstanalysis/")
|
||||
async def burst_analysis_endpoint(
|
||||
network: str, pipe_id: str, start_time: str, end_time: str, burst_flow: float
|
||||
):
|
||||
return burst_analysis(network, pipe_id, start_time, end_time, burst_flow)
|
||||
|
||||
|
||||
@router.get("/burst_analysis/")
|
||||
async def fastapi_burst_analysis(
|
||||
network: str = Query(...),
|
||||
modify_pattern_start_time: str = Query(...),
|
||||
burst_ID: list[str] = Query(...),
|
||||
burst_size: list[float] = Query(...),
|
||||
modify_total_duration: int = Query(...),
|
||||
scheme_name: str = Query(...),
|
||||
) -> str:
|
||||
burst_analysis(
|
||||
name=network,
|
||||
modify_pattern_start_time=modify_pattern_start_time,
|
||||
burst_ID=burst_ID,
|
||||
burst_size=burst_size,
|
||||
modify_total_duration=modify_total_duration,
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
return "success"
|
||||
|
||||
|
||||
@router.get("/valvecloseanalysis/")
|
||||
async def valve_close_analysis_endpoint(
|
||||
network: str, valve_id: str, start_time: str, end_time: str
|
||||
):
|
||||
return valve_close_analysis(network, valve_id, start_time, end_time)
|
||||
|
||||
|
||||
@router.get("/valve_close_analysis/", response_class=PlainTextResponse)
|
||||
async def fastapi_valve_close_analysis(
|
||||
network: str,
|
||||
start_time: str,
|
||||
valves: List[str] = Query(...),
|
||||
duration: int | None = None,
|
||||
) -> str:
|
||||
result = valve_close_analysis(
|
||||
name=network,
|
||||
modify_pattern_start_time=start_time,
|
||||
modify_total_duration=duration or 900,
|
||||
modify_valve_opening={valve_id: 0.0 for valve_id in valves},
|
||||
)
|
||||
return result or "success"
|
||||
|
||||
|
||||
@router.get("/valve_isolation_analysis/")
|
||||
async def valve_isolation_endpoint(
|
||||
network: str,
|
||||
accident_element: List[str] = Query(...),
|
||||
disabled_valves: List[str] = Query(None),
|
||||
):
|
||||
result = {
|
||||
"accident_element": "P461309",
|
||||
"accident_elements": ["P461309"],
|
||||
"affected_nodes": [
|
||||
"J316629_A",
|
||||
"J317037_B",
|
||||
"J317060_B",
|
||||
"J408189_B",
|
||||
"J499996",
|
||||
"J524940",
|
||||
"J535933",
|
||||
"J58841",
|
||||
],
|
||||
"isolatable": True,
|
||||
"must_close_valves": ["210521658", "V12974", "V12986", "V12993"],
|
||||
"optional_valves": [],
|
||||
}
|
||||
result = analyze_valve_isolation(network, accident_element, disabled_valves)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/flushinganalysis/")
|
||||
async def flushing_analysis_endpoint(
|
||||
network: str, pipe_id: str, start_time: str, duration: float, flow: float
|
||||
):
|
||||
return flushing_analysis(network, pipe_id, start_time, duration, flow)
|
||||
|
||||
|
||||
@router.get("/flushing_analysis/", response_class=PlainTextResponse)
|
||||
async def fastapi_flushing_analysis(
|
||||
network: str,
|
||||
start_time: str,
|
||||
valves: List[str] = Query(...),
|
||||
valves_k: List[float] = Query(...),
|
||||
drainage_node_ID: str = Query(...),
|
||||
flush_flow: float = 0,
|
||||
duration: int | None = None,
|
||||
scheme_name: str | None = None,
|
||||
) -> str:
|
||||
valve_opening = {
|
||||
valve_id: float(valves_k[idx]) for idx, valve_id in enumerate(valves)
|
||||
}
|
||||
result = flushing_analysis(
|
||||
name=network,
|
||||
modify_pattern_start_time=start_time,
|
||||
modify_total_duration=duration or 900,
|
||||
modify_valve_opening=valve_opening,
|
||||
drainage_node_ID=drainage_node_ID,
|
||||
flushing_flow=flush_flow,
|
||||
scheme_name=scheme_name,
|
||||
)
|
||||
return result or "success"
|
||||
|
||||
|
||||
@router.get("/contaminant_simulation/", response_class=PlainTextResponse)
|
||||
async def fastapi_contaminant_simulation(
|
||||
network: str,
|
||||
start_time: str,
|
||||
source: str,
|
||||
concentration: float,
|
||||
duration: int,
|
||||
scheme_name: str | None = None,
|
||||
pattern: str | None = None,
|
||||
) -> str:
|
||||
result = contaminant_simulation(
|
||||
name=network,
|
||||
modify_pattern_start_time=start_time,
|
||||
scheme_name=scheme_name,
|
||||
modify_total_duration=duration,
|
||||
source=source,
|
||||
concentration=concentration,
|
||||
source_pattern=pattern,
|
||||
)
|
||||
return result or "success"
|
||||
|
||||
|
||||
@router.get("/ageanalysis/")
|
||||
async def age_analysis_endpoint(network: str):
|
||||
return age_analysis(network)
|
||||
|
||||
|
||||
@router.get("/age_analysis/", response_class=PlainTextResponse)
|
||||
async def fastapi_age_analysis(
|
||||
network: str, start_time: str, end_time: str, duration: int
|
||||
) -> str:
|
||||
result = age_analysis(network, start_time, duration)
|
||||
return result or "success"
|
||||
|
||||
|
||||
# @router.get("/schedulinganalysis/")
|
||||
# async def scheduling_analysis_endpoint(network: str):
|
||||
# return scheduling_analysis(network)
|
||||
|
||||
|
||||
@router.get("/pressureregulation/")
|
||||
async def pressure_regulation_endpoint(
|
||||
network: str, target_node: str, target_pressure: float
|
||||
):
|
||||
return pressure_regulation(network, target_node, target_pressure)
|
||||
|
||||
|
||||
@router.post("/pressure_regulation/")
|
||||
async def fastapi_pressure_regulation(data: PressureRegulation) -> str:
|
||||
item = data.dict()
|
||||
simulation.query_corresponding_element_id_and_query_id(item["network"])
|
||||
fixed_pumps = set(globals.fixed_pumps_id.keys())
|
||||
variable_pumps = set(globals.variable_pumps_id.keys())
|
||||
fixed_pump_pattern: dict[str, list] = {}
|
||||
variable_pump_pattern: dict[str, list] = {}
|
||||
for pump_id, values in item["pump_control"].items():
|
||||
if pump_id in variable_pumps:
|
||||
variable_pump_pattern[pump_id] = values
|
||||
else:
|
||||
fixed_pump_pattern[pump_id] = values
|
||||
pressure_regulation(
|
||||
name=item["network"],
|
||||
modify_pattern_start_time=item["start_time"],
|
||||
modify_total_duration=item["duration"] or 900,
|
||||
modify_tank_initial_level=item["tank_init_level"],
|
||||
modify_fixed_pump_pattern=fixed_pump_pattern or None,
|
||||
modify_variable_pump_pattern=variable_pump_pattern or None,
|
||||
scheme_name=item["scheme_name"],
|
||||
)
|
||||
return "success"
|
||||
|
||||
|
||||
@router.get("/projectmanagement/")
|
||||
async def project_management_endpoint(network: str):
|
||||
return project_management(network)
|
||||
|
||||
|
||||
@router.post("/project_management/")
|
||||
async def fastapi_project_management(data: ProjectManagement) -> str:
|
||||
item = data.dict()
|
||||
return project_management(
|
||||
prj_name=item["network"],
|
||||
start_datetime=item["start_time"],
|
||||
pump_control=item["pump_control"],
|
||||
tank_initial_level_control=item["tank_init_level"],
|
||||
region_demand_control=item["region_demand"],
|
||||
)
|
||||
|
||||
|
||||
# @router.get("/dailyschedulinganalysis/")
|
||||
# async def daily_scheduling_analysis_endpoint(network: str):
|
||||
# return daily_scheduling_analysis(network)
|
||||
|
||||
|
||||
@router.post("/scheduling_analysis/")
|
||||
async def fastapi_scheduling_analysis(data: SchedulingAnalysis) -> str:
|
||||
item = data.dict()
|
||||
return scheduling_simulation(
|
||||
item["network"],
|
||||
item["start_time"],
|
||||
item["pump_control"],
|
||||
item["tank_id"],
|
||||
item["water_plant_output_id"],
|
||||
item["time_delta"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/daily_scheduling_analysis/")
|
||||
async def fastapi_daily_scheduling_analysis(data: DailySchedulingAnalysis) -> str:
|
||||
item = data.dict()
|
||||
return daily_scheduling_simulation(
|
||||
item["network"],
|
||||
item["start_time"],
|
||||
item["pump_control"],
|
||||
item["reservoir_id"],
|
||||
item["tank_id"],
|
||||
item["water_plant_output_id"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/network_project/")
|
||||
async def fastapi_network_project(file: UploadFile = File()) -> str:
|
||||
temp_file_dir = "./inp/"
|
||||
if not os.path.exists(temp_file_dir):
|
||||
os.mkdir(temp_file_dir)
|
||||
temp_file_name = f'network_project_{datetime.now().strftime("%Y%m%d")}'
|
||||
temp_file_path = f"{temp_file_dir}{temp_file_name}.inp"
|
||||
with open(temp_file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
return run_inp(temp_file_name)
|
||||
|
||||
|
||||
@router.get("/networkupdate/")
|
||||
async def network_update_endpoint(network: str):
|
||||
return network_update(network)
|
||||
|
||||
|
||||
@router.post("/network_update/")
|
||||
async def fastapi_network_update(file: UploadFile = File()) -> str:
|
||||
default_folder = "./"
|
||||
temp_file_name = f'network_update_{datetime.now().strftime("%Y%m%d")}'
|
||||
temp_file_path = os.path.join(default_folder, temp_file_name)
|
||||
try:
|
||||
with open(temp_file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
network_update(temp_file_path)
|
||||
return json.dumps({"message": "管网更新成功"})
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"数据库操作失败: {exc}")
|
||||
|
||||
|
||||
# @router.get("/pumpfailure/")
|
||||
# async def pump_failure_endpoint(network: str, pump_id: str, time: str):
|
||||
# return pump_failure(network, pump_id, time)
|
||||
|
||||
|
||||
@router.post("/pump_failure/")
|
||||
async def fastapi_pump_failure(data: PumpFailureState) -> str:
|
||||
item = data.dict()
|
||||
with open("./pump_failure_message.txt", "a", encoding="utf-8-sig") as f1:
|
||||
f1.write("[{}] {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), item))
|
||||
with open("./pump_failure_status.txt", "r", encoding="utf-8-sig") as f2:
|
||||
lines = f2.readlines()
|
||||
first_stage_pump_status_dict = json.loads(json.dumps(eval(lines[0])))
|
||||
second_stage_pump_status_dict = json.loads(json.dumps(eval(lines[-1])))
|
||||
pump_status_dict = {
|
||||
"first": first_stage_pump_status_dict,
|
||||
"second": second_stage_pump_status_dict,
|
||||
}
|
||||
status_info = item.copy()
|
||||
for pump_type in status_info["pump_status"].keys():
|
||||
if pump_type in pump_status_dict.keys():
|
||||
if all(
|
||||
pump_id in pump_status_dict[pump_type].keys()
|
||||
for pump_id in status_info["pump_status"][pump_type].keys()
|
||||
):
|
||||
for pump_id in status_info["pump_status"][pump_type].keys():
|
||||
pump_status_dict[pump_type][pump_id] = int(
|
||||
status_info["pump_status"][pump_type][pump_id]
|
||||
)
|
||||
else:
|
||||
return json.dumps("ERROR: Wrong Pump ID")
|
||||
else:
|
||||
return json.dumps("ERROR: Wrong Pump Type")
|
||||
with open("./pump_failure_status.txt", "w", encoding="utf-8-sig") as f2_:
|
||||
f2_.write(
|
||||
"{}\n{}".format(pump_status_dict["first"], pump_status_dict["second"])
|
||||
)
|
||||
return json.dumps("SUCCESS")
|
||||
|
||||
|
||||
@router.get("/pressuresensorplacementsensitivity/")
|
||||
async def pressure_sensor_placement_sensitivity_endpoint(
|
||||
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
|
||||
):
|
||||
return pressure_sensor_placement_sensitivity(
|
||||
name, scheme_name, sensor_number, min_diameter, username
|
||||
)
|
||||
|
||||
|
||||
@router.post("/pressure_sensor_placement_sensitivity/")
|
||||
async def fastapi_pressure_sensor_placement_sensitivity(
|
||||
data: PressureSensorPlacement,
|
||||
) -> None:
|
||||
item = data.dict()
|
||||
pressure_sensor_placement_sensitivity(
|
||||
name=item["name"],
|
||||
scheme_name=item["scheme_name"],
|
||||
sensor_number=item["sensor_number"],
|
||||
min_diameter=item["min_diameter"],
|
||||
username=item["username"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/pressuresensorplacementkmeans/")
|
||||
async def pressure_sensor_placement_kmeans_endpoint(
|
||||
name: str, scheme_name: str, sensor_number: int, min_diameter: int, username: str
|
||||
):
|
||||
return pressure_sensor_placement_kmeans(
|
||||
name, scheme_name, sensor_number, min_diameter, username
|
||||
)
|
||||
|
||||
|
||||
@router.post("/pressure_sensor_placement_kmeans/")
|
||||
async def fastapi_pressure_sensor_placement_kmeans(
|
||||
data: PressureSensorPlacement,
|
||||
) -> None:
|
||||
item = data.dict()
|
||||
pressure_sensor_placement_kmeans(
|
||||
name=item["name"],
|
||||
scheme_name=item["scheme_name"],
|
||||
sensor_number=item["sensor_number"],
|
||||
min_diameter=item["min_diameter"],
|
||||
username=item["username"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sensorplacementscheme/create")
|
||||
async def fastapi_pressure_sensor_placement(
|
||||
network: str = Query(...),
|
||||
scheme_name: str = Query(...),
|
||||
sensor_type: str = Query(...),
|
||||
method: str = Query(...),
|
||||
sensor_count: int = Query(...),
|
||||
min_diameter: int = Query(0),
|
||||
user_name: str = Query(...),
|
||||
) -> str:
|
||||
if method not in ["sensitivity", "kmeans"]:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid method. Must be 'sensitivity' or 'kmeans'"
|
||||
)
|
||||
if method == "sensitivity":
|
||||
pressure_sensor_placement_sensitivity(
|
||||
name=network,
|
||||
scheme_name=scheme_name,
|
||||
sensor_number=sensor_count,
|
||||
min_diameter=min_diameter,
|
||||
username=user_name,
|
||||
)
|
||||
elif method == "kmeans":
|
||||
pressure_sensor_placement_kmeans(
|
||||
name=network,
|
||||
scheme_name=scheme_name,
|
||||
sensor_number=sensor_count,
|
||||
min_diameter=min_diameter,
|
||||
username=user_name,
|
||||
)
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/scadadevicedatacleaning/")
|
||||
async def fastapi_scada_device_data_cleaning(
|
||||
network: str = Query(...),
|
||||
ids_list: List[str] = Query(...),
|
||||
start_time: str = Query(...),
|
||||
end_time: str = Query(...),
|
||||
user_name: str = Query(...),
|
||||
) -> str:
|
||||
item = {
|
||||
"network": network,
|
||||
"ids": ids_list,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"user_name": user_name,
|
||||
}
|
||||
query_ids_list = item["ids"][0].split(",")
|
||||
scada_data = influxdb_api.query_SCADA_data_by_device_ID_and_timerange(
|
||||
query_ids_list=query_ids_list,
|
||||
start_time=item["start_time"],
|
||||
end_time=item["end_time"],
|
||||
)
|
||||
scada_device_info = influxdb_api.query_pg_scada_info(item["network"])
|
||||
scada_device_info_dict = {info["id"]: info for info in scada_device_info}
|
||||
type_groups: dict[str, list[str]] = {}
|
||||
for device_id in query_ids_list:
|
||||
device_info = scada_device_info_dict.get(device_id, {})
|
||||
device_type = device_info.get("type", "unknown")
|
||||
type_groups.setdefault(device_type, []).append(device_id)
|
||||
for device_type, device_ids in type_groups.items():
|
||||
if device_type not in ["pressure", "pipe_flow"]:
|
||||
continue
|
||||
type_scada_data = {
|
||||
device_id: scada_data[device_id]
|
||||
for device_id in device_ids
|
||||
if device_id in scada_data
|
||||
}
|
||||
if not type_scada_data:
|
||||
continue
|
||||
time_list = [record["time"] for record in next(iter(type_scada_data.values()))]
|
||||
df = pd.DataFrame({"time": time_list})
|
||||
for device_id in device_ids:
|
||||
if device_id in type_scada_data:
|
||||
values = [record["value"] for record in type_scada_data[device_id]]
|
||||
df[device_id] = values
|
||||
if device_type == "pressure":
|
||||
cleaned_value_df = pressure_data_clean.clean_pressure_data_df_km(df)
|
||||
elif device_type == "pipe_flow":
|
||||
cleaned_value_df = flow_data_clean.clean_flow_data_df_kf(df)
|
||||
cleaned_value_df = pd.DataFrame(cleaned_value_df)
|
||||
cleaned_df = pd.concat([df["time"], cleaned_value_df], axis=1)
|
||||
influxdb_api.import_multicolumn_data_from_dict(
|
||||
data_dict=cleaned_df.to_dict("list"),
|
||||
raw=False,
|
||||
)
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/runsimulationmanuallybydate/")
|
||||
async def fastapi_run_simulation_manually_by_date(
|
||||
data: RunSimulationManuallyByDate,
|
||||
) -> dict[str, str]:
|
||||
item = data.dict()
|
||||
try:
|
||||
simulation.query_corresponding_element_id_and_query_id(item["name"])
|
||||
simulation.query_corresponding_pattern_id_and_query_id(item["name"])
|
||||
region_result = simulation.query_non_realtime_region(item["name"])
|
||||
globals.source_outflow_region_id = simulation.get_source_outflow_region_id(
|
||||
item["name"], region_result
|
||||
)
|
||||
globals.realtime_region_pipe_flow_and_demand_id = (
|
||||
simulation.query_realtime_region_pipe_flow_and_demand_id(
|
||||
item["name"], region_result
|
||||
)
|
||||
)
|
||||
globals.pipe_flow_region_patterns = simulation.query_pipe_flow_region_patterns(
|
||||
item["name"]
|
||||
)
|
||||
globals.non_realtime_region_patterns = (
|
||||
simulation.query_non_realtime_region_patterns(item["name"], region_result)
|
||||
)
|
||||
(
|
||||
globals.source_outflow_region_patterns,
|
||||
globals.realtime_region_pipe_flow_and_demand_patterns,
|
||||
) = simulation.get_realtime_region_patterns(
|
||||
item["name"],
|
||||
globals.source_outflow_region_id,
|
||||
globals.realtime_region_pipe_flow_and_demand_id,
|
||||
)
|
||||
base_date = datetime.strptime(item["simulation_date"], "%Y-%m-%d")
|
||||
run_simulation_manually_by_date(
|
||||
item["name"], base_date, item["start_time"], item["duration"]
|
||||
)
|
||||
return {"status": "success"}
|
||||
except Exception as exc:
|
||||
return {"status": "error", "message": str(exc)}
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from app.native.api import ChangeSet
|
||||
from app.services.tjnetwork import (
|
||||
get_current_operation,
|
||||
execute_undo,
|
||||
execute_redo,
|
||||
list_snapshot,
|
||||
have_snapshot,
|
||||
have_snapshot_for_operation,
|
||||
have_snapshot_for_current_operation,
|
||||
take_snapshot_for_operation,
|
||||
take_snapshot_for_current_operation,
|
||||
take_snapshot,
|
||||
pick_snapshot,
|
||||
pick_operation,
|
||||
sync_with_server,
|
||||
execute_batch_commands,
|
||||
execute_batch_command,
|
||||
get_restore_operation,
|
||||
set_restore_operation,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/getcurrentoperationid/")
|
||||
async def get_current_operation_id_endpoint(network: str) -> int:
|
||||
return get_current_operation(network)
|
||||
|
||||
@router.post("/undo/")
|
||||
async def undo_endpoint(network: str):
|
||||
return execute_undo(network)
|
||||
|
||||
@router.post("/redo/")
|
||||
async def redo_endpoint(network: str):
|
||||
return execute_redo(network)
|
||||
|
||||
@router.get("/getsnapshots/")
|
||||
async def list_snapshot_endpoint(network: str) -> list[tuple[int, str]]:
|
||||
return list_snapshot(network)
|
||||
|
||||
@router.get("/havesnapshot/")
|
||||
async def have_snapshot_endpoint(network: str, tag: str) -> bool:
|
||||
return have_snapshot(network, tag)
|
||||
|
||||
@router.get("/havesnapshotforoperation/")
|
||||
async def have_snapshot_for_operation_endpoint(network: str, operation: int) -> bool:
|
||||
return have_snapshot_for_operation(network, operation)
|
||||
|
||||
@router.get("/havesnapshotforcurrentoperation/")
|
||||
async def have_snapshot_for_current_operation_endpoint(network: str) -> bool:
|
||||
return have_snapshot_for_current_operation(network)
|
||||
|
||||
@router.post("/takesnapshotforoperation/")
|
||||
async def take_snapshot_for_operation_endpoint(
|
||||
network: str, operation: int, tag: str
|
||||
) -> None:
|
||||
return take_snapshot_for_operation(network, operation, tag)
|
||||
|
||||
@router.post("/takesnapshotforcurrentoperation")
|
||||
async def take_snapshot_for_current_operation_endpoint(network: str, tag: str) -> None:
|
||||
return take_snapshot_for_current_operation(network, tag)
|
||||
|
||||
# 兼容旧拼写: takenapshotforcurrentoperation
|
||||
@router.post("/takenapshotforcurrentoperation")
|
||||
async def take_snapshot_for_current_operation_legacy_endpoint(
|
||||
network: str, tag: str
|
||||
) -> None:
|
||||
return take_snapshot_for_current_operation(network, tag)
|
||||
|
||||
@router.post("/takesnapshot/")
|
||||
async def take_snapshot_endpoint(network: str, tag: str) -> None:
|
||||
return take_snapshot(network, tag)
|
||||
|
||||
@router.post("/picksnapshot/", response_model=None)
|
||||
async def pick_snapshot_endpoint(network: str, tag: str, discard: bool = False) -> ChangeSet:
|
||||
return pick_snapshot(network, tag, discard)
|
||||
|
||||
@router.post("/pickoperation/", response_model=None)
|
||||
async def pick_operation_endpoint(
|
||||
network: str, operation: int, discard: bool = False
|
||||
) -> ChangeSet:
|
||||
return pick_operation(network, operation, discard)
|
||||
|
||||
@router.get("/syncwithserver/", response_model=None)
|
||||
async def sync_with_server_endpoint(network: str, operation: int) -> ChangeSet:
|
||||
return sync_with_server(network, operation)
|
||||
|
||||
@router.post("/batch/", response_model=None)
|
||||
async def execute_batch_commands_endpoint(network: str, req: Request) -> ChangeSet:
|
||||
jo_root = await req.json()
|
||||
cs: ChangeSet = ChangeSet()
|
||||
cs.operations = jo_root["operations"]
|
||||
rcs = execute_batch_commands(network, cs)
|
||||
return rcs
|
||||
|
||||
@router.post("/compressedbatch/", response_model=None)
|
||||
async def execute_compressed_batch_commands_endpoint(
|
||||
network: str, req: Request
|
||||
) -> ChangeSet:
|
||||
jo_root = await req.json()
|
||||
cs: ChangeSet = ChangeSet()
|
||||
cs.operations = jo_root["operations"]
|
||||
return execute_batch_command(network, cs)
|
||||
|
||||
@router.get("/getrestoreoperation/")
|
||||
async def get_restore_operation_endpoint(network: str) -> int:
|
||||
return get_restore_operation(network)
|
||||
|
||||
@router.post("/setrestoreoperation/")
|
||||
async def set_restore_operation_endpoint(network: str, operation: int) -> None:
|
||||
return set_restore_operation(network, operation)
|
||||
|
||||
180
app/api/v1/endpoints/user_management.py
Normal file
180
app/api/v1/endpoints/user_management.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
用户管理 API 接口
|
||||
|
||||
演示权限控制的使用
|
||||
"""
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from app.domain.schemas.user import UserResponse, UserUpdate, UserCreate
|
||||
from app.domain.models.role import UserRole
|
||||
from app.domain.schemas.user import UserInDB
|
||||
from app.infra.repositories.user_repository import UserRepository
|
||||
from app.auth.dependencies import get_user_repository, get_current_active_user
|
||||
from app.auth.permissions import get_current_admin, require_role, check_resource_owner
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/", response_model=List[UserResponse])
|
||||
async def list_users(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: UserInDB = Depends(require_role(UserRole.ADMIN)),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> List[UserResponse]:
|
||||
"""
|
||||
获取用户列表(仅管理员)
|
||||
"""
|
||||
users = await user_repo.get_all_users(skip=skip, limit=limit)
|
||||
return [UserResponse.model_validate(user) for user in users]
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
current_user: UserInDB = Depends(get_current_active_user),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> UserResponse:
|
||||
"""
|
||||
获取用户详情
|
||||
|
||||
管理员可查看所有用户,普通用户只能查看自己
|
||||
"""
|
||||
# 检查权限
|
||||
if not check_resource_owner(user_id, current_user):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to view this user"
|
||||
)
|
||||
|
||||
user = await user_repo.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: int,
|
||||
user_update: UserUpdate,
|
||||
current_user: UserInDB = Depends(get_current_active_user),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> UserResponse:
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
管理员可更新所有用户,普通用户只能更新自己(且不能修改角色)
|
||||
"""
|
||||
# 检查用户是否存在
|
||||
target_user = await user_repo.get_user_by_id(user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
# 权限检查
|
||||
is_owner = current_user.id == user_id
|
||||
is_admin = UserRole(current_user.role).has_permission(UserRole.ADMIN)
|
||||
|
||||
if not is_owner and not is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to update this user"
|
||||
)
|
||||
|
||||
# 非管理员不能修改角色和激活状态
|
||||
if not is_admin:
|
||||
if user_update.role is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admins can change user roles"
|
||||
)
|
||||
if user_update.is_active is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only admins can change user active status"
|
||||
)
|
||||
|
||||
# 更新用户
|
||||
updated_user = await user_repo.update_user(user_id, user_update)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user"
|
||||
)
|
||||
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> dict:
|
||||
"""
|
||||
删除用户(仅管理员)
|
||||
"""
|
||||
# 不能删除自己
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You cannot delete your own account"
|
||||
)
|
||||
|
||||
success = await user_repo.delete_user(user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return {"message": "User deleted successfully"}
|
||||
|
||||
@router.post("/{user_id}/activate")
|
||||
async def activate_user(
|
||||
user_id: int,
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> UserResponse:
|
||||
"""
|
||||
激活用户(仅管理员)
|
||||
"""
|
||||
user_update = UserUpdate(is_active=True)
|
||||
updated_user = await user_repo.update_user(user_id, user_update)
|
||||
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return UserResponse.model_validate(updated_user)
|
||||
|
||||
@router.post("/{user_id}/deactivate")
|
||||
async def deactivate_user(
|
||||
user_id: int,
|
||||
current_user: UserInDB = Depends(get_current_admin),
|
||||
user_repo: UserRepository = Depends(get_user_repository)
|
||||
) -> UserResponse:
|
||||
"""
|
||||
停用用户(仅管理员)
|
||||
"""
|
||||
# 不能停用自己
|
||||
if current_user.id == user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="You cannot deactivate your own account"
|
||||
)
|
||||
|
||||
user_update = UserUpdate(is_active=False)
|
||||
updated_user = await user_repo.update_user(user_id, user_update)
|
||||
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
return UserResponse.model_validate(updated_user)
|
||||
21
app/api/v1/endpoints/users.py
Normal file
21
app/api/v1/endpoints/users.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from typing import Any, List, Dict, Union
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
###########################################################
|
||||
# user 39
|
||||
###########################################################
|
||||
|
||||
@router.get("/getuserschema/")
|
||||
async def fastapi_get_user_schema(network: str) -> dict[str, dict[Any, Any]]:
|
||||
return get_user_schema(network)
|
||||
|
||||
@router.get("/getuser/")
|
||||
async def fastapi_get_user(network: str, user_name: str) -> dict[Any, Any]:
|
||||
return get_user(network, user_name)
|
||||
|
||||
@router.get("/getallusers/")
|
||||
async def fastapi_get_all_users(network: str) -> list[dict[Any, Any]]:
|
||||
return get_all_users(network)
|
||||
@@ -2,19 +2,91 @@ from fastapi import APIRouter
|
||||
from app.api.v1.endpoints import (
|
||||
auth,
|
||||
project,
|
||||
network_elements,
|
||||
simulation,
|
||||
scada,
|
||||
extension,
|
||||
snapshots
|
||||
snapshots,
|
||||
data_query,
|
||||
users,
|
||||
schemes,
|
||||
misc,
|
||||
risk,
|
||||
cache,
|
||||
user_management, # 新增:用户管理
|
||||
audit, # 新增:审计日志
|
||||
meta,
|
||||
)
|
||||
from app.api.v1.endpoints.network import (
|
||||
general,
|
||||
junctions,
|
||||
reservoirs,
|
||||
tanks,
|
||||
pipes,
|
||||
pumps,
|
||||
valves,
|
||||
tags,
|
||||
demands,
|
||||
geometry,
|
||||
regions,
|
||||
)
|
||||
from app.api.v1.endpoints.components import (
|
||||
curves,
|
||||
patterns,
|
||||
controls,
|
||||
options,
|
||||
quality,
|
||||
visuals,
|
||||
)
|
||||
|
||||
from app.infra.db.postgresql import router as postgresql_router
|
||||
from app.infra.db.timescaledb import router as timescaledb_router
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(project.router, prefix="/projects", tags=["projects"])
|
||||
api_router.include_router(network_elements.router, prefix="/elements", tags=["network-elements"])
|
||||
api_router.include_router(simulation.router, prefix="/simulation", tags=["simulation"])
|
||||
api_router.include_router(scada.router, prefix="/scada", tags=["scada"])
|
||||
api_router.include_router(extension.router, prefix="/extension", tags=["extension"])
|
||||
api_router.include_router(snapshots.router, prefix="/snapshots", tags=["snapshots"])
|
||||
# Core Services
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||
api_router.include_router(user_management.router, prefix="/users", tags=["User Management"]) # 新增
|
||||
api_router.include_router(audit.router, prefix="/audit", tags=["Audit Logs"]) # 新增
|
||||
api_router.include_router(meta.router, tags=["Metadata"])
|
||||
api_router.include_router(project.router, tags=["Project"])
|
||||
|
||||
# Network Elements (Node/Link Types)
|
||||
api_router.include_router(general.router, tags=["Network General"])
|
||||
api_router.include_router(junctions.router, tags=["Junctions"])
|
||||
api_router.include_router(reservoirs.router, tags=["Reservoirs"])
|
||||
api_router.include_router(tanks.router, tags=["Tanks"])
|
||||
api_router.include_router(pipes.router, tags=["Pipes"])
|
||||
api_router.include_router(pumps.router, tags=["Pumps"])
|
||||
api_router.include_router(valves.router, tags=["Valves"])
|
||||
|
||||
# Network Features
|
||||
api_router.include_router(tags.router, tags=["Tags"])
|
||||
api_router.include_router(demands.router, tags=["Demands"])
|
||||
api_router.include_router(geometry.router, tags=["Geometry & Coordinates"])
|
||||
api_router.include_router(regions.router, tags=["Regions & DMAs"])
|
||||
|
||||
# Components & Controls
|
||||
api_router.include_router(curves.router, tags=["Curves"])
|
||||
api_router.include_router(patterns.router, tags=["Patterns"])
|
||||
api_router.include_router(controls.router, tags=["Controls & Rules"])
|
||||
api_router.include_router(options.router, tags=["Options"])
|
||||
api_router.include_router(quality.router, tags=["Quality"])
|
||||
api_router.include_router(visuals.router, tags=["Visuals"])
|
||||
|
||||
# Simulation & Data
|
||||
api_router.include_router(simulation.router, tags=["Simulation Control"])
|
||||
api_router.include_router(data_query.router, tags=["Data Query & InfluxDB"])
|
||||
api_router.include_router(scada.router, tags=["SCADA"])
|
||||
api_router.include_router(snapshots.router, tags=["Snapshots"])
|
||||
api_router.include_router(users.router, tags=["Users"])
|
||||
api_router.include_router(schemes.router, tags=["Schemes"])
|
||||
api_router.include_router(misc.router, tags=["Misc"])
|
||||
api_router.include_router(risk.router, tags=["Risk"])
|
||||
api_router.include_router(cache.router, tags=["Cache"])
|
||||
|
||||
# Database Routers
|
||||
api_router.include_router(timescaledb_router, tags=["TimescaleDB"])
|
||||
api_router.include_router(postgresql_router, tags=["PostgreSQL"])
|
||||
|
||||
# Extension
|
||||
api_router.include_router(extension.router, tags=["Extension"])
|
||||
|
||||
@@ -1,21 +1,100 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from typing import Annotated, Optional
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from app.core.config import settings
|
||||
from jose import jwt, JWTError
|
||||
from app.core.config import settings
|
||||
from app.domain.schemas.user import UserInDB, TokenPayload
|
||||
from app.infra.repositories.user_repository import UserRepository
|
||||
from app.infra.db.postgresql.database import Database
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
|
||||
# 数据库依赖
|
||||
async def get_db(request: Request) -> Database:
|
||||
"""
|
||||
获取数据库实例
|
||||
|
||||
从 FastAPI app.state 中获取在启动时初始化的数据库连接
|
||||
"""
|
||||
if not hasattr(request.app.state, "db"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Database not initialized",
|
||||
)
|
||||
return request.app.state.db
|
||||
|
||||
|
||||
async def get_user_repository(db: Database = Depends(get_db)) -> UserRepository:
|
||||
"""获取用户仓储实例"""
|
||||
return UserRepository(db)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
user_repo: UserRepository = Depends(get_user_repository),
|
||||
) -> UserInDB:
|
||||
"""
|
||||
获取当前登录用户
|
||||
|
||||
从 JWT Token 中解析用户信息,并从数据库验证
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
token_type: str = payload.get("type", "access")
|
||||
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
|
||||
if token_type != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type. Access token required.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
return username
|
||||
|
||||
# 从数据库获取用户
|
||||
user = await user_repo.get_user_by_username(username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> UserInDB:
|
||||
"""
|
||||
获取当前活跃用户(必须是激活状态)
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_superuser(
|
||||
current_user: UserInDB = Depends(get_current_user),
|
||||
) -> UserInDB:
|
||||
"""
|
||||
获取当前超级管理员用户
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough privileges. Superuser access required.",
|
||||
)
|
||||
return current_user
|
||||
|
||||
63
app/auth/keycloak_dependencies.py
Normal file
63
app/auth/keycloak_dependencies.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# import logging
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
oauth2_optional = OAuth2PasswordBearer(
|
||||
tokenUrl=f"{settings.API_V1_STR}/auth/login", auto_error=False
|
||||
)
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_current_keycloak_sub(
|
||||
token: str | None = Depends(oauth2_optional),
|
||||
) -> UUID:
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if settings.KEYCLOAK_PUBLIC_KEY:
|
||||
key = settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
algorithms = [settings.KEYCLOAK_ALGORITHM]
|
||||
else:
|
||||
key = settings.SECRET_KEY
|
||||
algorithms = [settings.ALGORITHM]
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=algorithms,
|
||||
audience=settings.KEYCLOAK_AUDIENCE or None,
|
||||
)
|
||||
except JWTError as exc:
|
||||
# logger.warning("Keycloak token validation failed: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing subject claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
return UUID(sub)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid subject claim",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
60
app/auth/metadata_dependencies.py
Normal file
60
app/auth/metadata_dependencies.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_metadata_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> MetadataRepository:
|
||||
return MetadataRepository(session)
|
||||
|
||||
|
||||
async def get_current_metadata_user(
|
||||
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
):
|
||||
try:
|
||||
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while resolving current user",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_metadata_admin(
|
||||
user=Depends(get_current_metadata_user),
|
||||
):
|
||||
if user.is_superuser or user.role == "admin":
|
||||
return user
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _AuthBypassUser:
|
||||
id: UUID = UUID(int=0)
|
||||
role: str = "admin"
|
||||
is_superuser: bool = True
|
||||
is_active: bool = True
|
||||
106
app/auth/permissions.py
Normal file
106
app/auth/permissions.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
权限控制依赖项和装饰器
|
||||
|
||||
基于角色的访问控制(RBAC)
|
||||
"""
|
||||
from typing import Callable
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from app.domain.models.role import UserRole
|
||||
from app.domain.schemas.user import UserInDB
|
||||
from app.auth.dependencies import get_current_active_user
|
||||
|
||||
def require_role(required_role: UserRole):
|
||||
"""
|
||||
要求特定角色或更高权限
|
||||
|
||||
用法:
|
||||
@router.get("/admin-only")
|
||||
async def admin_endpoint(user: UserInDB = Depends(require_role(UserRole.ADMIN))):
|
||||
...
|
||||
|
||||
Args:
|
||||
required_role: 需要的最低角色
|
||||
|
||||
Returns:
|
||||
依赖函数
|
||||
"""
|
||||
async def role_checker(
|
||||
current_user: UserInDB = Depends(get_current_active_user)
|
||||
) -> UserInDB:
|
||||
user_role = UserRole(current_user.role)
|
||||
|
||||
if not user_role.has_permission(required_role):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Insufficient permissions. Required role: {required_role.value}, "
|
||||
f"Your role: {user_role.value}"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return role_checker
|
||||
|
||||
# 预定义的权限检查依赖
|
||||
require_admin = require_role(UserRole.ADMIN)
|
||||
require_operator = require_role(UserRole.OPERATOR)
|
||||
require_user = require_role(UserRole.USER)
|
||||
|
||||
def get_current_admin(
|
||||
current_user: UserInDB = Depends(require_admin)
|
||||
) -> UserInDB:
|
||||
"""
|
||||
获取当前管理员用户
|
||||
|
||||
等同于 Depends(require_role(UserRole.ADMIN))
|
||||
"""
|
||||
return current_user
|
||||
|
||||
def get_current_operator(
|
||||
current_user: UserInDB = Depends(require_operator)
|
||||
) -> UserInDB:
|
||||
"""
|
||||
获取当前操作员用户(或更高权限)
|
||||
|
||||
等同于 Depends(require_role(UserRole.OPERATOR))
|
||||
"""
|
||||
return current_user
|
||||
|
||||
def check_resource_owner(user_id: int, current_user: UserInDB) -> bool:
|
||||
"""
|
||||
检查是否是资源拥有者或管理员
|
||||
|
||||
Args:
|
||||
user_id: 资源拥有者ID
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
是否有权限
|
||||
"""
|
||||
# 管理员可以访问所有资源
|
||||
if UserRole(current_user.role).has_permission(UserRole.ADMIN):
|
||||
return True
|
||||
|
||||
# 检查是否是资源拥有者
|
||||
return current_user.id == user_id
|
||||
|
||||
def require_owner_or_admin(user_id: int):
|
||||
"""
|
||||
要求是资源拥有者或管理员
|
||||
|
||||
Args:
|
||||
user_id: 资源拥有者ID
|
||||
|
||||
Returns:
|
||||
依赖函数
|
||||
"""
|
||||
async def owner_or_admin_checker(
|
||||
current_user: UserInDB = Depends(get_current_active_user)
|
||||
) -> UserInDB:
|
||||
if not check_resource_owner(user_id, current_user):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to access this resource"
|
||||
)
|
||||
return current_user
|
||||
|
||||
return owner_or_admin_checker
|
||||
213
app/auth/project_dependencies.py
Normal file
213
app/auth/project_dependencies.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
import logging
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from psycopg import AsyncConnection
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.keycloak_dependencies import get_current_keycloak_sub
|
||||
from app.core.config import settings
|
||||
from app.infra.db.dynamic_manager import project_connection_manager
|
||||
from app.infra.db.metadata.database import get_metadata_session
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
DB_ROLE_BIZ_DATA = "biz_data"
|
||||
DB_ROLE_IOT_DATA = "iot_data"
|
||||
DB_TYPE_POSTGRES = "postgresql"
|
||||
DB_TYPE_TIMESCALE = "timescaledb"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectContext:
|
||||
project_id: UUID
|
||||
user_id: UUID
|
||||
project_role: str
|
||||
|
||||
|
||||
async def get_metadata_repository(
|
||||
session: AsyncSession = Depends(get_metadata_session),
|
||||
) -> MetadataRepository:
|
||||
return MetadataRepository(session)
|
||||
|
||||
|
||||
async def get_project_context(
|
||||
x_project_id: str = Header(..., alias="X-Project-Id"),
|
||||
keycloak_sub: UUID = Depends(get_current_keycloak_sub),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> ProjectContext:
|
||||
try:
|
||||
project_uuid = UUID(x_project_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid project id"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
project = await metadata_repo.get_project_by_id(project_uuid)
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
|
||||
)
|
||||
if project.status != "active":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Project is not active"
|
||||
)
|
||||
|
||||
user = await metadata_repo.get_user_by_keycloak_id(keycloak_sub)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User not registered"
|
||||
)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
membership_role = await metadata_repo.get_membership_role(project_uuid, user.id)
|
||||
if not membership_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="No access to project"
|
||||
)
|
||||
except SQLAlchemyError as exc:
|
||||
logger.error(
|
||||
"Metadata DB error while resolving project context",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Metadata database error: {exc}",
|
||||
) from exc
|
||||
|
||||
return ProjectContext(
|
||||
project_id=project.id,
|
||||
user_id=user.id,
|
||||
project_role=membership_role,
|
||||
)
|
||||
|
||||
|
||||
async def get_project_pg_session(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project PostgreSQL routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_POSTGRES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||
sessionmaker = await project_connection_manager.get_pg_sessionmaker(
|
||||
ctx.project_id,
|
||||
DB_ROLE_BIZ_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with sessionmaker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_project_pg_connection(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_BIZ_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project PostgreSQL routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project PostgreSQL routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_POSTGRES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project PostgreSQL type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_PG_POOL_SIZE
|
||||
pool = await project_connection_manager.get_pg_pool(
|
||||
ctx.project_id,
|
||||
DB_ROLE_BIZ_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_project_timescale_connection(
|
||||
ctx: ProjectContext = Depends(get_project_context),
|
||||
metadata_repo: MetadataRepository = Depends(get_metadata_repository),
|
||||
) -> AsyncGenerator[AsyncConnection, None]:
|
||||
try:
|
||||
routing = await metadata_repo.get_project_db_routing(
|
||||
ctx.project_id, DB_ROLE_IOT_DATA
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(
|
||||
"Invalid project TimescaleDB routing DSN configuration",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Project TimescaleDB routing DSN is invalid: {exc}",
|
||||
) from exc
|
||||
if not routing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project TimescaleDB not configured",
|
||||
)
|
||||
if routing.db_type != DB_TYPE_TIMESCALE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project TimescaleDB type mismatch",
|
||||
)
|
||||
|
||||
pool_min_size = routing.pool_min_size or settings.PROJECT_TS_POOL_MIN_SIZE
|
||||
pool_max_size = routing.pool_max_size or settings.PROJECT_TS_POOL_MAX_SIZE
|
||||
pool = await project_connection_manager.get_timescale_pool(
|
||||
ctx.project_id,
|
||||
DB_ROLE_IOT_DATA,
|
||||
routing.dsn,
|
||||
pool_min_size,
|
||||
pool_max_size,
|
||||
)
|
||||
async with pool.connection() as conn:
|
||||
yield conn
|
||||
@@ -1,3 +1,146 @@
|
||||
# Placeholder for audit logic
|
||||
async def log_audit_event(event_type: str, user_id: str, details: dict):
|
||||
pass
|
||||
"""
|
||||
审计日志模块
|
||||
|
||||
记录系统关键操作,用于安全审计和合规追踪
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditAction:
|
||||
"""审计操作类型常量"""
|
||||
|
||||
# 认证相关
|
||||
LOGIN = "LOGIN"
|
||||
LOGOUT = "LOGOUT"
|
||||
REGISTER = "REGISTER"
|
||||
PASSWORD_CHANGE = "PASSWORD_CHANGE"
|
||||
|
||||
# 数据操作
|
||||
CREATE = "CREATE"
|
||||
READ = "READ"
|
||||
UPDATE = "UPDATE"
|
||||
DELETE = "DELETE"
|
||||
|
||||
# 权限相关
|
||||
PERMISSION_CHANGE = "PERMISSION_CHANGE"
|
||||
ROLE_CHANGE = "ROLE_CHANGE"
|
||||
|
||||
# 系统操作
|
||||
CONFIG_CHANGE = "CONFIG_CHANGE"
|
||||
SYSTEM_START = "SYSTEM_START"
|
||||
SYSTEM_STOP = "SYSTEM_STOP"
|
||||
|
||||
|
||||
async def log_audit_event(
|
||||
action: str,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
session=None,
|
||||
):
|
||||
"""
|
||||
记录审计日志
|
||||
|
||||
Args:
|
||||
action: 操作类型
|
||||
user_id: 用户ID
|
||||
project_id: 项目ID
|
||||
resource_type: 资源类型
|
||||
resource_id: 资源ID
|
||||
ip_address: IP地址
|
||||
request_method: 请求方法
|
||||
request_path: 请求路径
|
||||
request_data: 请求数据(敏感字段需脱敏)
|
||||
response_status: 响应状态码
|
||||
session: 元数据库会话(可选)
|
||||
"""
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.audit_repository import AuditRepository
|
||||
|
||||
if request_data:
|
||||
request_data = sanitize_sensitive_data(request_data)
|
||||
|
||||
if session is None:
|
||||
async with SessionLocal() as session:
|
||||
audit_repo = AuditRepository(session)
|
||||
await audit_repo.create_log(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
)
|
||||
else:
|
||||
audit_repo = AuditRepository(session)
|
||||
await audit_repo.create_log(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Audit log created: action=%s, user=%s, project=%s, resource=%s:%s",
|
||||
action,
|
||||
user_id,
|
||||
project_id,
|
||||
resource_type,
|
||||
resource_id,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_sensitive_data(data: dict) -> dict:
|
||||
"""
|
||||
脱敏敏感数据
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
|
||||
Returns:
|
||||
脱敏后的数据
|
||||
"""
|
||||
sensitive_fields = [
|
||||
"password",
|
||||
"passwd",
|
||||
"pwd",
|
||||
"secret",
|
||||
"token",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"credit_card",
|
||||
"ssn",
|
||||
"social_security",
|
||||
]
|
||||
|
||||
sanitized = data.copy()
|
||||
|
||||
for key in sanitized:
|
||||
if isinstance(sanitized[key], dict):
|
||||
sanitized[key] = sanitize_sensitive_data(sanitized[key])
|
||||
elif any(sensitive in key.lower() for sensitive in sensitive_fields):
|
||||
sanitized[key] = "***REDACTED***"
|
||||
|
||||
return sanitized
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote_plus
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "TJWater Server"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
SECRET_KEY: str = "your-secret-key-here" # Change in production
|
||||
|
||||
# JWT 配置
|
||||
SECRET_KEY: str = (
|
||||
"your-secret-key-here-change-in-production-use-openssl-rand-hex-32"
|
||||
)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# 数据加密密钥 (使用 Fernet)
|
||||
ENCRYPTION_KEY: str = "" # 必须从环境变量设置
|
||||
DATABASE_ENCRYPTION_KEY: str = "" # project_databases.dsn_encrypted 专用
|
||||
|
||||
# Database Config (PostgreSQL)
|
||||
DB_NAME: str = "tjwater"
|
||||
DB_HOST: str = "localhost"
|
||||
@@ -14,17 +26,57 @@ class Settings(BaseSettings):
|
||||
DB_USER: str = "postgres"
|
||||
DB_PASSWORD: str = "password"
|
||||
|
||||
# Database Config (TimescaleDB)
|
||||
TIMESCALEDB_DB_NAME: str = "tjwater"
|
||||
TIMESCALEDB_DB_HOST: str = "localhost"
|
||||
TIMESCALEDB_DB_PORT: str = "5433"
|
||||
TIMESCALEDB_DB_USER: str = "postgres"
|
||||
TIMESCALEDB_DB_PASSWORD: str = "password"
|
||||
# InfluxDB
|
||||
INFLUXDB_URL: str = "http://localhost:8086"
|
||||
INFLUXDB_TOKEN: str = "token"
|
||||
INFLUXDB_ORG: str = "org"
|
||||
INFLUXDB_BUCKET: str = "bucket"
|
||||
|
||||
# Metadata Database Config (PostgreSQL)
|
||||
METADATA_DB_NAME: str = "system_hub"
|
||||
METADATA_DB_HOST: str = "localhost"
|
||||
METADATA_DB_PORT: str = "5432"
|
||||
METADATA_DB_USER: str = "postgres"
|
||||
METADATA_DB_PASSWORD: str = "password"
|
||||
|
||||
METADATA_DB_POOL_SIZE: int = 5
|
||||
METADATA_DB_MAX_OVERFLOW: int = 10
|
||||
|
||||
PROJECT_PG_CACHE_SIZE: int = 50
|
||||
PROJECT_TS_CACHE_SIZE: int = 50
|
||||
PROJECT_PG_POOL_SIZE: int = 5
|
||||
PROJECT_PG_MAX_OVERFLOW: int = 10
|
||||
PROJECT_TS_POOL_MIN_SIZE: int = 1
|
||||
PROJECT_TS_POOL_MAX_SIZE: int = 10
|
||||
|
||||
# Keycloak JWT (optional override)
|
||||
KEYCLOAK_PUBLIC_KEY: str = ""
|
||||
KEYCLOAK_ALGORITHM: str = "RS256"
|
||||
KEYCLOAK_AUDIENCE: str = ""
|
||||
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
db_password = quote_plus(self.DB_PASSWORD)
|
||||
return f"postgresql://{self.DB_USER}:{db_password}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
|
||||
@property
|
||||
def METADATA_DATABASE_URI(self) -> str:
|
||||
metadata_password = quote_plus(self.METADATA_DB_PASSWORD)
|
||||
return (
|
||||
f"postgresql+psycopg://{self.METADATA_DB_USER}:{metadata_password}"
|
||||
f"@{self.METADATA_DB_HOST}:{self.METADATA_DB_PORT}/{self.METADATA_DB_NAME}"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=Path(__file__).resolve().parents[2] / ".env",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,9 +1,126 @@
|
||||
# Placeholder for encryption logic
|
||||
from cryptography.fernet import Fernet
|
||||
from typing import Optional
|
||||
import base64
|
||||
import os
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class Encryptor:
|
||||
"""
|
||||
使用 Fernet (对称加密) 实现数据加密/解密
|
||||
适用于加密敏感配置、用户数据等
|
||||
"""
|
||||
|
||||
def __init__(self, key: Optional[bytes] = None):
|
||||
"""
|
||||
初始化加密器
|
||||
|
||||
Args:
|
||||
key: 加密密钥,如果为 None 则从环境变量读取
|
||||
"""
|
||||
if key is None:
|
||||
key_str = os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY
|
||||
if not key_str:
|
||||
raise ValueError(
|
||||
"ENCRYPTION_KEY not found in environment variables or .env. "
|
||||
"Generate one using: Encryptor.generate_key()"
|
||||
)
|
||||
key = key_str.encode()
|
||||
|
||||
self.fernet = Fernet(key)
|
||||
|
||||
def encrypt(self, data: str) -> str:
|
||||
return data # Implement actual encryption
|
||||
"""
|
||||
加密字符串
|
||||
|
||||
Args:
|
||||
data: 待加密的明文字符串
|
||||
|
||||
Returns:
|
||||
Base64 编码的加密字符串
|
||||
"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
encrypted_bytes = self.fernet.encrypt(data.encode())
|
||||
return encrypted_bytes.decode()
|
||||
|
||||
def decrypt(self, data: str) -> str:
|
||||
return data # Implement actual decryption
|
||||
"""
|
||||
解密字符串
|
||||
|
||||
encryptor = Encryptor()
|
||||
Args:
|
||||
data: Base64 编码的加密字符串
|
||||
|
||||
Returns:
|
||||
解密后的明文字符串
|
||||
"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
decrypted_bytes = self.fernet.decrypt(data.encode())
|
||||
return decrypted_bytes.decode()
|
||||
|
||||
@staticmethod
|
||||
def generate_key() -> str:
|
||||
"""
|
||||
生成新的 Fernet 加密密钥
|
||||
|
||||
Returns:
|
||||
Base64 编码的密钥字符串
|
||||
"""
|
||||
key = Fernet.generate_key()
|
||||
return key.decode()
|
||||
|
||||
|
||||
# 全局加密器实例(懒加载)
|
||||
_encryptor: Optional[Encryptor] = None
|
||||
_database_encryptor: Optional[Encryptor] = None
|
||||
|
||||
|
||||
def is_encryption_configured() -> bool:
|
||||
return bool(os.getenv("ENCRYPTION_KEY") or settings.ENCRYPTION_KEY)
|
||||
|
||||
|
||||
def is_database_encryption_configured() -> bool:
|
||||
return bool(
|
||||
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||
or settings.DATABASE_ENCRYPTION_KEY
|
||||
or os.getenv("ENCRYPTION_KEY")
|
||||
or settings.ENCRYPTION_KEY
|
||||
)
|
||||
|
||||
|
||||
def get_encryptor() -> Encryptor:
|
||||
"""获取全局加密器实例"""
|
||||
global _encryptor
|
||||
if _encryptor is None:
|
||||
_encryptor = Encryptor()
|
||||
return _encryptor
|
||||
|
||||
|
||||
def get_database_encryptor() -> Encryptor:
|
||||
"""获取 project DB DSN 专用加密器实例"""
|
||||
global _database_encryptor
|
||||
if _database_encryptor is None:
|
||||
key_str = (
|
||||
os.getenv("DATABASE_ENCRYPTION_KEY")
|
||||
or settings.DATABASE_ENCRYPTION_KEY
|
||||
or os.getenv("ENCRYPTION_KEY")
|
||||
or settings.ENCRYPTION_KEY
|
||||
)
|
||||
if not key_str:
|
||||
raise ValueError(
|
||||
"DATABASE_ENCRYPTION_KEY not found in environment variables or .env. "
|
||||
"Generate one using: Encryptor.generate_key()"
|
||||
)
|
||||
_database_encryptor = Encryptor(key=key_str.encode())
|
||||
return _database_encryptor
|
||||
|
||||
|
||||
# 向后兼容(延迟加载)
|
||||
def __getattr__(name):
|
||||
if name == "encryptor":
|
||||
return get_encryptor()
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
@@ -1,23 +1,91 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
"""
|
||||
创建 JWT Access Token
|
||||
|
||||
Args:
|
||||
subject: 用户标识(通常是用户名或用户ID)
|
||||
expires_delta: 过期时间增量
|
||||
|
||||
Returns:
|
||||
JWT token 字符串
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
expire = datetime.now() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
expire = datetime.now() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
to_encode = {
|
||||
"exp": expire,
|
||||
"sub": str(subject),
|
||||
"type": "access",
|
||||
"iat": datetime.now(),
|
||||
}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(subject: Union[str, Any]) -> str:
|
||||
"""
|
||||
创建 JWT Refresh Token(长期有效)
|
||||
|
||||
Args:
|
||||
subject: 用户标识
|
||||
|
||||
Returns:
|
||||
JWT refresh token 字符串
|
||||
"""
|
||||
expire = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"exp": expire,
|
||||
"sub": str(subject),
|
||||
"type": "refresh",
|
||||
"iat": datetime.now(),
|
||||
}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
验证密码
|
||||
|
||||
Args:
|
||||
plain_password: 明文密码
|
||||
hashed_password: 密码哈希
|
||||
|
||||
Returns:
|
||||
是否匹配
|
||||
"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""
|
||||
生成密码哈希
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
|
||||
Returns:
|
||||
bcrypt 哈希字符串
|
||||
"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
36
app/domain/models/role.py
Normal file
36
app/domain/models/role.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from enum import Enum
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""用户角色枚举"""
|
||||
ADMIN = "ADMIN" # 管理员 - 完全权限
|
||||
OPERATOR = "OPERATOR" # 操作员 - 可修改数据
|
||||
USER = "USER" # 普通用户 - 读写权限
|
||||
VIEWER = "VIEWER" # 观察者 - 仅查询权限
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def get_hierarchy(cls) -> dict:
|
||||
"""
|
||||
获取角色层级(数字越大权限越高)
|
||||
"""
|
||||
return {
|
||||
cls.VIEWER: 1,
|
||||
cls.USER: 2,
|
||||
cls.OPERATOR: 3,
|
||||
cls.ADMIN: 4,
|
||||
}
|
||||
|
||||
def has_permission(self, required_role: 'UserRole') -> bool:
|
||||
"""
|
||||
检查当前角色是否有足够权限
|
||||
|
||||
Args:
|
||||
required_role: 需要的最低角色
|
||||
|
||||
Returns:
|
||||
True if has permission
|
||||
"""
|
||||
hierarchy = self.get_hierarchy()
|
||||
return hierarchy[self] >= hierarchy[required_role]
|
||||
45
app/domain/schemas/audit.py
Normal file
45
app/domain/schemas/audit.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
class AuditLogCreate(BaseModel):
|
||||
"""创建审计日志"""
|
||||
user_id: Optional[UUID] = None
|
||||
project_id: Optional[UUID] = None
|
||||
action: str
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
request_method: Optional[str] = None
|
||||
request_path: Optional[str] = None
|
||||
request_data: Optional[dict] = None
|
||||
response_status: Optional[int] = None
|
||||
|
||||
class AuditLogResponse(BaseModel):
|
||||
"""审计日志响应"""
|
||||
id: UUID
|
||||
user_id: Optional[UUID]
|
||||
project_id: Optional[UUID]
|
||||
action: str
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
ip_address: Optional[str]
|
||||
request_method: Optional[str]
|
||||
request_path: Optional[str]
|
||||
request_data: Optional[dict]
|
||||
response_status: Optional[int]
|
||||
timestamp: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class AuditLogQuery(BaseModel):
|
||||
"""审计日志查询参数"""
|
||||
user_id: Optional[UUID] = None
|
||||
project_id: Optional[UUID] = None
|
||||
action: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
skip: int = Field(default=0, ge=0)
|
||||
limit: int = Field(default=100, ge=1, le=1000)
|
||||
33
app/domain/schemas/metadata.py
Normal file
33
app/domain/schemas/metadata.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeoServerConfigResponse(BaseModel):
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
class ProjectMetaResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
geoserver: Optional[GeoServerConfigResponse]
|
||||
|
||||
|
||||
class ProjectSummaryResponse(BaseModel):
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
68
app/domain/schemas/user.py
Normal file
68
app/domain/schemas/user.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, EmailStr, Field, ConfigDict
|
||||
from app.domain.models.role import UserRole
|
||||
|
||||
# ============================================
|
||||
# Request Schemas (输入)
|
||||
# ============================================
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
"""用户注册"""
|
||||
username: str = Field(..., min_length=3, max_length=50,
|
||||
description="用户名,3-50个字符")
|
||||
email: EmailStr = Field(..., description="邮箱地址")
|
||||
password: str = Field(..., min_length=6, max_length=100,
|
||||
description="密码,至少6个字符")
|
||||
role: UserRole = Field(default=UserRole.USER, description="用户角色")
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
"""用户登录"""
|
||||
username: str = Field(..., description="用户名或邮箱")
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
"""用户信息更新"""
|
||||
email: Optional[EmailStr] = None
|
||||
password: Optional[str] = Field(None, min_length=6, max_length=100)
|
||||
role: Optional[UserRole] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
# ============================================
|
||||
# Response Schemas (输出)
|
||||
# ============================================
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""用户信息响应(不含密码)"""
|
||||
id: int
|
||||
username: str
|
||||
email: str
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
class UserInDB(UserResponse):
|
||||
"""数据库中的用户(含密码哈希)"""
|
||||
hashed_password: str
|
||||
|
||||
# ============================================
|
||||
# Token Schemas
|
||||
# ============================================
|
||||
|
||||
class Token(BaseModel):
|
||||
"""JWT Token 响应"""
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = Field(..., description="过期时间(秒)")
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT Token Payload"""
|
||||
sub: str = Field(..., description="用户ID或用户名")
|
||||
exp: Optional[int] = None
|
||||
iat: Optional[int] = None
|
||||
type: str = Field(default="access", description="token类型: access 或 refresh")
|
||||
224
app/infra/audit/middleware.py
Normal file
224
app/infra/audit/middleware.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
审计日志中间件
|
||||
|
||||
自动记录关键HTTP请求到审计日志
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from uuid import UUID
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.core.audit import log_audit_event, AuditAction
|
||||
import logging
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.infra.db.metadata.database import SessionLocal
|
||||
from app.infra.repositories.metadata_repository import MetadataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
审计中间件
|
||||
|
||||
自动记录以下操作:
|
||||
- 所有 POST/PUT/DELETE 请求
|
||||
- 登录/登出
|
||||
- 关键资源访问
|
||||
"""
|
||||
|
||||
# 需要审计的路径前缀
|
||||
AUDIT_PATHS = [
|
||||
# "/api/v1/auth/",
|
||||
# "/api/v1/users/",
|
||||
# "/api/v1/projects/",
|
||||
# "/api/v1/networks/",
|
||||
]
|
||||
|
||||
# [新增] 需要审计的 API Tags (在 Router 或 api 函数中定义 tags=["Audit"])
|
||||
AUDIT_TAGS = [
|
||||
"Audit",
|
||||
"Users",
|
||||
"Project",
|
||||
"Network General",
|
||||
"Junctions",
|
||||
"Pipes",
|
||||
"Reservoirs",
|
||||
"Tanks",
|
||||
"Pumps",
|
||||
"Valves",
|
||||
]
|
||||
|
||||
# 需要审计的HTTP方法
|
||||
AUDIT_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 提取开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 预判是否需要读取Body (针对写操作)
|
||||
# 注意:我们暂时移除早期的 return,因为需要等待路由匹配后才能检查 Tag
|
||||
should_capture_body = request.method in ["POST", "PUT", "PATCH"]
|
||||
|
||||
request_data = None
|
||||
if should_capture_body:
|
||||
try:
|
||||
# 注意:读取 body 后需要重新设置,避免影响后续处理
|
||||
body = await request.body()
|
||||
if body:
|
||||
request_data = json.loads(body.decode())
|
||||
|
||||
# 重新构造请求以供后续使用
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": body}
|
||||
|
||||
request._receive = receive
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read request body for audit: {e}")
|
||||
|
||||
# 2. 执行请求 (FastAPI在此过程中进行路由匹配)
|
||||
response = await call_next(request)
|
||||
|
||||
# 3. 决定是否审计
|
||||
# 检查方法
|
||||
is_audit_method = request.method in self.AUDIT_METHODS
|
||||
# 检查路径
|
||||
is_audit_path = any(
|
||||
request.url.path.startswith(path) for path in self.AUDIT_PATHS
|
||||
)
|
||||
# [新增] 检查 Tags (从 request.scope 中获取匹配的路由信息)
|
||||
is_audit_tag = False
|
||||
route = request.scope.get("route")
|
||||
if route and hasattr(route, "tags"):
|
||||
is_audit_tag = any(tag in self.AUDIT_TAGS for tag in route.tags)
|
||||
|
||||
should_audit = is_audit_method or is_audit_path or is_audit_tag
|
||||
|
||||
if not should_audit:
|
||||
# 即便不审计,也要处理响应头中的时间(保持原有逻辑一致性)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
|
||||
# 4. 提取审计所需信息
|
||||
user_id = await self._resolve_user_id(request)
|
||||
project_id = self._resolve_project_id(request)
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
# 确定操作类型
|
||||
action = self._determine_action(request)
|
||||
resource_type, resource_id = self._extract_resource_info(request)
|
||||
|
||||
# 记录审计日志
|
||||
try:
|
||||
await log_audit_event(
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request.method,
|
||||
request_path=str(request.url.path),
|
||||
request_data=request_data,
|
||||
response_status=response.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
# 审计失败不应影响响应
|
||||
logger.error(f"Failed to log audit event: {e}", exc_info=True)
|
||||
|
||||
# 添加处理时间到响应头
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
||||
|
||||
def _resolve_project_id(self, request: Request) -> UUID | None:
|
||||
project_header = request.headers.get("X-Project-Id")
|
||||
if not project_header:
|
||||
return None
|
||||
try:
|
||||
return UUID(project_header)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def _resolve_user_id(self, request: Request) -> UUID | None:
|
||||
auth_header = request.headers.get("authorization")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
key = (
|
||||
settings.KEYCLOAK_PUBLIC_KEY.replace("\\n", "\n")
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else settings.SECRET_KEY
|
||||
)
|
||||
algorithms = (
|
||||
[settings.KEYCLOAK_ALGORITHM]
|
||||
if settings.KEYCLOAK_PUBLIC_KEY
|
||||
else [settings.ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, key, algorithms=algorithms)
|
||||
sub = payload.get("sub")
|
||||
if not sub:
|
||||
return None
|
||||
keycloak_id = UUID(sub)
|
||||
except (JWTError, ValueError):
|
||||
return None
|
||||
|
||||
async with SessionLocal() as session:
|
||||
repo = MetadataRepository(session)
|
||||
user = await repo.get_user_by_keycloak_id(keycloak_id)
|
||||
if user and user.is_active:
|
||||
return user.id
|
||||
return None
|
||||
|
||||
def _determine_action(self, request: Request) -> str:
|
||||
"""根据请求路径和方法确定操作类型"""
|
||||
path = request.url.path.lower()
|
||||
method = request.method
|
||||
|
||||
# 认证相关
|
||||
if "login" in path:
|
||||
return AuditAction.LOGIN
|
||||
elif "logout" in path:
|
||||
return AuditAction.LOGOUT
|
||||
elif "register" in path:
|
||||
return AuditAction.REGISTER
|
||||
|
||||
# CRUD 操作
|
||||
if method == "POST":
|
||||
return AuditAction.CREATE
|
||||
elif method == "PUT" or method == "PATCH":
|
||||
return AuditAction.UPDATE
|
||||
elif method == "DELETE":
|
||||
return AuditAction.DELETE
|
||||
elif method == "GET":
|
||||
return AuditAction.READ
|
||||
|
||||
return f"{method}_REQUEST"
|
||||
|
||||
def _extract_resource_info(self, request: Request) -> tuple:
|
||||
"""从请求路径提取资源类型和ID"""
|
||||
path_parts = request.url.path.strip("/").split("/")
|
||||
|
||||
resource_type = None
|
||||
resource_id = None
|
||||
|
||||
# 尝试从路径中提取资源信息
|
||||
# 例如: /api/v1/users/123 -> resource_type=user, resource_id=123
|
||||
if len(path_parts) >= 4:
|
||||
resource_type = path_parts[3].rstrip("s") # 移除复数s
|
||||
|
||||
if len(path_parts) >= 5 and path_parts[4].isdigit():
|
||||
resource_id = path_parts[4]
|
||||
|
||||
return resource_type, resource_id
|
||||
19
app/infra/cache/redis_client.py
vendored
Normal file
19
app/infra/cache/redis_client.py
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
import redis
|
||||
import msgpack
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
# Initialize Redis connection
|
||||
redis_client = redis.Redis(host="127.0.0.1", port=6379, db=0)
|
||||
|
||||
def encode_datetime(obj: Any) -> Any:
|
||||
"""Serialize datetime objects to dictionary format."""
|
||||
if isinstance(obj, datetime):
|
||||
return {"__datetime__": True, "as_str": obj.strftime("%Y%m%dT%H:%M:%S.%f")}
|
||||
return obj
|
||||
|
||||
def decode_datetime(obj: Any) -> Any:
|
||||
"""Deserialize dictionary format to datetime objects."""
|
||||
if "__datetime__" in obj:
|
||||
return datetime.strptime(obj["as_str"], "%Y%m%dT%H:%M:%S.%f")
|
||||
return obj
|
||||
211
app/infra/db/dynamic_manager.py
Normal file
211
app/infra/db/dynamic_manager.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.rows import dict_row
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PgEngineEntry:
|
||||
engine: AsyncEngine
|
||||
sessionmaker: async_sessionmaker[AsyncSession]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheKey:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
|
||||
|
||||
class ProjectConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
|
||||
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
|
||||
self._pg_lock = asyncio.Lock()
|
||||
self._ts_lock = asyncio.Lock()
|
||||
self._pg_raw_lock = asyncio.Lock()
|
||||
|
||||
def _normalize_pg_url(self, url: str) -> str:
|
||||
parsed = make_url(url)
|
||||
if parsed.drivername == "postgresql":
|
||||
parsed = parsed.set(drivername="postgresql+psycopg")
|
||||
return str(parsed)
|
||||
|
||||
async def get_pg_sessionmaker(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> async_sessionmaker[AsyncSession]:
|
||||
async with self._pg_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
entry = self._pg_cache.get(key)
|
||||
if entry:
|
||||
self._pg_cache.move_to_end(key)
|
||||
return entry.sessionmaker
|
||||
|
||||
normalized_url = self._normalize_pg_url(connection_url)
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
engine = create_async_engine(
|
||||
normalized_url,
|
||||
pool_size=pool_min_size,
|
||||
max_overflow=max(0, pool_max_size - pool_min_size),
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
self._pg_cache[key] = PgEngineEntry(
|
||||
engine=engine,
|
||||
sessionmaker=sessionmaker,
|
||||
)
|
||||
await self._evict_pg_if_needed()
|
||||
logger.info(
|
||||
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return sessionmaker
|
||||
|
||||
async def get_timescale_pool(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._ts_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._ts_cache.get(key)
|
||||
if pool:
|
||||
self._ts_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
max_size=pool_max_size,
|
||||
open=False,
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._ts_cache[key] = pool
|
||||
await self._evict_ts_if_needed()
|
||||
logger.info(
|
||||
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return pool
|
||||
|
||||
async def get_pg_pool(
|
||||
self,
|
||||
project_id: UUID,
|
||||
db_role: str,
|
||||
connection_url: str,
|
||||
pool_min_size: int,
|
||||
pool_max_size: int,
|
||||
) -> AsyncConnectionPool:
|
||||
async with self._pg_raw_lock:
|
||||
key = CacheKey(project_id=project_id, db_role=db_role)
|
||||
pool = self._pg_raw_cache.get(key)
|
||||
if pool:
|
||||
self._pg_raw_cache.move_to_end(key)
|
||||
return pool
|
||||
|
||||
pool_min_size = max(1, pool_min_size)
|
||||
pool_max_size = max(pool_min_size, pool_max_size)
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=connection_url,
|
||||
min_size=pool_min_size,
|
||||
max_size=pool_max_size,
|
||||
open=False,
|
||||
kwargs={"row_factory": dict_row},
|
||||
)
|
||||
await pool.open()
|
||||
self._pg_raw_cache[key] = pool
|
||||
await self._evict_pg_raw_if_needed()
|
||||
logger.info(
|
||||
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
|
||||
)
|
||||
return pool
|
||||
|
||||
async def _evict_pg_if_needed(self) -> None:
|
||||
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||
key, entry = self._pg_cache.popitem(last=False)
|
||||
await entry.engine.dispose()
|
||||
logger.info(
|
||||
"Evicted PostgreSQL engine for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def _evict_ts_if_needed(self) -> None:
|
||||
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
|
||||
key, pool = self._ts_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Evicted TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def _evict_pg_raw_if_needed(self) -> None:
|
||||
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
|
||||
key, pool = self._pg_raw_cache.popitem(last=False)
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Evicted PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
async with self._pg_lock:
|
||||
for key, entry in list(self._pg_cache.items()):
|
||||
await entry.engine.dispose()
|
||||
logger.info(
|
||||
"Closed PostgreSQL engine for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._pg_cache.clear()
|
||||
|
||||
async with self._ts_lock:
|
||||
for key, pool in list(self._ts_cache.items()):
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Closed TimescaleDB pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._ts_cache.clear()
|
||||
|
||||
async with self._pg_raw_lock:
|
||||
for key, pool in list(self._pg_raw_cache.items()):
|
||||
await pool.close()
|
||||
logger.info(
|
||||
"Closed PostgreSQL pool for project %s (%s)",
|
||||
key.project_id,
|
||||
key.db_role,
|
||||
)
|
||||
self._pg_raw_cache.clear()
|
||||
|
||||
|
||||
project_connection_manager = ProjectConnectionManager()
|
||||
@@ -13,12 +13,12 @@ from typing import List, Dict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from influxdb_client.client.write_api import SYNCHRONOUS, ASYNCHRONOUS
|
||||
from dateutil import parser
|
||||
import get_realValue
|
||||
import get_data
|
||||
# import get_realValue
|
||||
# import get_data
|
||||
import psycopg
|
||||
import time
|
||||
import app.services.simulation as simulation
|
||||
from tjnetwork import *
|
||||
from app.services.tjnetwork import *
|
||||
import schedule
|
||||
import threading
|
||||
import app.services.globals as globals
|
||||
@@ -404,8 +404,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
|
||||
Point("link")
|
||||
.tag("date", None)
|
||||
.tag("ID", None)
|
||||
.tag("scheme_Type", None)
|
||||
.tag("scheme_Name", None)
|
||||
.tag("scheme_type", None)
|
||||
.tag("scheme_name", None)
|
||||
.field("flow", 0.0)
|
||||
.field("leakage", 0.0)
|
||||
.field("velocity", 0.0)
|
||||
@@ -420,8 +420,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
|
||||
Point("node")
|
||||
.tag("date", None)
|
||||
.tag("ID", None)
|
||||
.tag("scheme_Type", None)
|
||||
.tag("scheme_Name", None)
|
||||
.tag("scheme_type", None)
|
||||
.tag("scheme_name", None)
|
||||
.field("head", 0.0)
|
||||
.field("pressure", 0.0)
|
||||
.field("actualdemand", 0.0)
|
||||
@@ -436,8 +436,8 @@ def create_and_initialize_buckets(org_name: str) -> None:
|
||||
.tag("date", None)
|
||||
.tag("description", None)
|
||||
.tag("device_ID", None)
|
||||
.tag("scheme_Type", None)
|
||||
.tag("scheme_Name", None)
|
||||
.tag("scheme_type", None)
|
||||
.tag("scheme_name", None)
|
||||
.field("monitored_value", 0.0)
|
||||
.field("datacleaning_value", 0.0)
|
||||
.field("scheme_simulation_value", 0.0)
|
||||
@@ -1811,8 +1811,8 @@ def query_SCADA_data_by_device_ID_and_time(
|
||||
def query_scheme_SCADA_data_by_device_ID_and_time(
|
||||
query_ids_list: List[str],
|
||||
query_time: str,
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
@@ -1843,7 +1843,7 @@ def query_scheme_SCADA_data_by_device_ID_and_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["device_ID"] == "{device_id}" and r["_field"] == "monitored_value" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|
||||
|> filter(fn: (r) => r["device_ID"] == "{device_id}" and r["_field"] == "monitored_value" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
|
||||
"""
|
||||
# 执行查询
|
||||
try:
|
||||
@@ -2585,7 +2585,7 @@ def query_all_scheme_record_by_time_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -2635,7 +2635,7 @@ def query_scheme_simulation_result_by_ID_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "node" and r["ID"] == "{ID}")
|
||||
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" and r["ID"] == "{ID}")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -2660,7 +2660,7 @@ def query_scheme_simulation_result_by_ID_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Name"] == "{scheme_name}" and r["_measurement"] == "link" and r["ID"] == "{ID}")
|
||||
|> filter(fn: (r) => r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "link" and r["ID"] == "{ID}")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -3227,8 +3227,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
link_result_list: List[Dict[str, any]],
|
||||
scheme_start_time: str,
|
||||
num_periods: int = 1,
|
||||
scheme_Type: str = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_type: str = None,
|
||||
scheme_name: str = None,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
):
|
||||
"""
|
||||
@@ -3237,8 +3237,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
:param link_result_list: (List[Dict[str, any]]): 包含连接和结果数据的字典列表。
|
||||
:param scheme_start_time: (str): 方案模拟开始时间。
|
||||
:param num_periods: (int): 方案模拟的周期数
|
||||
:param scheme_Type: (str): 方案类型
|
||||
:param scheme_Name: (str): 方案名称
|
||||
:param scheme_type: (str): 方案类型
|
||||
:param scheme_name: (str): 方案名称
|
||||
:param bucket: (str): InfluxDB 的 bucket 名称,默认值为 "scheme_simulation_result"。
|
||||
:return:
|
||||
"""
|
||||
@@ -3298,8 +3298,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
Point("node")
|
||||
.tag("date", date_str)
|
||||
.tag("ID", node_id)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("head", data.get("head", 0.0))
|
||||
.field("pressure", data.get("pressure", 0.0))
|
||||
.field("actualdemand", data.get("demand", 0.0))
|
||||
@@ -3322,8 +3322,8 @@ def store_scheme_simulation_result_to_influxdb(
|
||||
Point("link")
|
||||
.tag("date", date_str)
|
||||
.tag("ID", link_id)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("flow", data.get("flow", 0.0))
|
||||
.field("velocity", data.get("velocity", 0.0))
|
||||
.field("headloss", data.get("headloss", 0.0))
|
||||
@@ -3409,14 +3409,14 @@ def query_corresponding_query_id_and_element_id(name: str) -> None:
|
||||
|
||||
# 2025/03/11
|
||||
def fill_scheme_simulation_result_to_SCADA(
|
||||
scheme_Type: str = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_type: str = None,
|
||||
scheme_name: str = None,
|
||||
query_date: str = None,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
):
|
||||
"""
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param bucket: InfluxDB 的 bucket 名称,默认值为 "scheme_simulation_result"
|
||||
:return:
|
||||
@@ -3457,8 +3457,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
# 查找associated_element_id的对应值
|
||||
for key, value in globals.scheme_source_outflow_ids.items():
|
||||
scheme_source_outflow_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="link",
|
||||
@@ -3470,8 +3470,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_source_outflow")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3480,8 +3480,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_pipe_flow_ids.items():
|
||||
scheme_pipe_flow_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="link",
|
||||
@@ -3492,8 +3492,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_pipe_flow")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3502,8 +3502,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_pressure_ids.items():
|
||||
scheme_pressure_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="node",
|
||||
@@ -3514,8 +3514,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_pressure")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3524,8 +3524,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_demand_ids.items():
|
||||
scheme_demand_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="node",
|
||||
@@ -3536,8 +3536,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_demand")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3546,8 +3546,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
|
||||
for key, value in globals.scheme_quality_ids.items():
|
||||
scheme_quality_result = query_scheme_curve_by_ID_property(
|
||||
scheme_Type=scheme_Type,
|
||||
scheme_Name=scheme_Name,
|
||||
scheme_type=scheme_type,
|
||||
scheme_name=scheme_name,
|
||||
query_date=query_date,
|
||||
ID=value,
|
||||
type="node",
|
||||
@@ -3558,8 +3558,8 @@ def fill_scheme_simulation_result_to_SCADA(
|
||||
Point("scheme_quality")
|
||||
.tag("date", query_date)
|
||||
.tag("device_ID", key)
|
||||
.tag("scheme_Type", scheme_Type)
|
||||
.tag("scheme_Name", scheme_Name)
|
||||
.tag("scheme_type", scheme_type)
|
||||
.tag("scheme_name", scheme_name)
|
||||
.field("monitored_value", data["value"])
|
||||
.time(data["time"], write_precision="s")
|
||||
)
|
||||
@@ -3629,15 +3629,15 @@ def query_SCADA_data_curve(
|
||||
|
||||
# 2025/02/18
|
||||
def query_scheme_all_record_by_time(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_time: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> tuple:
|
||||
"""
|
||||
查询指定方案某一时刻的所有记录,包括‘node'和‘link’,分别以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_time: 输入的北京时间,格式为 '2024-11-24T17:30:00+08:00'。
|
||||
:param bucket: 数据存储的 bucket 名称。
|
||||
:return: dict: tuple: (node_records, link_records)
|
||||
@@ -3660,7 +3660,7 @@ def query_scheme_all_record_by_time(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -3710,8 +3710,8 @@ def query_scheme_all_record_by_time(
|
||||
|
||||
# 2025/03/04
|
||||
def query_scheme_all_record_by_time_property(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_time: str,
|
||||
type: str,
|
||||
property: str,
|
||||
@@ -3719,8 +3719,8 @@ def query_scheme_all_record_by_time_property(
|
||||
) -> list:
|
||||
"""
|
||||
查询指定方案某一时刻‘node'或‘link’某一属性值,以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_time: 输入的北京时间,格式为 '2024-11-24T17:30:00+08:00'。
|
||||
:param type: 查询的类型(决定 measurement)
|
||||
:param property: 查询的字段名称(field)
|
||||
@@ -3752,7 +3752,7 @@ def query_scheme_all_record_by_time_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -3767,8 +3767,8 @@ def query_scheme_all_record_by_time_property(
|
||||
|
||||
# 2025/02/19
|
||||
def query_scheme_curve_by_ID_property(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
ID: str,
|
||||
type: str,
|
||||
@@ -3776,9 +3776,9 @@ def query_scheme_curve_by_ID_property(
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> list:
|
||||
"""
|
||||
根据scheme_Type和scheme_Name,查询该模拟方案中,某一node或link的某一属性值的所有时间的结果
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
根据scheme_Type和scheme_name,查询该模拟方案中,某一node或link的某一属性值的所有时间的结果
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param ID: 元素的ID
|
||||
:param type: 元素的类型,node或link
|
||||
@@ -3817,7 +3817,7 @@ def query_scheme_curve_by_ID_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["_measurement"] == "{measurement}" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["ID"] == "{ID}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["_measurement"] == "{measurement}" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["ID"] == "{ID}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -3832,15 +3832,15 @@ def query_scheme_curve_by_ID_property(
|
||||
|
||||
# 2025/02/21
|
||||
def query_scheme_all_record(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> tuple:
|
||||
"""
|
||||
查询指定方案的所有记录,包括‘node'和‘link’,分别以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param bucket: 数据存储的 bucket 名称。
|
||||
:return: dict: tuple: (node_records, link_records)
|
||||
@@ -3867,7 +3867,7 @@ def query_scheme_all_record(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {utc_start_time.isoformat()}, stop: {utc_stop_time.isoformat()})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["_measurement"] == "node" or r["_measurement"] == "link")
|
||||
|> pivot(
|
||||
rowKey:["_time"],
|
||||
columnKey:["_field"],
|
||||
@@ -3917,8 +3917,8 @@ def query_scheme_all_record(
|
||||
|
||||
# 2025/03/04
|
||||
def query_scheme_all_record_property(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
type: str,
|
||||
property: str,
|
||||
@@ -3926,8 +3926,8 @@ def query_scheme_all_record_property(
|
||||
) -> list:
|
||||
"""
|
||||
查询指定方案的‘node'或‘link’的某一属性值,以指定格式返回。
|
||||
:param scheme_Type: 方案类型
|
||||
:param scheme_Name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param scheme_name: 方案名称
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param type: 查询的类型(决定 measurement)
|
||||
:param property: 查询的字段名称(field)
|
||||
@@ -3964,7 +3964,7 @@ def query_scheme_all_record_property(
|
||||
flux_query = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}" and r["date"] == "{query_date}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
|> filter(fn: (r) => r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}" and r["date"] == "{query_date}" and r["_measurement"] == "{measurement}" and r["_field"] == "{property}")
|
||||
"""
|
||||
# 执行查询
|
||||
tables = query_api.query(flux_query)
|
||||
@@ -4245,8 +4245,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
link_data[key][field] = record.get_value()
|
||||
link_data[key]["measurement"] = record.get_measurement()
|
||||
link_data[key]["date"] = record.values.get("date", None)
|
||||
link_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
link_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
link_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
link_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
# 构建 Flux 查询语句,查询指定时间范围内的数据
|
||||
flux_query_node = f"""
|
||||
from(bucket: "{bucket}")
|
||||
@@ -4267,8 +4267,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
node_data[key][field] = record.get_value()
|
||||
node_data[key]["measurement"] = record.get_measurement()
|
||||
node_data[key]["date"] = record.values.get("date", None)
|
||||
node_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
node_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
node_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
node_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
for key in set(link_data.keys()):
|
||||
row = {"time": key[0], "ID": key[1]}
|
||||
row.update(link_data.get(key, {}))
|
||||
@@ -4288,8 +4288,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"flow",
|
||||
"leakage",
|
||||
@@ -4311,8 +4311,8 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"head",
|
||||
"pressure",
|
||||
@@ -4330,15 +4330,15 @@ def export_scheme_simulation_result_to_csv_time(
|
||||
|
||||
# 2025/02/18
|
||||
def export_scheme_simulation_result_to_csv_scheme(
|
||||
scheme_Type: str,
|
||||
scheme_Name: str,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_date: str,
|
||||
bucket: str = "scheme_simulation_result",
|
||||
) -> None:
|
||||
"""
|
||||
导出influxdb中scheme_simulation_result这个bucket的数据到csv中
|
||||
:param scheme_Type: 查询的方案类型
|
||||
:param scheme_Name: 查询的方案名
|
||||
:param scheme_type: 查询的方案类型
|
||||
:param scheme_name: 查询的方案名
|
||||
:param query_date: 查询日期,格式为 'YYYY-MM-DD'
|
||||
:param bucket: 数据存储的 bucket 名称,默认值为 "SCADA_data"
|
||||
:return:
|
||||
@@ -4366,7 +4366,7 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
flux_query_link = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["_measurement"] == "link" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|
||||
|> filter(fn: (r) => r["_measurement"] == "link" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
|
||||
"""
|
||||
# 执行查询
|
||||
link_tables = query_api.query(flux_query_link)
|
||||
@@ -4382,13 +4382,13 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
link_data[key][field] = record.get_value()
|
||||
link_data[key]["measurement"] = record.get_measurement()
|
||||
link_data[key]["date"] = record.values.get("date", None)
|
||||
link_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
link_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
link_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
link_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
# 构建 Flux 查询语句,查询指定时间范围内的数据
|
||||
flux_query_node = f"""
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: {start_time}, stop: {stop_time})
|
||||
|> filter(fn: (r) => r["_measurement"] == "node" and r["scheme_Type"] == "{scheme_Type}" and r["scheme_Name"] == "{scheme_Name}")
|
||||
|> filter(fn: (r) => r["_measurement"] == "node" and r["scheme_type"] == "{scheme_type}" and r["scheme_name"] == "{scheme_name}")
|
||||
"""
|
||||
# 执行查询
|
||||
node_tables = query_api.query(flux_query_node)
|
||||
@@ -4404,8 +4404,8 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
node_data[key][field] = record.get_value()
|
||||
node_data[key]["measurement"] = record.get_measurement()
|
||||
node_data[key]["date"] = record.values.get("date", None)
|
||||
node_data[key]["scheme_Type"] = record.values.get("scheme_Type", None)
|
||||
node_data[key]["scheme_Name"] = record.values.get("scheme_Name", None)
|
||||
node_data[key]["scheme_type"] = record.values.get("scheme_type", None)
|
||||
node_data[key]["scheme_name"] = record.values.get("scheme_name", None)
|
||||
for key in set(link_data.keys()):
|
||||
row = {"time": key[0], "ID": key[1]}
|
||||
row.update(link_data.get(key, {}))
|
||||
@@ -4416,10 +4416,10 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
node_rows.append(row)
|
||||
# 动态生成 CSV 文件名
|
||||
csv_filename_link = (
|
||||
f"scheme_simulation_link_result_{scheme_Name}_of_{scheme_Type}.csv"
|
||||
f"scheme_simulation_link_result_{scheme_name}_of_{scheme_type}.csv"
|
||||
)
|
||||
csv_filename_node = (
|
||||
f"scheme_simulation_node_result_{scheme_Name}_of_{scheme_Type}.csv"
|
||||
f"scheme_simulation_node_result_{scheme_name}_of_{scheme_type}.csv"
|
||||
)
|
||||
# 写入到 CSV 文件
|
||||
with open(csv_filename_link, mode="w", newline="") as file:
|
||||
@@ -4429,8 +4429,8 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"flow",
|
||||
"leakage",
|
||||
@@ -4452,8 +4452,8 @@ def export_scheme_simulation_result_to_csv_scheme(
|
||||
"time",
|
||||
"measurement",
|
||||
"date",
|
||||
"scheme_Type",
|
||||
"scheme_Name",
|
||||
"scheme_type",
|
||||
"scheme_name",
|
||||
"ID",
|
||||
"head",
|
||||
"pressure",
|
||||
@@ -4878,15 +4878,15 @@ if __name__ == "__main__":
|
||||
# export_scheme_simulation_result_to_csv_time(start_date='2025-02-13', end_date='2025-02-15')
|
||||
|
||||
# 示例9:export_scheme_simulation_result_to_csv_scheme
|
||||
# export_scheme_simulation_result_to_csv_scheme(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10')
|
||||
# export_scheme_simulation_result_to_csv_scheme(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10')
|
||||
|
||||
# 示例10:query_scheme_all_record_by_time
|
||||
# node_records, link_records = query_scheme_all_record_by_time(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_time="2025-02-14T10:30:00+08:00")
|
||||
# node_records, link_records = query_scheme_all_record_by_time(scheme_type='burst_Analysis', scheme_name='scheme1', query_time="2025-02-14T10:30:00+08:00")
|
||||
# print("Node 数据:", node_records)
|
||||
# print("Link 数据:", link_records)
|
||||
|
||||
# 示例11:query_scheme_curve_by_ID_property
|
||||
# curve_result = query_scheme_curve_by_ID_property(scheme_Type='burst_Analysis', scheme_Name='scheme1', ID='ZBBDTZDP000022',
|
||||
# curve_result = query_scheme_curve_by_ID_property(scheme_type='burst_Analysis', scheme_name='scheme1', ID='ZBBDTZDP000022',
|
||||
# type='node', property='head')
|
||||
# print(curve_result)
|
||||
|
||||
@@ -4896,7 +4896,7 @@ if __name__ == "__main__":
|
||||
# print("Link 数据:", link_records)
|
||||
|
||||
# 示例13:query_scheme_all_record
|
||||
# node_records, link_records = query_scheme_all_record(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10')
|
||||
# node_records, link_records = query_scheme_all_record(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10')
|
||||
# print("Node 数据:", node_records)
|
||||
# print("Link 数据:", link_records)
|
||||
|
||||
@@ -4909,16 +4909,16 @@ if __name__ == "__main__":
|
||||
# print(result_records)
|
||||
|
||||
# 示例16:query_scheme_all_record_by_time_property
|
||||
# result_records = query_scheme_all_record_by_time_property(scheme_Type='burst_Analysis', scheme_Name='scheme1',
|
||||
# result_records = query_scheme_all_record_by_time_property(scheme_type='burst_Analysis', scheme_name='scheme1',
|
||||
# query_time='2025-02-14T10:30:00+08:00', type='node', property='head')
|
||||
# print(result_records)
|
||||
|
||||
# 示例17:query_scheme_all_record_property
|
||||
# result_records = query_scheme_all_record_property(scheme_Type='burst_Analysis', scheme_Name='scheme1', query_date='2025-03-10', type='node', property='head')
|
||||
# result_records = query_scheme_all_record_property(scheme_type='burst_Analysis', scheme_name='scheme1', query_date='2025-03-10', type='node', property='head')
|
||||
# print(result_records)
|
||||
|
||||
# 示例18:fill_scheme_simulation_result_to_SCADA
|
||||
# fill_scheme_simulation_result_to_SCADA(scheme_Type='burst_Analysis', scheme_Name='burst0330', query_date='2025-03-30')
|
||||
# fill_scheme_simulation_result_to_SCADA(scheme_type='burst_Analysis', scheme_name='burst0330', query_date='2025-03-30')
|
||||
|
||||
# 示例19:query_SCADA_data_by_device_ID_and_timerange
|
||||
# result = query_SCADA_data_by_device_ID_and_timerange(query_ids_list=globals.pressure_non_realtime_ids, start_time='2025-04-16T00:00:00+08:00',
|
||||
@@ -4926,7 +4926,7 @@ if __name__ == "__main__":
|
||||
# print(result)
|
||||
|
||||
# 示例:manually_get_burst_flow
|
||||
# leakage = manually_get_burst_flow(scheme_Type='burst_Analysis', scheme_Name='burst_scheme', scheme_start_time='2025-03-10T12:00:00+08:00')
|
||||
# leakage = manually_get_burst_flow(scheme_type='burst_Analysis', scheme_name='burst_scheme', scheme_start_time='2025-03-10T12:00:00+08:00')
|
||||
# print(leakage)
|
||||
|
||||
# 示例:upload_cleaned_SCADA_data_to_influxdb
|
||||
|
||||
3
app/infra/db/metadata/__init__.py
Normal file
3
app/infra/db/metadata/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .database import get_metadata_session, close_metadata_engine
|
||||
|
||||
__all__ = ["get_metadata_session", "close_metadata_engine"]
|
||||
27
app/infra/db/metadata/database.py
Normal file
27
app/infra/db/metadata/database.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.METADATA_DATABASE_URI,
|
||||
pool_size=settings.METADATA_DB_POOL_SIZE,
|
||||
max_overflow=settings.METADATA_DB_MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_metadata_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with SessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def close_metadata_engine() -> None:
|
||||
await engine.dispose()
|
||||
logger.info("Metadata database engine disposed.")
|
||||
115
app/infra/db/metadata/models.py
Normal file
115
app/infra/db/metadata/models.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
keycloak_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
username: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
email: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
role: Mapped[str] = mapped_column(String(20), default="user")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Project(Base):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100))
|
||||
code: Mapped[str] = mapped_column(String(50), unique=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_workspace: Mapped[str] = mapped_column(String(100), unique=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class ProjectDatabase(Base):
|
||||
__tablename__ = "project_databases"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
db_role: Mapped[str] = mapped_column(String(20))
|
||||
db_type: Mapped[str] = mapped_column(String(20))
|
||||
dsn_encrypted: Mapped[str] = mapped_column(Text)
|
||||
pool_min_size: Mapped[int] = mapped_column(Integer, default=2)
|
||||
pool_max_size: Mapped[int] = mapped_column(Integer, default=10)
|
||||
|
||||
|
||||
class ProjectGeoServerConfig(Base):
|
||||
__tablename__ = "project_geoserver_configs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
project_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), unique=True, index=True
|
||||
)
|
||||
gs_base_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
gs_admin_user: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
gs_admin_password_encrypted: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
gs_datastore_name: Mapped[str] = mapped_column(String(100), default="ds_postgis")
|
||||
default_extent: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
srid: Mapped[int] = mapped_column(Integer, default=4326)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
|
||||
|
||||
class UserProjectMembership(Base):
|
||||
__tablename__ = "user_project_membership"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), index=True)
|
||||
project_role: Mapped[str] = mapped_column(String(20), default="viewer")
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
project_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), nullable=True, index=True
|
||||
)
|
||||
action: Mapped[str] = mapped_column(String(50))
|
||||
resource_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
resource_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
request_method: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
request_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
request_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
response_status: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
timestamp: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=datetime.utcnow
|
||||
)
|
||||
@@ -17,7 +17,14 @@ class Database:
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
target_db_name = db_name or self.db_name
|
||||
|
||||
# Get connection string, handling default case where target_db_name might be None
|
||||
if target_db_name:
|
||||
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
|
||||
else:
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
@@ -26,7 +33,7 @@ class Database:
|
||||
open=False, # Don't open immediately, wait for startup
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: default")
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize postgresql connection pool: {e}")
|
||||
raise
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi.logger import logger
|
||||
import postgresql_info
|
||||
import app.native.api.postgresql_info as postgresql_info
|
||||
import psycopg
|
||||
|
||||
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .scada_info import ScadaRepository
|
||||
from .scheme import SchemeRepository
|
||||
from app.auth.project_dependencies import get_project_pg_connection
|
||||
|
||||
router = APIRouter(prefix="/postgresql", tags=["postgresql"])
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
# 动态项目 PostgreSQL 连接依赖
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
@router.get("/scada-info")
|
||||
|
||||
@@ -4,15 +4,15 @@ from datetime import datetime, timedelta
|
||||
from psycopg import AsyncConnection
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from api_ex.Fdataclean import clean_flow_data_df_kf
|
||||
from api_ex.Pdataclean import clean_pressure_data_df_km
|
||||
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
from app.algorithms.api_ex.flow_data_clean import clean_flow_data_df_kf
|
||||
from app.algorithms.api_ex.pressure_data_clean import clean_pressure_data_df_km
|
||||
from app.algorithms.api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
from postgresql.internal_queries import InternalQueries
|
||||
from postgresql.scada_info import ScadaRepository as PostgreScadaRepository
|
||||
from timescaledb.schemas.realtime import RealtimeRepository
|
||||
from timescaledb.schemas.scheme import SchemeRepository
|
||||
from timescaledb.schemas.scada import ScadaRepository
|
||||
from app.infra.db.postgresql.internal_queries import InternalQueries
|
||||
from app.infra.db.postgresql.scada_info import ScadaRepository as PostgreScadaRepository
|
||||
from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository
|
||||
from app.infra.db.timescaledb.schemas.scheme import SchemeRepository
|
||||
from app.infra.db.timescaledb.schemas.scada import ScadaRepository
|
||||
|
||||
|
||||
class CompositeQueries:
|
||||
@@ -405,12 +405,8 @@ class CompositeQueries:
|
||||
pressure_df = df[pressure_ids]
|
||||
# 重置索引,将 time 变为普通列
|
||||
pressure_df = pressure_df.reset_index()
|
||||
# 移除 time 列,准备输入给清洗方法
|
||||
value_df = pressure_df.drop(columns=["time"])
|
||||
# 调用清洗方法
|
||||
cleaned_value_df = clean_pressure_data_df_km(value_df)
|
||||
# 添加 time 列到首列
|
||||
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
|
||||
cleaned_df = clean_pressure_data_df_km(pressure_df)
|
||||
# 将清洗后的数据写回数据库
|
||||
for device_id in pressure_ids:
|
||||
if device_id in cleaned_df.columns:
|
||||
@@ -432,12 +428,8 @@ class CompositeQueries:
|
||||
flow_df = df[flow_ids]
|
||||
# 重置索引,将 time 变为普通列
|
||||
flow_df = flow_df.reset_index()
|
||||
# 移除 time 列,准备输入给清洗方法
|
||||
value_df = flow_df.drop(columns=["time"])
|
||||
# 调用清洗方法
|
||||
cleaned_value_df = clean_flow_data_df_kf(value_df)
|
||||
# 添加 time 列到首列
|
||||
cleaned_df = pd.concat([flow_df["time"], cleaned_value_df], axis=1)
|
||||
cleaned_df = clean_flow_data_df_kf(flow_df)
|
||||
# 将清洗后的数据写回数据库
|
||||
for device_id in flow_ids:
|
||||
if device_id in cleaned_df.columns:
|
||||
@@ -583,9 +575,7 @@ class CompositeQueries:
|
||||
)
|
||||
|
||||
# 7. 使用PipelineHealthAnalyzer进行预测
|
||||
analyzer = PipelineHealthAnalyzer(
|
||||
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
|
||||
)
|
||||
analyzer = PipelineHealthAnalyzer()
|
||||
survival_functions = analyzer.predict_survival(data)
|
||||
# 8. 组合结果
|
||||
results = []
|
||||
|
||||
@@ -17,7 +17,14 @@ class Database:
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = timescaledb_info.get_pgconn_string()
|
||||
target_db_name = db_name or self.db_name
|
||||
|
||||
# Get connection string, handling default case where target_db_name might be None
|
||||
if target_db_name:
|
||||
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
else:
|
||||
conn_string = timescaledb_info.get_pgconn_string()
|
||||
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
@@ -27,7 +34,7 @@ class Database:
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(
|
||||
f"TimescaleDB connection pool initialized for database: default"
|
||||
f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
|
||||
@@ -46,7 +53,9 @@ class Database:
|
||||
def get_pgconn_string(self, db_name=None):
|
||||
"""Get the TimescaleDB connection string."""
|
||||
target_db_name = db_name or self.db_name
|
||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
if target_db_name:
|
||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
return timescaledb_info.get_pgconn_string()
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator:
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi.logger import logger
|
||||
from timescaledb.schemas.scheme import SchemeRepository
|
||||
from timescaledb.schemas.realtime import RealtimeRepository
|
||||
import timescaledb.timescaledb_info as timescaledb_info
|
||||
from datetime import datetime, timedelta
|
||||
from timescaledb.schemas.scada import ScadaRepository
|
||||
import psycopg
|
||||
import time
|
||||
from app.infra.db.timescaledb.schemas.scheme import SchemeRepository
|
||||
from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository
|
||||
import app.infra.db.timescaledb.timescaledb_info as timescaledb_info
|
||||
from app.infra.db.timescaledb.schemas.scada import ScadaRepository
|
||||
|
||||
|
||||
class InternalStorage:
|
||||
|
||||
@@ -1,40 +1,32 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .schemas.realtime import RealtimeRepository
|
||||
from .schemas.scheme import SchemeRepository
|
||||
from .schemas.scada import ScadaRepository
|
||||
from .composite_queries import CompositeQueries
|
||||
from postgresql.database import get_database_instance as get_postgres_database_instance
|
||||
from app.auth.project_dependencies import (
|
||||
get_project_pg_connection,
|
||||
get_project_timescale_connection,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"])
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
# 动态项目 TimescaleDB 连接依赖
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_timescale_connection),
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
# PostgreSQL 数据库连接依赖函数
|
||||
# 动态项目 PostgreSQL 连接依赖
|
||||
async def get_postgres_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
conn: AsyncConnection = Depends(get_project_pg_connection),
|
||||
):
|
||||
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_postgres_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
yield conn
|
||||
|
||||
|
||||
# --- Realtime Endpoints ---
|
||||
@@ -152,7 +144,7 @@ async def query_realtime_simulation_by_id_time(
|
||||
results = await RealtimeRepository.query_simulation_result_by_id_time(
|
||||
conn, id, type, query_time
|
||||
)
|
||||
return {"results": results}
|
||||
return {"result": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Any, Dict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections import defaultdict
|
||||
from psycopg import AsyncConnection, Connection, sql
|
||||
import globals
|
||||
import app.services.globals as globals
|
||||
|
||||
# 定义UTC+8时区
|
||||
UTC_8 = timezone(timedelta(hours=8))
|
||||
|
||||
112
app/infra/repositories/audit_repository.py
Normal file
112
app/infra/repositories/audit_repository.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
|
||||
class AuditRepository:
|
||||
"""审计日志数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_log(
|
||||
self,
|
||||
action: str,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
) -> AuditLogResponse:
|
||||
log = models.AuditLog(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(log)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
return AuditLogResponse.model_validate(log)
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> List[AuditLogResponse]:
|
||||
conditions = []
|
||||
if user_id is not None:
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = (
|
||||
select(models.AuditLog)
|
||||
.where(*conditions)
|
||||
.order_by(models.AuditLog.timestamp.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
AuditLogResponse.model_validate(log)
|
||||
for log in result.scalars().all()
|
||||
]
|
||||
|
||||
async def get_log_count(
|
||||
self,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
) -> int:
|
||||
conditions = []
|
||||
if user_id is not None:
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
|
||||
result = await self.session.execute(stmt)
|
||||
return int(result.scalar() or 0)
|
||||
197
app/infra/repositories/metadata_repository.py
Normal file
197
app/infra/repositories/metadata_repository.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.encryption import (
|
||||
get_database_encryptor,
|
||||
get_encryptor,
|
||||
is_database_encryption_configured,
|
||||
is_encryption_configured,
|
||||
)
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
|
||||
def _normalize_postgres_dsn(dsn: str) -> str:
|
||||
if not dsn or "://" not in dsn:
|
||||
return dsn
|
||||
scheme, rest = dsn.split("://", 1)
|
||||
if scheme not in ("postgresql", "postgres", "postgresql+psycopg"):
|
||||
return dsn
|
||||
if "@" not in rest:
|
||||
return dsn
|
||||
userinfo, hostinfo = rest.rsplit("@", 1)
|
||||
if ":" not in userinfo:
|
||||
return dsn
|
||||
username, password = userinfo.split(":", 1)
|
||||
if "@" not in password:
|
||||
return dsn
|
||||
password = password.replace("@", "%40")
|
||||
return f"{scheme}://{username}:{password}@{hostinfo}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectDbRouting:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
db_type: str
|
||||
dsn: str
|
||||
pool_min_size: int
|
||||
pool_max_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectGeoServerInfo:
|
||||
project_id: UUID
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_admin_password: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectSummary:
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
|
||||
|
||||
class MetadataRepository:
|
||||
"""元数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_user_by_keycloak_id(self, keycloak_id: UUID) -> Optional[models.User]:
|
||||
result = await self.session.execute(
|
||||
select(models.User).where(models.User.keycloak_id == keycloak_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).where(models.Project.id == project_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_membership_role(
|
||||
self, project_id: UUID, user_id: UUID
|
||||
) -> Optional[str]:
|
||||
result = await self.session.execute(
|
||||
select(models.UserProjectMembership.project_role).where(
|
||||
models.UserProjectMembership.project_id == project_id,
|
||||
models.UserProjectMembership.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_db_routing(
|
||||
self, project_id: UUID, db_role: str
|
||||
) -> Optional[ProjectDbRouting]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectDatabase).where(
|
||||
models.ProjectDatabase.project_id == project_id,
|
||||
models.ProjectDatabase.db_role == db_role,
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if not is_database_encryption_configured():
|
||||
raise ValueError("DATABASE_ENCRYPTION_KEY is not configured")
|
||||
encryptor = get_database_encryptor()
|
||||
try:
|
||||
dsn = encryptor.decrypt(record.dsn_encrypted)
|
||||
except InvalidToken:
|
||||
raise ValueError(
|
||||
"Failed to decrypt project DB DSN: DATABASE_ENCRYPTION_KEY mismatch "
|
||||
"or invalid dsn_encrypted value"
|
||||
)
|
||||
dsn = _normalize_postgres_dsn(dsn)
|
||||
return ProjectDbRouting(
|
||||
project_id=record.project_id,
|
||||
db_role=record.db_role,
|
||||
db_type=record.db_type,
|
||||
dsn=dsn,
|
||||
pool_min_size=record.pool_min_size,
|
||||
pool_max_size=record.pool_max_size,
|
||||
)
|
||||
|
||||
async def get_geoserver_config(
|
||||
self, project_id: UUID
|
||||
) -> Optional[ProjectGeoServerInfo]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectGeoServerConfig).where(
|
||||
models.ProjectGeoServerConfig.project_id == project_id
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
if record.gs_admin_password_encrypted:
|
||||
if is_encryption_configured():
|
||||
encryptor = get_encryptor()
|
||||
password = encryptor.decrypt(record.gs_admin_password_encrypted)
|
||||
else:
|
||||
password = record.gs_admin_password_encrypted
|
||||
else:
|
||||
password = None
|
||||
return ProjectGeoServerInfo(
|
||||
project_id=record.project_id,
|
||||
gs_base_url=record.gs_base_url,
|
||||
gs_admin_user=record.gs_admin_user,
|
||||
gs_admin_password=password,
|
||||
gs_datastore_name=record.gs_datastore_name,
|
||||
default_extent=record.default_extent,
|
||||
srid=record.srid,
|
||||
)
|
||||
|
||||
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
|
||||
stmt = (
|
||||
select(models.Project, models.UserProjectMembership.project_role)
|
||||
.join(
|
||||
models.UserProjectMembership,
|
||||
models.UserProjectMembership.project_id == models.Project.id,
|
||||
)
|
||||
.where(models.UserProjectMembership.user_id == user_id)
|
||||
.order_by(models.Project.name)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=role,
|
||||
)
|
||||
for project, role in result.all()
|
||||
]
|
||||
|
||||
async def list_all_projects(self) -> List[ProjectSummary]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).order_by(models.Project.name)
|
||||
)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role="owner",
|
||||
)
|
||||
for project in result.scalars().all()
|
||||
]
|
||||
235
app/infra/repositories/user_repository.py
Normal file
235
app/infra/repositories/user_repository.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.domain.schemas.user import UserCreate, UserUpdate, UserInDB
|
||||
from app.domain.models.role import UserRole
|
||||
from app.core.security import get_password_hash
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserRepository:
|
||||
"""用户数据访问层"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
async def create_user(self, user: UserCreate) -> Optional[UserInDB]:
|
||||
"""
|
||||
创建新用户
|
||||
|
||||
Args:
|
||||
user: 用户创建数据
|
||||
|
||||
Returns:
|
||||
创建的用户对象
|
||||
"""
|
||||
hashed_password = get_password_hash(user.password)
|
||||
|
||||
query = """
|
||||
INSERT INTO users (username, email, hashed_password, role, is_active, is_superuser)
|
||||
VALUES (%(username)s, %(email)s, %(hashed_password)s, %(role)s, TRUE, FALSE)
|
||||
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {
|
||||
'username': user.username,
|
||||
'email': user.email,
|
||||
'hashed_password': hashed_password,
|
||||
'role': user.role.value
|
||||
})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, user_id: int) -> Optional[UserInDB]:
|
||||
"""根据ID获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE id = %(user_id)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'user_id': user_id})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[UserInDB]:
|
||||
"""根据用户名获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE username = %(username)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'username': username})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[UserInDB]:
|
||||
"""根据邮箱获取用户"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
WHERE email = %(email)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'email': email})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
|
||||
return None
|
||||
|
||||
async def get_all_users(self, skip: int = 0, limit: int = 100) -> List[UserInDB]:
|
||||
"""获取所有用户(分页)"""
|
||||
query = """
|
||||
SELECT id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %(limit)s OFFSET %(skip)s
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'skip': skip, 'limit': limit})
|
||||
rows = await cur.fetchall()
|
||||
return [UserInDB(**row) for row in rows]
|
||||
|
||||
async def update_user(self, user_id: int, user_update: UserUpdate) -> Optional[UserInDB]:
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
user_update: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的用户对象
|
||||
"""
|
||||
# 构建动态更新语句
|
||||
update_fields = []
|
||||
params = {'user_id': user_id}
|
||||
|
||||
if user_update.email is not None:
|
||||
update_fields.append("email = %(email)s")
|
||||
params['email'] = user_update.email
|
||||
|
||||
if user_update.password is not None:
|
||||
update_fields.append("hashed_password = %(hashed_password)s")
|
||||
params['hashed_password'] = get_password_hash(user_update.password)
|
||||
|
||||
if user_update.role is not None:
|
||||
update_fields.append("role = %(role)s")
|
||||
params['role'] = user_update.role.value
|
||||
|
||||
if user_update.is_active is not None:
|
||||
update_fields.append("is_active = %(is_active)s")
|
||||
params['is_active'] = user_update.is_active
|
||||
|
||||
if not update_fields:
|
||||
return await self.get_user_by_id(user_id)
|
||||
|
||||
query = f"""
|
||||
UPDATE users
|
||||
SET {', '.join(update_fields)}, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = %(user_id)s
|
||||
RETURNING id, username, email, hashed_password, role, is_active, is_superuser,
|
||||
created_at, updated_at
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return UserInDB(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
async def delete_user(self, user_id: int) -> bool:
|
||||
"""
|
||||
删除用户
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
query = "DELETE FROM users WHERE id = %(user_id)s"
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {'user_id': user_id})
|
||||
return cur.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def user_exists(self, username: str = None, email: str = None) -> bool:
|
||||
"""
|
||||
检查用户是否存在
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
email: 邮箱
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
if email:
|
||||
conditions.append("email = %(email)s")
|
||||
params['email'] = email
|
||||
|
||||
if not conditions:
|
||||
return False
|
||||
|
||||
query = f"""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM users WHERE {' OR '.join(conditions)}
|
||||
)
|
||||
"""
|
||||
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
result = await cur.fetchone()
|
||||
return result['exists'] if result else False
|
||||
4276
app/main.py
4276
app/main.py
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,34 @@
|
||||
from app.services.network_import import network_update, submit_scada_info
|
||||
from app.services.scheme_management import (
|
||||
create_user,
|
||||
delete_user,
|
||||
scheme_name_exists,
|
||||
store_scheme_info,
|
||||
delete_scheme_info,
|
||||
query_scheme_list,
|
||||
upload_shp_to_pg,
|
||||
submit_risk_probability_result,
|
||||
)
|
||||
from app.services.valve_isolation import analyze_valve_isolation
|
||||
from app.services.simulation_ops import (
|
||||
project_management,
|
||||
scheduling_simulation,
|
||||
daily_scheduling_simulation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"network_update",
|
||||
"submit_scada_info",
|
||||
"create_user",
|
||||
"delete_user",
|
||||
"scheme_name_exists",
|
||||
"store_scheme_info",
|
||||
"delete_scheme_info",
|
||||
"query_scheme_list",
|
||||
"upload_shp_to_pg",
|
||||
"submit_risk_probability_result",
|
||||
"project_management",
|
||||
"scheduling_simulation",
|
||||
"daily_scheduling_simulation",
|
||||
"analyze_valve_isolation",
|
||||
]
|
||||
|
||||
@@ -30,11 +30,11 @@ class Output:
|
||||
|
||||
if platform.system() == "Windows":
|
||||
self._lib = ctypes.CDLL(
|
||||
os.path.join(os.getcwd(), "epanet", "epanet-output.dll")
|
||||
os.path.join(os.path.dirname(__file__), "windows", "epanet-output.dll")
|
||||
)
|
||||
else:
|
||||
self._lib = ctypes.CDLL(
|
||||
os.path.join(os.getcwd(), "epanet", "linux", "libepanet-output.so")
|
||||
os.path.join(os.path.dirname(__file__), "linux", "libepanet-output.so")
|
||||
)
|
||||
|
||||
self._handle = ctypes.c_void_p()
|
||||
@@ -314,9 +314,9 @@ def run_project_return_dict(name: str, readable_output: bool = False) -> dict[st
|
||||
|
||||
input = name + ".db"
|
||||
if platform.system() == "Windows":
|
||||
exe = os.path.join(os.path.join(dir, "epanet"), "runepanet.exe")
|
||||
exe = os.path.join(os.path.dirname(__file__), "windows", "runepanet.exe")
|
||||
else:
|
||||
exe = os.path.join(os.path.join(dir, "epanet"), "linux", "runepanet")
|
||||
exe = os.path.join(os.path.dirname(__file__), "linux", "runepanet")
|
||||
inp = os.path.join(os.path.join(dir, "db_inp"), input + ".inp")
|
||||
rpt = os.path.join(os.path.join(dir, "temp"), input + ".rpt")
|
||||
opt = os.path.join(os.path.join(dir, "temp"), input + ".opt")
|
||||
@@ -364,9 +364,9 @@ def run_project(name: str, readable_output: bool = False) -> str:
|
||||
|
||||
input = name + ".db"
|
||||
if platform.system() == "Windows":
|
||||
exe = os.path.join(os.path.join(dir, "epanet"), "runepanet.exe")
|
||||
exe = os.path.join(os.path.dirname(__file__), "windows", "runepanet.exe")
|
||||
else:
|
||||
exe = os.path.join(os.path.join(dir, "epanet"), "linux", "runepanet")
|
||||
exe = os.path.join(os.path.dirname(__file__), "linux", "runepanet")
|
||||
inp = os.path.join(os.path.join(dir, "db_inp"), input + ".inp")
|
||||
rpt = os.path.join(os.path.join(dir, "temp"), input + ".rpt")
|
||||
opt = os.path.join(os.path.join(dir, "temp"), input + ".opt")
|
||||
@@ -416,9 +416,9 @@ def run_inp(name: str) -> str:
|
||||
dir = os.path.abspath(os.getcwd())
|
||||
|
||||
if platform.system() == "Windows":
|
||||
exe = os.path.join(os.path.join(dir, "epanet"), "runepanet.exe")
|
||||
exe = os.path.join(os.path.dirname(__file__), "windows", "runepanet.exe")
|
||||
else:
|
||||
exe = os.path.join(os.path.join(dir, "epanet"), "linux", "runepanet")
|
||||
exe = os.path.join(os.path.dirname(__file__), "linux", "runepanet")
|
||||
inp = os.path.join(os.path.join(dir, "inp"), name + ".inp")
|
||||
rpt = os.path.join(os.path.join(dir, "temp"), name + ".rpt")
|
||||
opt = os.path.join(os.path.join(dir, "temp"), name + ".opt")
|
||||
|
||||
197
app/services/network_import.py
Normal file
197
app/services/network_import.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import csv
|
||||
import os
|
||||
|
||||
import chardet
|
||||
import psycopg
|
||||
from psycopg import sql
|
||||
|
||||
import app.services.project_info as project_info
|
||||
from app.native.api.postgresql_info import get_pgconn_string
|
||||
from app.services.tjnetwork import read_inp
|
||||
|
||||
|
||||
############################################################
|
||||
# network_update 10
|
||||
############################################################
|
||||
|
||||
|
||||
def network_update(file_path: str) -> None:
|
||||
"""
|
||||
更新pg数据库中的inp文件
|
||||
:param file_path: inp文件
|
||||
:return:
|
||||
"""
|
||||
read_inp("szh", file_path)
|
||||
|
||||
csv_path = "./history_pattern_flow.csv"
|
||||
|
||||
# # 检查文件是否存在
|
||||
# if os.path.exists(csv_path):
|
||||
# print(f"history_patterns_flows文件存在,开始处理...")
|
||||
#
|
||||
# # 读取 CSV 文件
|
||||
# df = pd.read_csv(csv_path)
|
||||
#
|
||||
# # 连接到 PostgreSQL 数据库(这里是数据库 "bb")
|
||||
# with psycopg.connect("dbname=bb host=127.0.0.1") as conn:
|
||||
# with conn.cursor() as cur:
|
||||
# for index, row in df.iterrows():
|
||||
# # 直接将数据插入,不进行唯一性检查
|
||||
# insert_sql = sql.SQL("""
|
||||
# INSERT INTO history_patterns_flows (id, factor, flow)
|
||||
# VALUES (%s, %s, %s);
|
||||
# """)
|
||||
# # 将数据插入数据库
|
||||
# cur.execute(insert_sql, (row['id'], row['factor'], row['flow']))
|
||||
# conn.commit()
|
||||
# print("数据成功导入到 'history_patterns_flows' 表格。")
|
||||
# else:
|
||||
# print(f"history_patterns_flows文件不存在。")
|
||||
# 检查文件是否存在
|
||||
if os.path.exists(csv_path):
|
||||
print(f"history_patterns_flows文件存在,开始处理...")
|
||||
|
||||
# 连接到 PostgreSQL 数据库(这里是数据库 "bb")
|
||||
with psycopg.connect(f"dbname={project_info.name} host=127.0.0.1") as conn:
|
||||
with conn.cursor() as cur:
|
||||
with open(csv_path, newline="", encoding="utf-8-sig") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
# 直接将数据插入,不进行唯一性检查
|
||||
insert_sql = sql.SQL(
|
||||
"""
|
||||
INSERT INTO history_patterns_flows (id, factor, flow)
|
||||
VALUES (%s, %s, %s);
|
||||
"""
|
||||
)
|
||||
# 将数据插入数据库
|
||||
cur.execute(insert_sql, (row["id"], row["factor"], row["flow"]))
|
||||
conn.commit()
|
||||
print("数据成功导入到 'history_patterns_flows' 表格。")
|
||||
else:
|
||||
print(f"history_patterns_flows文件不存在。")
|
||||
|
||||
|
||||
def submit_scada_info(name: str, coord_id: str) -> None:
|
||||
"""
|
||||
将scada信息表导入pg数据库
|
||||
:param name: 项目名称(数据库名称)
|
||||
:param coord_id: 坐标系的id,如4326,根据原始坐标信息输入
|
||||
:return:
|
||||
"""
|
||||
scada_info_path = "./scada_info.csv"
|
||||
# 检查文件是否存在
|
||||
if os.path.exists(scada_info_path):
|
||||
print(f"scada_info文件存在,开始处理...")
|
||||
|
||||
# 自动检测文件编码
|
||||
with open(scada_info_path, "rb") as file:
|
||||
raw_data = file.read()
|
||||
detected = chardet.detect(raw_data)
|
||||
file_encoding = detected["encoding"]
|
||||
print(f"检测到的文件编码:{file_encoding}")
|
||||
try:
|
||||
# 动态替换数据库名称
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
|
||||
# 连接到 PostgreSQL 数据库(这里是数据库 "bb")
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 检查 scada_info 表是否为空
|
||||
cur.execute("SELECT COUNT(*) FROM scada_info;")
|
||||
count = cur.fetchone()[0]
|
||||
|
||||
if count > 0:
|
||||
print("scada_info表中已有数据,正在清空记录...")
|
||||
cur.execute("DELETE FROM scada_info;")
|
||||
print("表记录已清空。")
|
||||
|
||||
with open(
|
||||
scada_info_path, newline="", encoding=file_encoding
|
||||
) as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
# 将CSV单元格值为空的字段转换为 None
|
||||
cleaned_row = {
|
||||
key: (value if value.strip() else None)
|
||||
for key, value in row.items()
|
||||
}
|
||||
|
||||
# 处理 associated_source_outflow_id 列动态变化
|
||||
associated_columns = [
|
||||
f"associated_source_outflow_id{i}" for i in range(1, 21)
|
||||
]
|
||||
associated_values = [
|
||||
(
|
||||
cleaned_row.get(col).strip()
|
||||
if cleaned_row.get(col)
|
||||
and cleaned_row.get(col).strip()
|
||||
else None
|
||||
)
|
||||
for col in associated_columns
|
||||
]
|
||||
|
||||
# 将 X_coor 和 Y_coor 转换为 geometry 类型
|
||||
x_coor = (
|
||||
float(cleaned_row["X_coor"])
|
||||
if cleaned_row["X_coor"]
|
||||
else None
|
||||
)
|
||||
y_coor = (
|
||||
float(cleaned_row["Y_coor"])
|
||||
if cleaned_row["Y_coor"]
|
||||
else None
|
||||
)
|
||||
coord = (
|
||||
f"SRID={coord_id};POINT({x_coor} {y_coor})"
|
||||
if x_coor and y_coor
|
||||
else None
|
||||
)
|
||||
|
||||
# 准备插入 SQL 语句
|
||||
insert_sql = sql.SQL(
|
||||
"""
|
||||
INSERT INTO scada_info (
|
||||
id, type, associated_element_id, associated_pattern,
|
||||
associated_pipe_flow_id, {associated_columns},
|
||||
API_query_id, transmission_mode, transmission_frequency,
|
||||
reliability, X_coor, Y_coor, coord
|
||||
)
|
||||
VALUES (
|
||||
%s, %s, %s, %s, %s, {associated_placeholders},
|
||||
%s, %s, %s, %s, %s, %s, %s
|
||||
);
|
||||
"""
|
||||
).format(
|
||||
associated_columns=sql.SQL(", ").join(
|
||||
sql.Identifier(col) for col in associated_columns
|
||||
),
|
||||
associated_placeholders=sql.SQL(", ").join(
|
||||
sql.Placeholder() for _ in associated_columns
|
||||
),
|
||||
)
|
||||
# 将数据插入数据库
|
||||
cur.execute(
|
||||
insert_sql,
|
||||
(
|
||||
cleaned_row["id"],
|
||||
cleaned_row["type"],
|
||||
cleaned_row["associated_element_id"],
|
||||
cleaned_row.get("associated_pattern"),
|
||||
cleaned_row.get("associated_pipe_flow_id"),
|
||||
*associated_values,
|
||||
cleaned_row.get("API_query_id"),
|
||||
cleaned_row["transmission_mode"],
|
||||
cleaned_row["transmission_frequency"],
|
||||
cleaned_row["reliability"],
|
||||
x_coor,
|
||||
y_coor,
|
||||
coord,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
print("数据成功导入到 'scada_info' 表格。")
|
||||
except Exception as e:
|
||||
print(f"导入时出错:{e}")
|
||||
else:
|
||||
print(f"scada_info文件不存在。")
|
||||
@@ -1 +1,4 @@
|
||||
name='szh'
|
||||
import os
|
||||
|
||||
# 从环境变量 NETWORK_NAME 读取
|
||||
name = os.getenv("NETWORK_NAME")
|
||||
|
||||
266
app/services/scheme_management.py
Normal file
266
app/services/scheme_management.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import ast
|
||||
import json
|
||||
|
||||
import geopandas as gpd
|
||||
import pandas as pd
|
||||
import psycopg
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from app.native.api.postgresql_info import get_pgconn_string
|
||||
|
||||
|
||||
# 2025/03/23
|
||||
def create_user(name: str, username: str, password: str):
|
||||
"""
|
||||
创建用户
|
||||
:param name: 数据库名称
|
||||
:param username: 用户名
|
||||
:param password: 密码
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# 动态替换数据库名称
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
# 连接到 PostgreSQL 数据库(这里是数据库 "bb")
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO users (username, password) VALUES (%s, %s)",
|
||||
(username, password),
|
||||
)
|
||||
# 提交事务
|
||||
conn.commit()
|
||||
print("新用户创建成功!")
|
||||
except Exception as e:
|
||||
print(f"创建用户出错:{e}")
|
||||
|
||||
|
||||
# 2025/03/23
|
||||
def delete_user(name: str, username: str):
|
||||
"""
|
||||
删除用户
|
||||
:param name: 数据库名称
|
||||
:param username: 用户名
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# 动态替换数据库名称
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
# 连接到 PostgreSQL 数据库(这里是数据库 "bb")
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("DELETE FROM users WHERE username = %s", (username,))
|
||||
conn.commit()
|
||||
print(f"用户 {username} 删除成功!")
|
||||
except Exception as e:
|
||||
print(f"删除用户出错:{e}")
|
||||
|
||||
|
||||
# 2025/03/23
|
||||
def scheme_name_exists(name: str, scheme_name: str) -> bool:
|
||||
"""
|
||||
判断传入的 scheme_name 是否已存在于 scheme_list 表中,用于输入框判断
|
||||
:param name: 数据库名称
|
||||
:param scheme_name: 需要判断的方案名称
|
||||
:return: 如果存在返回 True,否则返回 False
|
||||
"""
|
||||
try:
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT COUNT(*) FROM scheme_list WHERE scheme_name = %s",
|
||||
(scheme_name,),
|
||||
)
|
||||
result = cur.fetchone()
|
||||
if result is not None and result[0] > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"查询 scheme_name 时出错:{e}")
|
||||
return False
|
||||
|
||||
|
||||
# 2025/03/23
|
||||
def store_scheme_info(
|
||||
name: str,
|
||||
scheme_name: str,
|
||||
scheme_type: str,
|
||||
username: str,
|
||||
scheme_start_time: str,
|
||||
scheme_detail: dict,
|
||||
):
|
||||
"""
|
||||
将一条方案记录插入 scheme_list 表中
|
||||
:param name: 数据库名称
|
||||
:param scheme_name: 方案名称
|
||||
:param scheme_type: 方案类型
|
||||
:param username: 用户名(需在 users 表中已存在)
|
||||
:param scheme_start_time: 方案起始时间(字符串)
|
||||
:param scheme_detail: 方案详情(字典,会转换为 JSON)
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
sql = """
|
||||
INSERT INTO scheme_list (scheme_name, scheme_type, username, scheme_start_time, scheme_detail)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
# 将字典转换为 JSON 字符串
|
||||
scheme_detail_json = json.dumps(scheme_detail)
|
||||
cur.execute(
|
||||
sql,
|
||||
(
|
||||
scheme_name,
|
||||
scheme_type,
|
||||
username,
|
||||
scheme_start_time,
|
||||
scheme_detail_json,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
print("方案信息存储成功!")
|
||||
except Exception as e:
|
||||
print(f"存储方案信息时出错:{e}")
|
||||
|
||||
|
||||
# 2025/03/23
|
||||
def delete_scheme_info(name: str, scheme_name: str) -> None:
|
||||
"""
|
||||
从 scheme_list 表中删除指定的方案
|
||||
:param name: 数据库名称
|
||||
:param scheme_name: 要删除的方案名称
|
||||
"""
|
||||
try:
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 使用参数化查询删除方案记录
|
||||
cur.execute(
|
||||
"DELETE FROM scheme_list WHERE scheme_name = %s", (scheme_name,)
|
||||
)
|
||||
conn.commit()
|
||||
print(f"方案 {scheme_name} 删除成功!")
|
||||
except Exception as e:
|
||||
print(f"删除方案时出错:{e}")
|
||||
|
||||
|
||||
# 2025/03/23
|
||||
def query_scheme_list(name: str) -> list:
|
||||
"""
|
||||
查询pg数据库中的scheme_list,按照 create_time 降序排列,离现在时间最近的记录排在最前面
|
||||
:param name: 项目名称(数据库名称)
|
||||
:return: 返回查询结果的所有行
|
||||
"""
|
||||
try:
|
||||
# 动态替换数据库名称
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
# 连接到 PostgreSQL 数据库(这里是数据库 "bb")
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 按 create_time 降序排列
|
||||
cur.execute("SELECT * FROM scheme_list ORDER BY create_time DESC")
|
||||
rows = cur.fetchall()
|
||||
return rows
|
||||
|
||||
except Exception as e:
|
||||
print(f"查询错误:{e}")
|
||||
|
||||
|
||||
# 2025/03/23
|
||||
def upload_shp_to_pg(name: str, table_name: str, role: str, shp_file_path: str):
|
||||
"""
|
||||
将 Shapefile 文件上传到 PostgreSQL 数据库
|
||||
:param name: 项目名称(数据库名称)
|
||||
:param table_name: 创建表的名字
|
||||
:param role: 数据库角色名,位于c盘user中查看
|
||||
:param shp_file_path: shp文件的路径
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# 动态连接到指定的数据库
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
# 读取 Shapefile 文件
|
||||
gdf = gpd.read_file(shp_file_path)
|
||||
|
||||
# 检查投影坐标系(CRS),并确保是 EPSG:4326
|
||||
if gdf.crs.to_string() != "EPSG:4490":
|
||||
gdf = gdf.to_crs(epsg=4490)
|
||||
|
||||
# 使用 GeoDataFrame 的 .to_postgis 方法将数据写入 PostgreSQL
|
||||
# 需要在数据库中提前安装 PostGIS 扩展
|
||||
engine = create_engine(f"postgresql+psycopg2://{role}:@127.0.0.1/{name}")
|
||||
gdf.to_postgis(
|
||||
table_name, engine, if_exists="replace", index=True, index_label="id"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Shapefile 文件成功上传到 PostgreSQL 数据库 '{name}' 的表 '{table_name}'."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"上传 Shapefile 到 PostgreSQL 时出错:{e}")
|
||||
|
||||
|
||||
def submit_risk_probability_result(name: str, result_file_path: str) -> None:
|
||||
"""
|
||||
将管网风险评估结果导入pg数据库
|
||||
:param name: 项目名称(数据库名称)
|
||||
:param result_file_path: 结果文件路径
|
||||
:return:
|
||||
"""
|
||||
# 自动检测文件编码
|
||||
# with open({result_file_path}, 'rb') as file:
|
||||
# raw_data = file.read()
|
||||
# detected = chardet.detect(raw_data)
|
||||
# file_encoding = detected['encoding']
|
||||
# print(f"检测到的文件编码:{file_encoding}")
|
||||
|
||||
try:
|
||||
# 动态替换数据库名称
|
||||
conn_string = get_pgconn_string(db_name=name)
|
||||
|
||||
# 连接到 PostgreSQL 数据库
|
||||
with psycopg.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 检查 scada_info 表是否为空
|
||||
cur.execute("SELECT COUNT(*) FROM pipe_risk_probability;")
|
||||
count = cur.fetchone()[0]
|
||||
|
||||
if count > 0:
|
||||
print("pipe_risk_probability表中已有数据,正在清空记录...")
|
||||
cur.execute("DELETE FROM pipe_risk_probability;")
|
||||
print("表记录已清空。")
|
||||
|
||||
# 读取Excel并转换x/y列为列表
|
||||
df = pd.read_excel(result_file_path, sheet_name="Sheet1")
|
||||
df["x"] = df["x"].apply(ast.literal_eval)
|
||||
df["y"] = df["y"].apply(ast.literal_eval)
|
||||
|
||||
# 批量插入数据
|
||||
for index, row in df.iterrows():
|
||||
insert_query = """
|
||||
INSERT INTO pipe_risk_probability
|
||||
(pipeID, pipeage, risk_probability_now, x, y)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
cur.execute(
|
||||
insert_query,
|
||||
(
|
||||
row["pipeID"],
|
||||
row["pipeage"],
|
||||
row["risk_probability_now"],
|
||||
row["x"], # 直接传递列表
|
||||
row["y"], # 同上
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
print("风险评估结果导入成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"导入时出错:{e}")
|
||||
@@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from tjnetwork import *
|
||||
from app.services.tjnetwork import *
|
||||
from app.native.api.s36_wda_cal import *
|
||||
|
||||
# from get_real_status import *
|
||||
@@ -11,7 +11,7 @@ import pytz
|
||||
import requests
|
||||
import time
|
||||
import shutil
|
||||
from epanet.epanet import Output
|
||||
from app.services.epanet.epanet import Output
|
||||
from typing import Optional, Tuple
|
||||
import app.infra.db.influxdb.api as influxdb_api
|
||||
import typing
|
||||
@@ -21,8 +21,12 @@ import app.services.globals as globals
|
||||
import uuid
|
||||
import app.services.project_info as project_info
|
||||
from app.native.api.postgresql_info import get_pgconn_string
|
||||
from timescaledb.internal_queries import InternalQueries as TimescaleInternalQueries
|
||||
from timescaledb.internal_queries import InternalStorage as TimescaleInternalStorage
|
||||
from app.infra.db.timescaledb.internal_queries import (
|
||||
InternalQueries as TimescaleInternalQueries,
|
||||
)
|
||||
from app.infra.db.timescaledb.internal_queries import (
|
||||
InternalStorage as TimescaleInternalStorage,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
@@ -679,8 +683,8 @@ def run_simulation(
|
||||
modify_fixed_pump_pattern: dict[str, list] = None,
|
||||
modify_variable_pump_pattern: dict[str, list] = None,
|
||||
modify_valve_opening: dict[str, float] = None,
|
||||
scheme_Type: str = None,
|
||||
scheme_Name: str = None,
|
||||
scheme_type: str = None,
|
||||
scheme_name: str = None,
|
||||
) -> None:
|
||||
"""
|
||||
传入需要修改的参数,改变数据库中对应位置的值,然后计算,返回结果
|
||||
@@ -695,8 +699,8 @@ def run_simulation(
|
||||
:param modify_fixed_pump_pattern: dict中包含多个水泵模式,str为工频水泵的id,list为修改后的pattern
|
||||
:param modify_variable_pump_pattern: dict中包含多个水泵模式,str为变频水泵的id,list为修改后的pattern
|
||||
:param modify_valve_opening: dict中包含多个阀门开启度,str为阀门的id,float为修改后的阀门开启度
|
||||
:param scheme_Type: 模拟方案类型
|
||||
:param scheme_Name:模拟方案名称
|
||||
:param scheme_type: 模拟方案类型
|
||||
:param scheme_name:模拟方案名称
|
||||
:return:
|
||||
"""
|
||||
# 记录开始时间
|
||||
@@ -1186,12 +1190,12 @@ def run_simulation(
|
||||
if modify_valve_opening[valve_name] == 0:
|
||||
valve_status["status"] = "CLOSED"
|
||||
valve_status["setting"] = 0
|
||||
if modify_valve_opening[valve_name] < 1:
|
||||
elif modify_valve_opening[valve_name] < 1:
|
||||
valve_status["status"] = "OPEN"
|
||||
valve_status["setting"] = 0.1036 * pow(
|
||||
modify_valve_opening[valve_name], -3.105
|
||||
)
|
||||
if modify_valve_opening[valve_name] == 1:
|
||||
elif modify_valve_opening[valve_name] == 1:
|
||||
valve_status["status"] = "OPEN"
|
||||
valve_status["setting"] = 0
|
||||
cs = ChangeSet()
|
||||
@@ -1231,21 +1235,22 @@ def run_simulation(
|
||||
starttime = time.time()
|
||||
if simulation_type.upper() == "REALTIME":
|
||||
TimescaleInternalStorage.store_realtime_simulation(
|
||||
node_result, link_result, modify_pattern_start_time
|
||||
node_result, link_result, modify_pattern_start_time, db_name=name
|
||||
)
|
||||
elif simulation_type.upper() == "EXTENDED":
|
||||
TimescaleInternalStorage.store_scheme_simulation(
|
||||
scheme_Type,
|
||||
scheme_Name,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
node_result,
|
||||
link_result,
|
||||
modify_pattern_start_time,
|
||||
num_periods_result,
|
||||
db_name=name,
|
||||
)
|
||||
endtime = time.time()
|
||||
logging.info("store time: %f", endtime - starttime)
|
||||
# 暂不需要再次存储 SCADA 模拟信息
|
||||
# TimescaleInternalQueries.fill_scheme_simulation_result_to_SCADA(scheme_Type=scheme_Type, scheme_Name=scheme_Name)
|
||||
# TimescaleInternalQueries.fill_scheme_simulation_result_to_SCADA(scheme_type=scheme_type, scheme_name=scheme_name)
|
||||
|
||||
# if simulation_type.upper() == "REALTIME":
|
||||
# influxdb_api.store_realtime_simulation_result_to_influxdb(
|
||||
@@ -1257,11 +1262,11 @@ def run_simulation(
|
||||
# link_result,
|
||||
# modify_pattern_start_time,
|
||||
# num_periods_result,
|
||||
# scheme_Type,
|
||||
# scheme_Name,
|
||||
# scheme_type,
|
||||
# scheme_name,
|
||||
# )
|
||||
# 暂不需要再次存储 SCADA 模拟信息
|
||||
# influxdb_api.fill_scheme_simulation_result_to_SCADA(scheme_Type=scheme_Type, scheme_Name=scheme_Name)
|
||||
# influxdb_api.fill_scheme_simulation_result_to_SCADA(scheme_type=scheme_type, scheme_name=scheme_name)
|
||||
|
||||
print("after store result")
|
||||
|
||||
@@ -1341,7 +1346,7 @@ if __name__ == "__main__":
|
||||
# run_simulation(name='bb', simulation_type="realtime", modify_pattern_start_time='2025-02-25T23:45:00+08:00')
|
||||
# 模拟示例2
|
||||
# run_simulation(name='bb', simulation_type="extended", modify_pattern_start_time='2025-03-10T12:00:00+08:00',
|
||||
# modify_total_duration=1800, scheme_Type="burst_Analysis", scheme_Name="scheme1")
|
||||
# modify_total_duration=1800, scheme_type="burst_Analysis", scheme_name="scheme1")
|
||||
|
||||
# 查询示例1:query_SCADA_ID_corresponding_info
|
||||
# result = query_SCADA_ID_corresponding_info(name='bb', SCADA_ID='P10755')
|
||||
|
||||
233
app/services/simulation_ops.py
Normal file
233
app/services/simulation_ops.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from math import pi
|
||||
|
||||
import pytz
|
||||
|
||||
from app.algorithms.api_ex.run_simulation import run_simulation_ex
|
||||
from app.native.api.project import copy_project
|
||||
from app.services.epanet.epanet import Output
|
||||
from app.services.tjnetwork import *
|
||||
|
||||
|
||||
############################################################
|
||||
# project management 07 ***暂时不使用,与业务需求无关***
|
||||
############################################################
|
||||
|
||||
|
||||
def project_management(
|
||||
prj_name,
|
||||
start_datetime,
|
||||
pump_control,
|
||||
tank_initial_level_control=None,
|
||||
region_demand_control=None,
|
||||
) -> str:
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"project_management_{prj_name}"
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# if is_project_open(prj_name):
|
||||
# close_project(prj_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(prj_name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(prj_name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
result = run_simulation_ex(
|
||||
name=new_name,
|
||||
simulation_type="realtime",
|
||||
start_datetime=start_datetime,
|
||||
duration=86400,
|
||||
pump_control=pump_control,
|
||||
tank_initial_level_control=tank_initial_level_control,
|
||||
region_demand_control=region_demand_control,
|
||||
downloading_prohibition=True,
|
||||
)
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
return result
|
||||
|
||||
|
||||
############################################################
|
||||
# scheduling analysis 08 ***暂时不使用,与业务需求无关***
|
||||
############################################################
|
||||
|
||||
|
||||
def scheduling_simulation(
|
||||
prj_name, start_time, pump_control, tank_id, water_plant_output_id, time_delta=300
|
||||
) -> str:
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"scheduling_{prj_name}"
|
||||
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# if is_project_open(prj_name):
|
||||
# close_project(prj_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(prj_name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(prj_name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
|
||||
run_simulation_ex(
|
||||
new_name, "realtime", start_time, duration=0, pump_control=pump_control
|
||||
)
|
||||
|
||||
if not is_project_open(new_name):
|
||||
open_project(new_name)
|
||||
|
||||
tank = get_tank(new_name, tank_id) # 水塔信息
|
||||
tank_floor_space = pi * pow(tank["diameter"] / 2, 2) # 水塔底面积(m^2)
|
||||
tank_init_level = tank["init_level"] # 水塔初始水位(m)
|
||||
tank_pipes_id = tank["links"] # pipes list
|
||||
|
||||
tank_pipe_flow_direction = (
|
||||
{}
|
||||
) # 管道流向修正系数, 水塔为下游节点时为1, 水塔为上游节点时为-1
|
||||
for pipe_id in tank_pipes_id:
|
||||
if get_pipe(new_name, pipe_id)["node2"] == tank_id: # 水塔为下游节点
|
||||
tank_pipe_flow_direction[pipe_id] = 1
|
||||
else:
|
||||
tank_pipe_flow_direction[pipe_id] = -1
|
||||
|
||||
output = Output("./temp/{}.db.out".format(new_name))
|
||||
|
||||
node_results = (
|
||||
output.node_results()
|
||||
) # [{'node': str, 'result': [{'pressure': float}]}]
|
||||
water_plant_output_pressure = 0
|
||||
for node_result in node_results:
|
||||
if node_result["node"] == water_plant_output_id: # 水厂出水压力(m)
|
||||
water_plant_output_pressure = node_result["result"][-1]["pressure"]
|
||||
water_plant_output_pressure /= 100 # 预计水厂出水压力(Mpa)
|
||||
|
||||
pipe_results = output.link_results() # [{'link': str, 'result': [{'flow': float}]}]
|
||||
tank_inflow = 0
|
||||
for pipe_result in pipe_results:
|
||||
for pipe_id in tank_pipes_id: # 遍历与水塔相连的管道
|
||||
if pipe_result["link"] == pipe_id: # 水塔入流流量(L/s)
|
||||
tank_inflow += (
|
||||
pipe_result["result"][-1]["flow"]
|
||||
* tank_pipe_flow_direction[pipe_id]
|
||||
)
|
||||
tank_inflow /= 1000 # 水塔入流流量(m^3/s)
|
||||
tank_level_delta = tank_inflow * time_delta / tank_floor_space # 水塔水位改变值(m)
|
||||
tank_level = tank_init_level + tank_level_delta # 预计水塔水位(m)
|
||||
|
||||
simulation_results = {
|
||||
"water_plant_output_pressure": water_plant_output_pressure,
|
||||
"tank_init_level": tank_init_level,
|
||||
"tank_level": tank_level,
|
||||
}
|
||||
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
|
||||
return json.dumps(simulation_results)
|
||||
|
||||
|
||||
def daily_scheduling_simulation(
|
||||
prj_name, start_time, pump_control, reservoir_id, tank_id, water_plant_output_id
|
||||
) -> str:
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Analysis."
|
||||
)
|
||||
new_name = f"daily_scheduling_{prj_name}"
|
||||
|
||||
if have_project(new_name):
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
# if is_project_open(prj_name):
|
||||
# close_project(prj_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Copying Database."
|
||||
)
|
||||
# CopyProjectEx()(prj_name, new_name,
|
||||
# ['operation', 'current_operation', 'restore_operation', 'batch_operation', 'operation_table'])
|
||||
copy_project(prj_name + "_template", new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Start Opening Database."
|
||||
)
|
||||
open_project(new_name)
|
||||
print(
|
||||
datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
|
||||
+ " -- Database Loading OK."
|
||||
)
|
||||
|
||||
run_simulation_ex(
|
||||
new_name, "realtime", start_time, duration=86400, pump_control=pump_control
|
||||
)
|
||||
|
||||
if not is_project_open(new_name):
|
||||
open_project(new_name)
|
||||
|
||||
output = Output("./temp/{}.db.out".format(new_name))
|
||||
|
||||
node_results = (
|
||||
output.node_results()
|
||||
) # [{'node': str, 'result': [{'pressure': float, 'head': float}]}]
|
||||
water_plant_output_pressure = []
|
||||
reservoir_level = []
|
||||
tank_level = []
|
||||
for node_result in node_results:
|
||||
if node_result["node"] == water_plant_output_id:
|
||||
for result in node_result["result"]:
|
||||
water_plant_output_pressure.append(
|
||||
result["pressure"] / 100
|
||||
) # 水厂出水压力(Mpa)
|
||||
elif node_result["node"] == reservoir_id:
|
||||
for result in node_result["result"]:
|
||||
reservoir_level.append(result["head"] - 250.35) # 清水池液位(m)
|
||||
elif node_result["node"] == tank_id:
|
||||
for result in node_result["result"]:
|
||||
tank_level.append(result["pressure"]) # 调节池液位(m)
|
||||
|
||||
simulation_results = {
|
||||
"water_plant_output_pressure": water_plant_output_pressure,
|
||||
"reservoir_level": reservoir_level,
|
||||
"tank_level": tank_level,
|
||||
}
|
||||
|
||||
if is_project_open(new_name):
|
||||
close_project(new_name)
|
||||
delete_project(new_name)
|
||||
|
||||
return json.dumps(simulation_results)
|
||||
11
app/services/valve_isolation.py
Normal file
11
app/services/valve_isolation.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from app.algorithms.valve_isolation import valve_isolation_analysis
|
||||
|
||||
|
||||
def analyze_valve_isolation(
|
||||
network: str,
|
||||
accident_element: str | list[str],
|
||||
disabled_valves: list[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
return valve_isolation_analysis(network, accident_element, disabled_valves)
|
||||
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user