重构现代化 FastAPI 后端项目框架

This commit is contained in:
2026-01-21 16:50:57 +08:00
parent 9e06e68a15
commit c56f2fd1db
352 changed files with 176 additions and 70 deletions

0
app/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,289 @@
# ...existing code...
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pykalman import KalmanFilter
import os
def clean_flow_data_kf(input_csv_path: str, show_plot: bool = False) -> str:
"""
读取 input_csv_path 中的每列时间序列,使用一维 Kalman 滤波平滑并用预测值替换基于 3σ 检测出的异常点。
保存输出为:<input_filename>_cleaned.xlsx与输入同目录并返回输出文件的绝对路径。
仅保留输入文件路径作为参数(按要求)。
"""
# 读取 CSV
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
# 存储 Kalman 平滑结果
data_kf = pd.DataFrame(index=data.index, columns=data.columns)
# 平滑每一列
for col in data.columns:
observations = pd.Series(data[col].values).ffill().bfill()
if observations.isna().any():
observations = observations.fillna(observations.mean())
obs = observations.values.astype(float)
kf = KalmanFilter(
transition_matrices=[1],
observation_matrices=[1],
initial_state_mean=float(obs[0]),
initial_state_covariance=1,
observation_covariance=1,
transition_covariance=0.01,
)
# 跳过EM学习使用固定参数以提高性能
state_means, _ = kf.smooth(obs)
data_kf[col] = state_means.flatten()
# 计算残差并用IQR检测异常更稳健的方法
residuals = data - data_kf
residual_thresholds = {}
for col in data.columns:
res_values = residuals[col].dropna().values # 移除NaN以计算IQR
q1 = np.percentile(res_values, 25)
q3 = np.percentile(res_values, 75)
iqr = q3 - q1
lower_threshold = q1 - 1.5 * iqr
upper_threshold = q3 + 1.5 * iqr
residual_thresholds[col] = (lower_threshold, upper_threshold)
cleaned_data = data.copy()
anomalies_info = {}
for col in data.columns:
lower, upper = residual_thresholds[col]
sensor_residuals = residuals[col]
anomaly_mask = (sensor_residuals < lower) | (sensor_residuals > upper)
anomaly_idx = data.index[anomaly_mask.fillna(False)]
anomalies_info[col] = pd.DataFrame(
{
"Observed": data.loc[anomaly_idx, col],
"Kalman_Predicted": data_kf.loc[anomaly_idx, col],
"Residual": sensor_residuals.loc[anomaly_idx],
}
)
cleaned_data.loc[anomaly_idx, f"{col}_cleaned"] = data_kf.loc[anomaly_idx, col]
# 构造输出文件名:在输入文件名基础上加后缀 _cleaned.xlsx
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
output_filename = f"{input_base}_cleaned.xlsx"
output_path = os.path.join(input_dir, output_filename)
# 覆盖同名文件
if os.path.exists(output_path):
os.remove(output_path)
cleaned_data.to_excel(output_path, index=False)
# 可选可视化(第一个传感器)
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
if show_plot and len(data.columns) > 0:
sensor_to_plot = data.columns[0]
plt.figure(figsize=(12, 6))
plt.plot(
data.index,
data[sensor_to_plot],
label="监测值",
marker="o",
markersize=3,
alpha=0.7,
)
plt.plot(
data.index, data_kf[sensor_to_plot], label="Kalman滤波预测值", linewidth=2
)
anomaly_idx = anomalies_info[sensor_to_plot].index
if len(anomaly_idx) > 0:
plt.plot(
anomaly_idx,
data[sensor_to_plot].loc[anomaly_idx],
"ro",
markersize=8,
label="监测值异常点",
)
plt.plot(
anomaly_idx,
data_kf[sensor_to_plot].loc[anomaly_idx],
"go",
markersize=8,
label="Kalman修复值",
)
plt.xlabel("时间点(序号)")
plt.ylabel("监测值")
plt.title(f"{sensor_to_plot}观测值与Kalman滤波预测值异常点标记")
plt.legend()
plt.show()
# 返回输出文件的绝对路径
return os.path.abspath(output_path)
def clean_flow_data_df_kf(data: pd.DataFrame, show_plot: bool = False) -> dict:
"""
接收一个 DataFrame 数据结构,使用一维 Kalman 滤波平滑并用预测值替换基于 IQR 检测出的异常点。
区分合理的0值流量转换和异常的0值连续多个0或孤立0
返回完整的清洗后的字典数据结构。
"""
# 使用传入的 DataFrame
data = data.copy()
# 替换0值填充NaN值
data_filled = data.replace(0, np.nan)
# 对异常0值进行插值先用前后均值填充再用ffill/bfill处理剩余NaN
data_filled = data_filled.interpolate(method="linear", limit_direction="both")
# 处理剩余的0值和NaN值
data_filled = data_filled.ffill().bfill()
# 存储 Kalman 平滑结果
data_kf = pd.DataFrame(index=data_filled.index, columns=data_filled.columns)
# 平滑每一列
for col in data_filled.columns:
observations = pd.Series(data_filled[col].values).ffill().bfill()
if observations.isna().any():
observations = observations.fillna(observations.mean())
obs = observations.values.astype(float)
kf = KalmanFilter(
transition_matrices=[1],
observation_matrices=[1],
initial_state_mean=float(obs[0]),
initial_state_covariance=1,
observation_covariance=10,
transition_covariance=10,
)
state_means, _ = kf.smooth(obs)
data_kf[col] = state_means.flatten()
# 计算残差并用IQR检测异常
residuals = data_filled - data_kf
residual_thresholds = {}
for col in data_filled.columns:
res_values = residuals[col].dropna().values
q1 = np.percentile(res_values, 25)
q3 = np.percentile(res_values, 75)
iqr = q3 - q1
lower_threshold = q1 - 1.5 * iqr
upper_threshold = q3 + 1.5 * iqr
residual_thresholds[col] = (lower_threshold, upper_threshold)
# 创建完整的修复数据
cleaned_data = data_filled.copy()
anomalies_info = {}
for col in data_filled.columns:
lower, upper = residual_thresholds[col]
sensor_residuals = residuals[col]
anomaly_mask = (sensor_residuals < lower) | (sensor_residuals > upper)
anomaly_idx = data_filled.index[anomaly_mask.fillna(False)]
anomalies_info[col] = pd.DataFrame(
{
"Observed": data_filled.loc[anomaly_idx, col],
"Kalman_Predicted": data_kf.loc[anomaly_idx, col],
"Residual": sensor_residuals.loc[anomaly_idx],
}
)
# 直接在原列上替换异常值为 Kalman 预测值
cleaned_data.loc[anomaly_idx, col] = data_kf.loc[anomaly_idx, col]
# 可选可视化
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
if show_plot and len(data.columns) > 0:
sensor_to_plot = data.columns[0]
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
plt.plot(
data.index,
data[sensor_to_plot],
label="原始监测值",
marker="o",
markersize=3,
alpha=0.7,
)
abnormal_zero_idx = data.index[data_filled[sensor_to_plot].isna()]
if len(abnormal_zero_idx) > 0:
plt.plot(
abnormal_zero_idx,
data[sensor_to_plot].loc[abnormal_zero_idx],
"mo",
markersize=8,
label="异常0值",
)
plt.plot(
data.index, data_kf[sensor_to_plot], label="Kalman滤波预测值", linewidth=2
)
anomaly_idx = anomalies_info[sensor_to_plot].index
if len(anomaly_idx) > 0:
plt.plot(
anomaly_idx,
data_filled[sensor_to_plot].loc[anomaly_idx],
"ro",
markersize=8,
label="IQR异常点",
)
plt.xlabel("时间点(序号)")
plt.ylabel("流量值")
plt.title(f"{sensor_to_plot}:原始数据与异常检测")
plt.legend()
plt.subplot(2, 1, 2)
plt.plot(
data.index,
cleaned_data[sensor_to_plot],
label="修复后监测值",
marker="o",
markersize=3,
color="green",
)
plt.xlabel("时间点(序号)")
plt.ylabel("流量值")
plt.title(f"{sensor_to_plot}:修复后数据")
plt.legend()
plt.tight_layout()
plt.show()
# 返回完整的修复后字典
return cleaned_data
# # 测试
# if __name__ == "__main__":
# # 默认:脚本目录下同名 CSV 文件
# script_dir = os.path.dirname(os.path.abspath(__file__))
# default_csv = os.path.join(script_dir, "pipe_flow_data_to_clean2.0.csv")
# out = clean_flow_data_kf(default_csv)
# print("清洗后的数据已保存到:", out)
# 测试 clean_flow_data_dict 函数
if __name__ == "__main__":
import random
# 读取 szh_flow_scada.csv 文件
script_dir = os.path.dirname(os.path.abspath(__file__))
csv_path = os.path.join(script_dir, "szh_flow_scada.csv")
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
# 排除 Time 列,随机选择 5 列
columns_to_exclude = ["Time"]
available_columns = [col for col in data.columns if col not in columns_to_exclude]
selected_columns = random.sample(available_columns, 1)
# 将选中的列转换为字典
data_dict = {col: data[col].tolist() for col in selected_columns}
print("选中的列:", selected_columns)
print("原始数据长度:", len(data_dict[selected_columns[0]]))
# 调用函数进行清洗
cleaned_dict = clean_flow_data_df_kf(data_dict, show_plot=True)
# 将清洗后的字典写回 CSV
out_csv = os.path.join(script_dir, f"{selected_columns[0]}_clean.csv")
pd.DataFrame(cleaned_dict).to_csv(out_csv, index=False, encoding="utf-8-sig")
print("已保存清洗结果到:", out_csv)
print("清洗后的字典键:", list(cleaned_dict.keys()))
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
print("测试完成:函数运行正常")

View File

@@ -0,0 +1,238 @@
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.impute import SimpleImputer
import os
def clean_pressure_data_km(input_csv_path: str, show_plot: bool = False) -> str:
"""
读取输入 CSV基于 KMeans 检测异常并用滚动平均修复。输出为 <input_basename>_cleaned.xlsx同目录
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'
返回输出文件的绝对路径。
"""
# 读取 CSV
input_csv_path = os.path.abspath(input_csv_path)
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
# 标准化
data_norm = (data - data.mean()) / data.std()
# 聚类与异常检测
k = 3
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
clusters = kmeans.fit_predict(data_norm)
centers = kmeans.cluster_centers_
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
threshold = distances.mean() + 3 * distances.std()
anomaly_pos = np.where(distances > threshold)[0]
anomaly_indices = data.index[anomaly_pos]
anomaly_details = {}
for pos in anomaly_pos:
row_norm = data_norm.iloc[pos]
cluster_idx = clusters[pos]
center = centers[cluster_idx]
diff = abs(row_norm - center)
main_sensor = diff.idxmax()
anomaly_details[data.index[pos]] = main_sensor
# 修复:滚动平均(窗口可调)
data_rolled = data.rolling(window=13, center=True, min_periods=1).mean()
data_repaired = data.copy()
for pos in anomaly_pos:
label = data.index[pos]
sensor = anomaly_details[label]
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
# 可选可视化(使用位置作为 x 轴)
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
if show_plot and len(data.columns) > 0:
n = len(data)
time = np.arange(n)
plt.figure(figsize=(12, 8))
for col in data.columns:
plt.plot(time, data[col].values, marker="o", markersize=3, label=col)
for pos in anomaly_pos:
sensor = anomaly_details[data.index[pos]]
plt.plot(pos, data.iloc[pos][sensor], "ro", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("压力监测值")
plt.title("各传感器折线图(红色标记主要异常点)")
plt.legend()
plt.show()
plt.figure(figsize=(12, 8))
for col in data_repaired.columns:
plt.plot(
time, data_repaired[col].values, marker="o", markersize=3, label=col
)
for pos in anomaly_pos:
sensor = anomaly_details[data.index[pos]]
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("修复后压力监测值")
plt.title("修复后各传感器折线图(绿色标记修复值)")
plt.legend()
plt.show()
# 保存到 Excel两个 sheet
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
output_filename = f"{input_base}_cleaned.xlsx"
output_path = os.path.join(input_dir, output_filename)
if os.path.exists(output_path):
os.remove(output_path) # 覆盖同名文件
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
data.to_excel(writer, sheet_name="raw_pressure_data", index=False)
data_repaired.to_excel(writer, sheet_name="cleaned_pressusre_data", index=False)
# 返回输出文件的绝对路径
return os.path.abspath(output_path)
def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> dict:
"""
接收一个 DataFrame 数据结构使用KMeans聚类检测异常并用滚动平均修复。
返回清洗后的字典数据结构。
"""
# 使用传入的 DataFrame
data = data.copy()
# 填充NaN值
data = data.ffill().bfill()
# 异常值预处理
# 将0值替换为NaN然后用线性插值填充
data_filled = data.replace(0, np.nan)
data_filled = data_filled.interpolate(method="linear", limit_direction="both")
# 如果仍有NaN全为0的列用前后值填充
data_filled = data_filled.ffill().bfill()
# 标准化(使用填充后的数据)
data_norm = (data_filled - data_filled.mean()) / data_filled.std()
# 添加:处理标准化后的 NaN例如标准差为0的列防止异常数据时间段内所有数据都相同导致计算结果为 NaN
imputer = SimpleImputer(
strategy="constant", fill_value=0, keep_empty_features=True
) # 用 0 填充 NaN包括全 NaN并保留空特征
data_norm = pd.DataFrame(
imputer.fit_transform(data_norm),
columns=data_norm.columns,
index=data_norm.index,
)
# 聚类与异常检测
k = 3
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
clusters = kmeans.fit_predict(data_norm)
centers = kmeans.cluster_centers_
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
threshold = distances.mean() + 3 * distances.std()
anomaly_pos = np.where(distances > threshold)[0]
anomaly_indices = data.index[anomaly_pos]
anomaly_details = {}
for pos in anomaly_pos:
row_norm = data_norm.iloc[pos]
cluster_idx = clusters[pos]
center = centers[cluster_idx]
diff = abs(row_norm - center)
main_sensor = diff.idxmax()
anomaly_details[data.index[pos]] = main_sensor
# 修复:滚动平均(窗口可调)
data_rolled = data_filled.rolling(window=13, center=True, min_periods=1).mean()
data_repaired = data_filled.copy()
for pos in anomaly_pos:
label = data.index[pos]
sensor = anomaly_details[label]
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
# 可选可视化(使用位置作为 x 轴)
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
if show_plot and len(data.columns) > 0:
n = len(data)
time = np.arange(n)
plt.figure(figsize=(12, 8))
for col in data.columns:
plt.plot(
time, data[col].values, marker="o", markersize=3, label=col, alpha=0.5
)
for col in data_filled.columns:
plt.plot(
time,
data_filled[col].values,
marker="x",
markersize=3,
label=f"{col}_filled",
linestyle="--",
)
for pos in anomaly_pos:
sensor = anomaly_details[data.index[pos]]
plt.plot(pos, data_filled.iloc[pos][sensor], "ro", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("压力监测值")
plt.title("各传感器折线图红色标记主要异常点虚线为0值填充后")
plt.legend()
plt.show()
plt.figure(figsize=(12, 8))
for col in data_repaired.columns:
plt.plot(
time, data_repaired[col].values, marker="o", markersize=3, label=col
)
for pos in anomaly_pos:
sensor = anomaly_details[data.index[pos]]
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
plt.xlabel("时间点(序号)")
plt.ylabel("修复后压力监测值")
plt.title("修复后各传感器折线图(绿色标记修复值)")
plt.legend()
plt.show()
# 返回清洗后的字典
return data_repaired
# 测试
# if __name__ == "__main__":
# # 默认使用脚本目录下的 pressure_raw_data.csv
# script_dir = os.path.dirname(os.path.abspath(__file__))
# default_csv = os.path.join(script_dir, "pressure_raw_data.csv")
# out_path = clean_pressure_data_km(default_csv, show_plot=False)
# print("保存路径:", out_path)
# 测试 clean_pressure_data_dict_km 函数
if __name__ == "__main__":
import random
# 读取 szh_pressure_scada.csv 文件
script_dir = os.path.dirname(os.path.abspath(__file__))
csv_path = os.path.join(script_dir, "szh_pressure_scada.csv")
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
# 排除 Time 列,随机选择 5 列
columns_to_exclude = ["Time"]
available_columns = [col for col in data.columns if col not in columns_to_exclude]
selected_columns = random.sample(available_columns, 5)
# 将选中的列转换为字典
data_dict = {col: data[col].tolist() for col in selected_columns}
print("选中的列:", selected_columns)
print("原始数据长度:", len(data_dict[selected_columns[0]]))
# 调用函数进行清洗
cleaned_dict = clean_pressure_data_df_km(data_dict, show_plot=True)
print("清洗后的字典键:", list(cleaned_dict.keys()))
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
print("测试完成:函数运行正常")

View File

@@ -0,0 +1,3 @@
from .Fdataclean import *
from .Pdataclean import *
from .pipeline_health_analyzer import *

View File

@@ -0,0 +1,109 @@
import wntr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn.cluster
import os
class QD_KMeans(object):
def __init__(self, wn, num_monitors):
# self.inp = inp
self.cluster_num = num_monitors # 聚类中心个数,也即测压点个数
self.wn=wn
self.monitor_nodes = []
self.coords = []
self.junction_nodes = {} # Added missing initialization
def get_junctions_coordinates(self):
for junction_name in self.wn.junction_name_list:
junction = self.wn.get_node(junction_name)
self.junction_nodes[junction_name] = junction.coordinates
self.coords.append(junction.coordinates )
# print(f"Total junctions: {self.junction_coordinates}")
def select_monitoring_points(self):
if not self.coords: # Add check if coordinates are collected
self.get_junctions_coordinates()
coords = np.array(self.coords)
coords_normalized = (coords - coords.min(axis=0)) / (coords.max(axis=0) - coords.min(axis=0))
kmeans = sklearn.cluster.KMeans(n_clusters= self.cluster_num, random_state=42)
kmeans.fit(coords_normalized)
for center in kmeans.cluster_centers_:
distances = np.sum((coords_normalized - center) ** 2, axis=1)
nearest_node = self.wn.junction_name_list[np.argmin(distances)]
self.monitor_nodes.append(nearest_node)
return self.monitor_nodes
def visualize_network(self):
"""Visualize network with monitoring points"""
ax=wntr.graphics.plot_network(self.wn,
node_attribute=self.monitor_nodes,
node_size=30,
title='Optimal sensor')
plt.show()
def kmeans_sensor_placement(name: str, sensor_num: int, min_diameter: int) -> list:
inp_name = f'./db_inp/{name}.db.inp'
wn= wntr.network.WaterNetworkModel(inp_name)
wn_cluster=QD_KMeans(wn, sensor_num)
# Select monitoring pointse
sensor_ids= wn_cluster.select_monitoring_points()
# wn_cluster.visualize_network()
return sensor_ids
if __name__ == "__main__":
#sensorindex = get_ID(name='suzhouhe_2024_cloud_0817', sensor_num=30, min_diameter=500)
sensorindex = kmeans_sensor_placement(name='szh', sensor_num=50, min_diameter=300)
print(sensorindex)

View File

@@ -0,0 +1,142 @@
import os
import joblib
import pandas as pd
import matplotlib.pyplot as plt
class PipelineHealthAnalyzer:
"""
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
该类封装了模型加载和预测功能,便于在其他项目中复用。
模型基于4个特征进行生存分析预测材料、直径、流速、压力。
使用前需确保安装依赖joblib, pandas, numpy, scikit-survival, matplotlib。
"""
def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"):
"""
初始化分析器,加载预训练的随机生存森林模型。
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
:raises FileNotFoundError: 如果模型文件不存在。
:raises Exception: 如果模型加载失败。
"""
# 确保 model 目录存在
model_dir = os.path.dirname(model_path)
if model_dir and not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件未找到: {model_path}")
try:
self.rsf = joblib.load(model_path)
self.features = [
"Material",
"Diameter",
"Flow Velocity",
"Pressure", # 'Temperature', 'Precipitation',
# 'Location', 'Structural Defects', 'Functional Defects'
]
except Exception as e:
raise Exception(f"加载模型时出错: {str(e)}")
def predict_survival(self, data: pd.DataFrame) -> list:
"""
基于输入数据预测生存函数。
:param data: pandas DataFrame包含4个必需特征列。数据应为数值型或可转换为数值型。
:return: 生存函数列表每个元素为一个生存函数对象包含时间点x和生存概率y
:raises ValueError: 如果数据缺少必需特征或格式不正确。
"""
# 检查必需特征是否存在
missing_features = [feat for feat in self.features if feat not in data.columns]
if missing_features:
raise ValueError(f"数据缺少必需特征: {missing_features}")
# 提取特征数据
try:
x_test = data[self.features].astype(float) # 确保数值型
except ValueError as e:
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
# 进行预测
survival_functions = self.rsf.predict_survival_function(x_test)
return list(survival_functions)
def plot_survival(
self, survival_functions: list, save_path: str = None, show_plot: bool = True
):
"""
可视化生存函数,生成生存概率图表。
:param survival_functions: predict_survival返回的生存函数列表。
:param save_path: 可选,保存图表的路径(.png格式。如果为None则不保存。
:param show_plot: 是否显示图表(在交互环境中)。
"""
plt.figure(figsize=(10, 6))
for i, sf in enumerate(survival_functions):
plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
plt.xlabel("时间(年)")
plt.ylabel("生存概率")
plt.title("管道生存概率预测")
plt.legend()
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"图表已保存到: {save_path}")
if show_plot:
plt.show()
else:
plt.close()
# 调用说明示例
"""
在其他项目中使用PipelineHealthAnalyzer类的步骤
1. 安装依赖在requirements.txt中添加
joblib==1.5.0
pandas==2.2.3
numpy==2.0.2
scikit-survival==0.23.1
matplotlib==3.9.4
2. 导入类:
from pipeline_health_analyzer import PipelineHealthAnalyzer
3. 初始化分析器(替换为实际模型路径):
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
4. 准备数据pandas DataFrame包含9个特征列
import pandas as pd
data = pd.DataFrame({
'Material': [1, 2], # 示例数据
'Diameter': [100, 150],
'Flow Velocity': [1.5, 2.0],
'Pressure': [50, 60],
'Temperature': [20, 25],
'Precipitation': [0.1, 0.2],
'Location': [1, 2],
'Structural Defects': [0, 1],
'Functional Defects': [0, 0]
})
5. 进行预测:
survival_funcs = analyzer.predict_survival(data)
6. 查看结果(每个样本的生存概率随时间变化):
for i, sf in enumerate(survival_funcs):
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
7. 可视化(可选):
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
注意:
- 数据格式必须匹配特征列表,特征值为数值型。
- 模型文件需从原项目复制或重新训练。
- 如果需要自定义特征或模型参数可修改类中的features列表或继承此类。
"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,654 @@
# 改进灵敏度法
import networkx
import numpy as np
import pandas
import wntr
import pandas as pd
import copy
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.cluster import KMeans
from wntr.epanet.toolkit import EpanetException
from numpy.linalg import slogdet
import random
from tjnetwork import *
from matplotlib.lines import Line2D
from sklearn.cluster import SpectralClustering
import libpysal as ps
from spopt.region import Skater
from shapely.geometry import Point
import geopandas as gpd
from sklearn.metrics import pairwise_distances
import project_info
# 2025/03/12
# Step1: 获取节点坐标
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
"""
获取管网模型的节点坐标
:param wn: 由wntr生成的模型
:return: 节点坐标
"""
# site: pandas.Series
# index节点名称wn.node_name_list
# values每个节点的坐标格式为 tuple如 (x, y) 或 (x, y, z)
site = wn.query_node_attribute('coordinates')
# Coor: pandas.Series
# index与site相同节点名称
# values坐标转换为numpy.ndarray如array([10.5, 20.3])
Coor = site.apply(lambda x: np.array(x)) # 将节点坐标转换为numpy数组
# x, y: list[float]
x = [] # 存储所有节点的 x 坐标
y = [] # 存储所有节点的 y 坐标
for i in range(0, len(Coor)):
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
# xy: dict[str, list], x、y 坐标的字典
xy = {'x': x, 'y': y}
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
return Coor_node
# 2025/03/12
# Step2: KMeans 聚类
# 将节点用kmeans根据坐标分为k组存入字典g
def kgroup(coor: pandas.DataFrame, knum: int) -> dict[int, list[str]]:
"""
使用KMeans聚类将节点坐标分组
:param coor: 存储所有节点的坐标数据
:param knum: 需要分成的聚类数
:return: 聚类结果字典
"""
g = {}
# estimator: sklearn.cluster.KMeans,KMeans 聚类模型
estimator = KMeans(n_clusters=knum)
estimator.fit(coor)
# label_pred: numpy.ndarrayint,每个点的类别标签
label_pred = estimator.labels_
for i in range(0, knum):
g[i] = coor[label_pred == i].index.tolist()
return g
def skater_partition(G, n_clusters):
"""
使用 SKATER 算法对输入的无向图 G 进行区域划分,
保证每个划分区域在图论意义上是连通的,
同时依据节点坐标的空间信息进行划分。
参数:
G: networkx.Graph
带有节点坐标属性(键为 'pos')的无向图。
n_clusters: int
希望划分的区域数量。
返回:
groups: dict
字典形式的聚类结果,键为区域编号,值为该区域内的节点列表。
"""
# 1. 获取所有节点坐标,假设每个节点都有 'pos' 属性
pos = nx.get_node_attributes(G, 'pos')
nodes = list(G.nodes())
# 构造坐标数组:每行为 [x, y]
coords = np.array([pos[node] for node in nodes])
# 2. 构造 GeoDataFrame创建 DataFrame 并生成 geometry 列
df = pd.DataFrame(coords, columns=['x', 'y'], index=nodes)
# 利用 shapely 的 Point 构造空间位置
df['geometry'] = df.apply(lambda row: Point(row['x'], row['y']), axis=1)
gdf = gpd.GeoDataFrame(df, geometry='geometry')
# 3. 构造空间权重矩阵,使用 4 近邻方法k=4可根据实际情况调整
w = ps.weights.KNN.from_array(coords, k=4)
w.transform = 'R'
# 4. 调用 SKATER新版本 API 要求传入 gdf, w 以及 attrs_name这里使用 'x' 和 'y' 作为属性)
skater = Skater(gdf, w, attrs_name=['x', 'y'], n_clusters=n_clusters)
skater.solve()
# 5. 获取聚类标签,构造成字典格式
labels = skater.labels_
groups = {}
for label, node in zip(labels, nodes):
groups.setdefault(label, []).append(node)
return groups
def spectral_partition(G, n_clusters):
"""
利用谱聚类算法对图 G 进行分区:
1. 根据所有节点的空间坐标计算欧氏距离矩阵;
2. 利用高斯核函数构造相似度矩阵;
3. 使用 SpectralClustering 进行归一化割,返回分区结果。
参数:
G: networkx.Graph
每个节点需要有 'pos' 属性,其值为 (x, y) 坐标。
n_clusters: int
希望划分的聚类数目。
返回:
groups: dict
键为聚类标签,值为该聚类对应的节点列表。
"""
# 1. 获取节点空间坐标,注意保证每个节点都有 'pos' 属性
pos_dict = nx.get_node_attributes(G, 'pos')
nodes = list(G.nodes())
coords = np.array([pos_dict[node] for node in nodes])
# 2. 计算节点之间的欧氏距离矩阵
D = pairwise_distances(coords, metric='euclidean')
# 3. 计算 sigma 值:这里取所有距离的均值,当然也可以根据实际情况调整
sigma = np.mean(D)
# 4. 构造相似度矩阵:使用高斯核函数
# A(i, j) = exp( -d(i,j)^2 / (2*sigma^2) )
A = np.exp(- (D ** 2) / (2 * sigma ** 2))
# 5. 使用谱聚类进行图分区
clustering = SpectralClustering(n_clusters=n_clusters,
affinity='precomputed',
random_state=0)
labels = clustering.fit_predict(A)
# 6. 构造字典形式的分区结果
groups = {}
for label, node in zip(labels, nodes):
groups.setdefault(label, []).append(node)
return groups
# 2025/03/12
# Step3: wn_func类水力计算
# wn_func 主要用于计算:
# 水力距离hydraulic length即节点之间的水力阻力。
# 灵敏度分析sensitivity analysis用于优化测压点的布置。
# 一些与水力相关的函数,包括 CtoS求水力距离,stafun求状态函数F
# # diff求F对P的导数返回灵敏度矩阵A
# # sensitivity返回灵敏度和总灵敏度
class wn_func(object):
# Step3.1: 初始化
def __init__(self, wn: wntr.network.WaterNetworkModel, min_diameter: int):
"""
获取管网模型信息
:param wn: 由wntr生成的模型
:param min_diameter: 安装的最小管径
"""
# self.results: wntr.sim.results.SimulationResults,仿真结果,包含压力、流量、水头等数据
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
self.wn = wn
# self.qpandas.DataFrame,管道流量,索引为时间步长,列为管道名称
self.q = self.results.link['flowrate']
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
ReservoirIndex = wn.reservoir_name_list
Tankindex = wn.tank_name_list
# 删除水库节点,删除与直接水库相连的虚拟管道
# self.pipes: list[str],所有管道的名称
self.pipes = wn.pipe_name_list
# self.nodes: list[str],所有节点的名称
self.nodes = wn.node_name_list
# self.coordinatespandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
self.coordinates = wn.query_node_attribute('coordinates')
# allpumps / allvalves: list[str],所有泵/阀门名称列表
allpumps = wn.pump_name_list
allvalves = wn.valve_name_list
# pumpstnode / pumpednode / valvestnode / valveednode: list[str],存储泵和阀门 起终点节点的名称
pumpstnode = []
pumpednode = []
valvestnode = []
valveednode = []
# Reservoirpipe / Reservoirednode: list[str],记录与水库相关的管道和节点
Reservoirpipe = []
Reservoirednode = []
for pump in allpumps:
pumpstnode.append(wn.links[pump].start_node.name)
pumpednode.append(wn.links[pump].end_node.name)
for valve in allvalves:
valvestnode.append(wn.links[valve].start_node.name)
valveednode.append(wn.links[valve].end_node.name)
for pipe in self.pipes:
if wn.links[pipe].start_node.name in ReservoirIndex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].end_node.name)
if wn.links[pipe].start_node.name in Tankindex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].end_node.name)
if wn.links[pipe].end_node.name in Tankindex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].start_node.name)
# 泵的起终点、tank、reservoir
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
self.delnodes = list(
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
# 泵、起终点为tank、reservoir的管道
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
# self.L: list[float],所有管道的长度(以米为单位)
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
self.n = len(self.nodes)
self.m = len(self.pipes)
# self.unit_headloss: list[float],单位水头损失headloss 数据的第一行,单位:米/km
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
##
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
# === 改动新增部分:筛选管径小于 min_diameter 的管道节点 ===
self.less_than_min_diameter_junction_list = []
for pipe in self.pipes:
diameter = wn.links[pipe].diameter
if diameter < min_diameter:
start_node = wn.links[pipe].start_node.name
end_node = wn.links[pipe].end_node.name
self.less_than_min_diameter_junction_list.extend([start_node, end_node])
# 去重
self.less_than_min_diameter_junction_list = list(set(self.less_than_min_diameter_junction_list))
# Step3.2: 计算水力距离
def CtoS(self):
"""
计算水力距离矩阵
:return:
"""
# 水力距离:当行索引对应的节点为控制点时,列索引对应的节点距离控制点的(路径*水头损失)的最小值
# nodeslist[str](节点名称)
nodes = copy.deepcopy(self.nodes)
# pipeslist[str](管道名称)
pipes = self.pipes
wn = self.wn
# n / mint节点数 / 管道数)
n = self.n
m = self.m
s1 = [0] * m
q = self.q
L = self.L
# H1pandas.DataFrame,水头数据,索引为时间步长,列为节点名
H1 = self.results.node['head'].T
# hhlist[float],计算管道两端水头之差
hh = []
# 水头损失
for p in pipes:
h1 = self.wn.links[p].start_node.name
h1 = H1.loc[str(h1)]
h2 = self.wn.links[p].end_node.name
h2 = H1.loc[str(h2)]
hh.append(abs(h1 - h2))
hh = np.array(hh)
# headlosspandas.DataFrame,管道水头损失矩阵
headloss = pd.DataFrame(hh, index=pipes).T
# s1:管道阻力系数s2将管道阻力系数与管道的起始节点和终止节点对应
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵index代表起始节点columns代表终止节点
G = nx.DiGraph()
for i in range(0, m):
pipe = pipes[i]
a = wn.links[pipe].start_node.name
b = wn.links[pipe].end_node.name
if q.loc[0, pipe] > 0:
hf.loc[a, b] = headloss.loc[0, pipe]
weightL.loc[a, b] = headloss.loc[0, pipe] * L[i]
G.add_weighted_edges_from([(a, b, weightL.loc[a, b])])
else:
hf.loc[b, a] = headloss.loc[0, pipe]
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
for a in nodes:
if a in G.nodes:
d = nx.shortest_path_length(G, source=a, weight='weight')
for b in list(d.keys()):
hydraulicL.loc[a, b] = d[b]
hydraulicL = hydraulicL.drop(self.delnodes)
hydraulicL = hydraulicL.drop(self.delnodes, axis=1)
# 求加权水力距离
return hydraulicL, G
# Step3.3: 计算灵敏度矩阵
# 获取关系矩阵
def get_Conn(self):
"""
计算管网连接关系矩阵
:return:
"""
m = self.wn.num_links
n = self.wn.num_nodes
p = self.wn.num_pumps
v = self.wn.num_valves
self.nonjunc_index = []
self.non_link_index = []
for r in self.wn.reservoirs():
self.nonjunc_index.append(r[0])
for t in self.wn.tanks():
self.nonjunc_index.append(t[0])
# Connnumpy.matrix节点-管道连接矩阵,起点 -1终点 1
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1终点为1
# NConnnumpy.matrix节点-节点连接矩阵,有管道相连的地方设为 1
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系之间有管道为1反之为0
# pipeslist[str],去除泵和阀门的管道列表
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
for pipe_name, pipe in pipes:
start = self.wn.node_name_list.index(pipe.start_node_name)
end = self.wn.node_name_list.index(pipe.end_node_name)
p_index = self.wn.link_name_list.index(pipe_name)
Conn[start, p_index] = -1
Conn[end, p_index] = 1
NConn[start, end] = 1
NConn[end, start] = 1
self.A = Conn
link_name_list = [link for link in self.wn.link_name_list if
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
self.A2 = self.A2.drop(self.delnodes)
for pipe in self.delpipes:
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
self.A2 = self.A2.drop(columns=pipe)
self.junc_list = self.A2.index
self.A2 = np.mat(self.A2) # 节点管道关系
self.A3 = NConn
def Jaco(self, hL: pandas.DataFrame):
"""
计算灵敏度矩阵(节点压力对粗糙度变化的响应)
:param hL: 水力距离矩阵
:return:
"""
# global result
# Anumpy.matrix, 节点-管道关系矩阵
A = self.A2
wn = self.wn
try:
result = wntr.sim.EpanetSimulator(wn).run_sim()
except EpanetException:
pass
finally:
h = result.link['headloss'][self.pipes].values[0]
q = result.link['flowrate'][self.pipes].values[0]
l = self.wn.query_link_attribute('length')[self.pipes]
C = self.wn.query_link_attribute('roughness')[self.pipes]
# headlossnumpy.ndarray,水头损失数组
headloss = np.array(h)
# 调整流量方向
for i in range(0, len(q)):
if q[i] < 0:
A[:, i] = -A[:, i]
# qnumpy.ndarray,流量数组
q = np.abs(q)
# 两个灵敏度矩阵
# B / Snumpy.matrix,灵敏度计算的中间矩阵
B = np.mat(np.diag(q / ((1.852 * headloss) + 1e-10)))
S = np.mat(np.diag(q / C))
# Xnumpy.matrix, 灵敏度矩阵
X = A * B * A.T
try:
det = np.linalg.det(X)
except RuntimeError as e:
sign, logdet = slogdet(X) # 防止溢出
det = sign * np.exp(logdet)
if det != 0:
J_H_Cw = X.I * A * S
# J_H_Q = -X.I
J_q_Cw = S - B * A.T * X.I * A * S # 去掉了delnodes和delpipes
# J_q_Q = B * A.T * X.I
else: # 当X不可逆
J_H_Cw = np.linalg.pinv(X) @ A @ S
# J_H_Q = -np.linalg.pinv(X)
J_q_Cw = S - B * A.T * np.linalg.pinv(X) * A * S
# J_q_Q = B * A.T * np.linalg.pinv(X)
Sen_pressure = []
S_pressure = np.abs(J_H_Cw).sum(axis=1).tolist() # 修改为绝对值
for ss in S_pressure:
Sen_pressure.append(ss[0])
# 求总灵敏度
SS_pressure = copy.deepcopy(hL)
for i in range(0, len(Sen_pressure)):
SS_pressure.iloc[i, :] = SS_pressure.iloc[i, :] * Sen_pressure[i]
SS = copy.deepcopy(hL)
for i in range(0, len(Sen_pressure)):
SS.iloc[i, :] = SS.iloc[i, :] * Sen_pressure[i]
# SS[i,j]:节点nodes[i]的灵敏度*该节点到nodes[j]的水力距离
return SS
# 2025/03/12
# Step4: 传感器布置优化
# Sensorplacement
# weight分配权重
# sensor传感器布置的位置
class Sensorplacement(wn_func):
"""
Sensorplacement 类继承了 wn_func 类,并且用于计算和优化传感器布置的位置。
"""
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int, min_diameter: int):
"""
:param wn: 由wntr生成的模型
:param sensornum: 传感器的数量
:param min_diameter: 安装的最小管径
"""
wn_func.__init__(self, wn, min_diameter=min_diameter)
self.sensornum = sensornum
# 1.某个节点到所有节点的加权距离之和
# 2.某个节点到该组内所有节点的加权距离之和
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
"""
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
:param SS: 灵敏度矩阵每个节点的行和列代表不同节点矩阵元素表示节点间的灵敏度。SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
:param G: 加权图,表示管网的拓扑结构,每个节点通过管道连接。图的边的权重通常是根据水力距离或者流量等计算的
:param group: 节点分组,字典的键是分组编号,值是该组的节点名称列表
:return:
"""
# 传感器布置个数以及位置
# W = self.weight()
n = self.n - len(self.delnodes)
nodes = copy.deepcopy(self.nodes)
for node in self.delnodes:
nodes.remove(node)
# sumSSlist[float],每个节点到其他节点的灵敏度之和。SS.iloc[i, :] 返回第 i 个节点与所有其他节点的灵敏度值sum(SS.iloc[i, :]) 计算这些灵敏度值的总和。
sumSS = []
for i in range(0, n):
sumSS.append(sum(SS.iloc[i, :]))
# 一个整数范围表示每个节点的索引用作sumSS_ DataFrame的索引
indices = range(0, n)
# sumSS_pandas.DataFrame,将 sumSS 转换成 DataFrame 格式,并且将节点的总灵敏度保存到 CSV 文件 sumSS_data.csv 中
sumSS_ = pd.DataFrame(np.array(sumSS), index=indices)
# sumSS_.to_csv('sumSS_data.csv') # 存储节点总灵敏度
# sumSSpandas.DataFrame,sumSS 被转换为 DataFrame 类型并且按总灵敏度即灵敏度之和降序排列。此时sumSS 是按节点的灵敏度之和排序的 DataFrame
sumSS = pd.DataFrame(np.array(sumSS), index=nodes)
sumSS = sumSS.sort_values(by=[0], ascending=[False])
# sensorindexlist[str],用于存储根据灵敏度排序选出的传感器位置的节点名称,存储根据总灵敏度排序的节点列表,用于传感器布置
sensorindex = []
# sensorindex_2list[str],用于存储每组内根据灵敏度排序选出的传感器位置的节点名称,存储每个组内根据灵敏度排序选择的传感器节点
sensorindex_2 = []
# group_Sdict[int, pandas.DataFrame],存储每个组内的灵敏度矩阵
group_S = {}
# group_sumSSdict[int, list[float]],存储每个组内节点的总灵敏度,值为每个组内节点灵敏度之和的列表
group_sumSS = {}
# 改动
for i in range(0, len(group)):
for node in self.delnodes:
# 这里的group[i]是每个组的节点列表代码首先去除已经被标记为删除的节点self.delnodes
if node in group[i]:
group[i].remove(node)
group_S[i] = SS.loc[group[i], group[i]]
# 对每个组内的节点计算组内节点的总灵敏度group_sumSS[i])。它将每个组内节点的灵敏度值相加,并且按灵敏度降序排序
group_sumSS[i] = []
for j in range(0, len(group[i])):
group_sumSS[i].append(sum(group_S[i].iloc[j, :]))
group_sumSS[i] = pd.DataFrame(np.array(group_sumSS[i]), index=group[i])
group_sumSS[i] = group_sumSS[i].sort_values(by=[0], ascending=[False])
for node in self.less_than_min_diameter_junction_list:
# 这里的group_sumSS[i]是每个分组的灵敏度节点排序列表去除已经被标记为删除的节点self.less_than_min_diameter_junction_list
if node in group_sumSS[i]:
group_sumSS[i].remove(node)
pass
# 1.选sumSS最大的节点然后把这个节点所在的那个组删掉就可以不再从这个组选点。再重新排序选sumSS最大的;
# 2.在每组内选group_sumSS最大的节点
# 在这个循环中首先选择灵敏度最高的节点Smaxnode并添加到sensorindex。然后根据灵敏度排序删除已选的节点并继续选择下一个灵敏度最大的节点。这个过程用于选择传感器的位置
sensornum = self.sensornum
for i in range(0, sensornum):
# Smaxnodestr,最大灵敏度节点sumSS.index[0] 表示灵敏度最高的节点
Smaxnode = sumSS.index[0]
sensorindex.append(Smaxnode)
sensorindex_2.append(group_sumSS[i].index[0])
for key, value in group.items():
if Smaxnode in value:
sumSS = sumSS.drop(index=group[key])
continue
sumSS = sumSS.sort_values(by=[0], ascending=[False])
return sensorindex, sensorindex_2
# 2025/03/13
def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
"""
获取布置测压点的坐标,初始测压点布置根据灵敏度来布置计算初始情况下的校准过程的error
:param name: 数据库名称
:param sensor_num: 测压点数目
:param min_diameter: 安装的最小管径
:return: 测压点节点ID
"""
# inp_file_realstr,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp包含管网的结构信息、节点、管道、泵等数据
inp_file_real = f'./db_inp/{name}.db.inp'
# sensornumint,需要布置的传感器数量
# sensornum = sensor_num
# wn_realwntr.network.WaterNetworkModel,加载 EPANET 水力模型
wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
# sim_realwntr.sim.EpanetSimulator,创建一个水力仿真器对象
sim_real = wntr.sim.EpanetSimulator(wn_real)
# results_realwntr.sim.results.SimulationResults,运行仿真并返回结果
results_real = sim_real.run_sim()
# real_Clist[float],包含所有管道粗糙度的列表
real_C = wn_real.query_link_attribute('roughness').tolist()
# wn_fun1wn_func继承自 object创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
wn_fun1 = wn_func(wn_real, min_diameter=min_diameter)
# nodeslist[str],管网的节点名称列表
nodes = wn_fun1.nodes
# delnodeslist[str],被删除的节点(如水库、泵、阀门连接的节点等)
delnodes = wn_fun1.delnodes
# Coor_nodepandas.DataFrame
Coor_node = getCoor(wn_real)
Coor_node = Coor_node.drop(wn_fun1.delnodes)
nodes = [node for node in wn_fun1.nodes if node not in delnodes]
# coordinatespandas.Series存储所有节点的坐标类型为 Series索引为节点名称值为 (x, y) 坐标对
coordinates = wn_fun1.coordinates
# 随机产生监测点
# junctionnumintnodes 的长度,表示节点的数量
junctionnum = len(nodes)
# random_numberslist[int],使用 random.sample 随机选择 sensornum20个节点的编号。它返回一个不重复的随机编号列表
# random_numbers = random.sample(range(junctionnum), sensor_num)
# for i in range(sensor_num):
# # print(random_numbers[i])
wn_fun1.get_Conn()
# hLpandas.DataFrame水力距离矩阵表示每个节点到其他节点的水力阻力
# Gnetworkx.DiGraph加权有向图表示管网的拓扑结构节点之间的边带有权重
hL, G = wn_fun1.CtoS()
# SSpandas.DataFrame灵敏度矩阵表示每个节点对管网变化如粗糙度、流量等的响应
SS = wn_fun1.Jaco(hL)
# groupdict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
G1 = wn_real.to_graph()
G1 = G1.to_undirected() # 变为无向图
group = kgroup(Coor_node, sensor_num)
# group = skater_partition(G1, sensor_num)
# group = spectral_partition(G1, sensor_num)
# print(group)
# --------------------- 保存 group 数据 ---------------------
# 将 group 数据转换为一个“长格式”的 DataFrame
# 每一行记录一个节点及其所属的分组
# group_data = []
# for group_id, node_list in group.items():
# for node in node_list:
# group_data.append({"Group": group_id, "Node": node})
#
# df_group = pd.DataFrame(group_data)
#
# # 保存为 Excel 文件,文件名为 "group.xlsx"index=False 表示不保存行索引
# df_group.to_excel("group.xlsx", index=False)
# wn_funSensorplacement继承自wn_func
# 创建Sensorplacement类的实例传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
wn_fun = Sensorplacement(wn_real, sensor_num, min_diameter=min_diameter)
wn_fun.__dict__.update(wn_fun1.__dict__)
# sensorindexlist[str],初始传感器布置位置的节点名称
# sensorindex_2list[str],根据分组选择的传感器位置
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
# 重新打开数据库
# if is_project_open(name=name):
# close_project(name=name)
# open_project(name=name)
# for node_id in sensorindex :
# sensor_coord[node_id] = get_node_coord(name=name, node_id=node_id)
# close_project(name=name)
# print(sensor_coord)
# # 分区画图
# colorlist = ['lightpink', 'coral', 'rosybrown', 'olive', 'powderblue', 'lightskyblue', 'steelblue', 'peachpuff','brown','silver','indigo','lime','gold','violet','maroon','navy','teal','magenta','cyan',
# 'burlywood', 'tan', 'slategrey', 'thistle', 'lightseagreen', 'lightgreen', 'red','blue','yellow','orange','purple','grey','green','pink','lightblue','beige','chartreuse','turquoise','lavender','fuchsia','coral']
# G = wn_real.to_graph()
# G = G.to_undirected() # 变为无向图
# pos = nx.get_node_attributes(G, 'pos')
# pass
#
# for i in range(0, sensor_num):
# ax = plt.gca()
# ax.set_title(inp_file_real + str(sensor_num))
# nodes = nx.draw_networkx_nodes(G, pos, nodelist=group[i], node_color=colorlist[i], node_size=10)
# nodes = nx.draw_networkx_nodes(G, pos,
# nodelist=sensorindex_2, node_color='red', node_size=70, node_shape='*'
# )
# edges = nx.draw_networkx_edges(G, pos)
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)
# ax.spines['bottom'].set_visible(False)
# ax.spines['left'].set_visible(False)
# plt.savefig(inp_file_real + str(sensor_num) + ".png", dpi=300)
# plt.show()
#
# wntr.graphics.plot_network(wn_real, node_attribute=sensorindex_2, node_size=50, node_labels=False,
# title=inp_file_real + '_Projetion' + str(sensor_num))
# plt.savefig(inp_file_real + '_S' + str(sensor_num) + ".png", dpi=300)
# plt.show()
return sensorindex
if __name__ == '__main__':
sensorindex = get_ID(name=project_info.name, sensor_num=20, min_diameter=300)
print(sensorindex)
# 将 sensor_coord 字典转换为 DataFrame
# 使用 orient='index' 表示字典的键作为 DataFrame 的行索引,
# 数据中每个键对应的 value 是一个子字典,其键 'x' 和 'y' 成为 DataFrame 的列名
# df_sensor_coord = pd.DataFrame.from_dict(sensor_coord, orient='index')
#
# # 将索引名称设为 'Node'
# df_sensor_coord.index.name = 'Node'
#
# # 保存到 Excel 文件
# df_sensor_coord.to_excel("sensor_coord.xlsx", index=True)

View File

@@ -0,0 +1,557 @@
# 改进灵敏度法
import networkx
import numpy as np
import pandas
import wntr
import pandas as pd
import copy
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.cluster import KMeans
from wntr.epanet.toolkit import EpanetException
from numpy.linalg import slogdet
import random
from tjnetwork import *
import project_info
# 2025/03/12
# Step1: 获取节点坐标
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
"""
获取管网模型的节点坐标
:param wn: 由wntr生成的模型
:return: 节点坐标
"""
# site: pandas.Series
# index节点名称wn.node_name_list
# values每个节点的坐标格式为 tuple如 (x, y) 或 (x, y, z)
site = wn.query_node_attribute('coordinates')
# Coor: pandas.Series
# index与site相同节点名称
# values坐标转换为numpy.ndarray如array([10.5, 20.3])
Coor = site.apply(lambda x: np.array(x)) # 将节点坐标转换为numpy数组
# x, y: list[float]
x = [] # 存储所有节点的 x 坐标
y = [] # 存储所有节点的 y 坐标
for i in range(0, len(Coor)):
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
# xy: dict[str, list], x、y 坐标的字典
xy = {'x': x, 'y': y}
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
return Coor_node
# 2025/03/12
# Step2: KMeans 聚类
# 将节点用kmeans根据坐标分为k组存入字典g
def kgroup(coor: pandas.DataFrame, knum: int) -> dict[int, list[str]]:
"""
使用KMeans聚类将节点坐标分组
:param coor: 存储所有节点的坐标数据
:param knum: 需要分成的聚类数
:return: 聚类结果字典
"""
g = {}
# estimator: sklearn.cluster.KMeans,KMeans 聚类模型
estimator = KMeans(n_clusters=knum)
estimator.fit(coor)
# label_pred: numpy.ndarrayint,每个点的类别标签
label_pred = estimator.labels_
for i in range(0, knum):
g[i] = coor[label_pred == i].index.tolist()
return g
# 2025/03/12
# Step3: wn_func类水力计算
# wn_func 主要用于计算:
# 水力距离hydraulic length即节点之间的水力阻力。
# 灵敏度分析sensitivity analysis用于优化测压点的布置。
# 一些与水力相关的函数,包括 CtoS求水力距离,stafun求状态函数F
# # diff求F对P的导数返回灵敏度矩阵A
# # sensitivity返回灵敏度和总灵敏度
class wn_func(object):
# Step3.1: 初始化
def __init__(self, wn: wntr.network.WaterNetworkModel):
"""
获取管网模型信息
:param wn: 由wntr生成的模型
"""
# self.results: wntr.sim.results.SimulationResults,仿真结果,包含压力、流量、水头等数据
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
self.wn = wn
# self.qpandas.DataFrame,管道流量,索引为时间步长,列为管道名称
self.q = self.results.link['flowrate']
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
ReservoirIndex = wn.reservoir_name_list
Tankindex = wn.tank_name_list
# 删除水库节点,删除与直接水库相连的虚拟管道
# self.pipes: list[str],所有管道的名称
self.pipes = wn.pipe_name_list
# self.nodes: list[str],所有节点的名称
self.nodes = wn.node_name_list
# self.coordinatespandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
self.coordinates = wn.query_node_attribute('coordinates')
# allpumps / allvalves: list[str],所有泵/阀门名称列表
allpumps = wn.pump_name_list
allvalves = wn.valve_name_list
# pumpstnode / pumpednode / valvestnode / valveednode: list[str],存储泵和阀门 起终点节点的名称
pumpstnode = []
pumpednode = []
valvestnode = []
valveednode = []
# Reservoirpipe / Reservoirednode: list[str],记录与水库相关的管道和节点
Reservoirpipe = []
Reservoirednode = []
for pump in allpumps:
pumpstnode.append(wn.links[pump].start_node.name)
pumpednode.append(wn.links[pump].end_node.name)
for valve in allvalves:
valvestnode.append(wn.links[valve].start_node.name)
valveednode.append(wn.links[valve].end_node.name)
for pipe in self.pipes:
if wn.links[pipe].start_node.name in ReservoirIndex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].end_node.name)
if wn.links[pipe].start_node.name in Tankindex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].end_node.name)
if wn.links[pipe].end_node.name in Tankindex:
Reservoirpipe.append(pipe)
Reservoirednode.append(wn.links[pipe].start_node.name)
# 泵的起终点、tank、reservoir
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
self.delnodes = list(
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
# 泵、起终点为tank、reservoir的管道
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
# self.L: list[float],所有管道的长度(以米为单位)
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
self.n = len(self.nodes)
self.m = len(self.pipes)
# self.unit_headloss: list[float],单位水头损失headloss 数据的第一行,单位:米/km
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
##
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
# Step3.2: 计算水力距离
def CtoS(self):
"""
计算水力距离矩阵
:return:
"""
# 水力距离:当行索引对应的节点为控制点时,列索引对应的节点距离控制点的(路径*水头损失)的最小值
# nodeslist[str](节点名称)
nodes = copy.deepcopy(self.nodes)
# pipeslist[str](管道名称)
pipes = self.pipes
wn = self.wn
# n / mint节点数 / 管道数)
n = self.n
m = self.m
s1 = [0] * m
q = self.q
L = self.L
# H1pandas.DataFrame,水头数据,索引为时间步长,列为节点名
H1 = self.results.node['head'].T
# hhlist[float],计算管道两端水头之差
hh = []
# 水头损失
for p in pipes:
h1 = self.wn.links[p].start_node.name
h1 = H1.loc[str(h1)]
h2 = self.wn.links[p].end_node.name
h2 = H1.loc[str(h2)]
hh.append(abs(h1 - h2))
hh = np.array(hh)
# headlosspandas.DataFrame,管道水头损失矩阵
headloss = pd.DataFrame(hh, index=pipes).T
# s1:管道阻力系数s2将管道阻力系数与管道的起始节点和终止节点对应
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵index代表起始节点columns代表终止节点
G = nx.DiGraph()
for i in range(0, m):
pipe = pipes[i]
a = wn.links[pipe].start_node.name
b = wn.links[pipe].end_node.name
if q.loc[0, pipe] > 0:
hf.loc[a, b] = headloss.loc[0, pipe]
weightL.loc[a, b] = headloss.loc[0, pipe] * L[i]
G.add_weighted_edges_from([(a, b, weightL.loc[a, b])])
else:
hf.loc[b, a] = headloss.loc[0, pipe]
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
for a in nodes:
if a in G.nodes:
d = nx.shortest_path_length(G, source=a, weight='weight')
for b in list(d.keys()):
hydraulicL.loc[a, b] = d[b]
hydraulicL = hydraulicL.drop(self.delnodes)
hydraulicL = hydraulicL.drop(self.delnodes, axis=1)
# 求加权水力距离
return hydraulicL, G
# Step3.3: 计算灵敏度矩阵
# 获取关系矩阵
def get_Conn(self):
"""
计算管网连接关系矩阵
:return:
"""
m = self.wn.num_links
n = self.wn.num_nodes
p = self.wn.num_pumps
v = self.wn.num_valves
self.nonjunc_index = []
self.non_link_index = []
for r in self.wn.reservoirs():
self.nonjunc_index.append(r[0])
for t in self.wn.tanks():
self.nonjunc_index.append(t[0])
# Connnumpy.matrix节点-管道连接矩阵,起点 -1终点 1
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1终点为1
# NConnnumpy.matrix节点-节点连接矩阵,有管道相连的地方设为 1
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系之间有管道为1反之为0
# pipeslist[str],去除泵和阀门的管道列表
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
for pipe_name, pipe in pipes:
start = self.wn.node_name_list.index(pipe.start_node_name)
end = self.wn.node_name_list.index(pipe.end_node_name)
p_index = self.wn.link_name_list.index(pipe_name)
Conn[start, p_index] = -1
Conn[end, p_index] = 1
NConn[start, end] = 1
NConn[end, start] = 1
self.A = Conn
link_name_list = [link for link in self.wn.link_name_list if
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
self.A2 = self.A2.drop(self.delnodes)
for pipe in self.delpipes:
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
self.A2 = self.A2.drop(columns=pipe)
self.junc_list = self.A2.index
self.A2 = np.mat(self.A2) # 节点管道关系
self.A3 = NConn
def Jaco(self, hL: pandas.DataFrame):
"""
计算灵敏度矩阵(节点压力对粗糙度变化的响应)
:param hL: 水力距离矩阵
:return:
"""
# global result
# Anumpy.matrix, 节点-管道关系矩阵
A = self.A2
wn = self.wn
try:
result = wntr.sim.EpanetSimulator(wn).run_sim()
except EpanetException:
pass
finally:
h = result.link['headloss'][self.pipes].values[0]
q = result.link['flowrate'][self.pipes].values[0]
l = self.wn.query_link_attribute('length')[self.pipes]
C = self.wn.query_link_attribute('roughness')[self.pipes]
# headlossnumpy.ndarray,水头损失数组
headloss = np.array(h)
# 调整流量方向
for i in range(0, len(q)):
if q[i] < 0:
A[:, i] = -A[:, i]
# qnumpy.ndarray,流量数组
q = np.abs(q)
# 两个灵敏度矩阵
# B / Snumpy.matrix,灵敏度计算的中间矩阵
B = np.mat(np.diag(q / ((1.852 * headloss) + 1e-10)))
S = np.mat(np.diag(q / C))
# Xnumpy.matrix, 灵敏度矩阵
X = A * B * A.T
try:
det = np.linalg.det(X)
except RuntimeError as e:
sign, logdet = slogdet(X) # 防止溢出
det = sign * np.exp(logdet)
if det != 0:
J_H_Cw = X.I * A * S
# J_H_Q = -X.I
J_q_Cw = S - B * A.T * X.I * A * S # 去掉了delnodes和delpipes
# J_q_Q = B * A.T * X.I
else: # 当X不可逆
J_H_Cw = np.linalg.pinv(X) @ A @ S
# J_H_Q = -np.linalg.pinv(X)
J_q_Cw = S - B * A.T * np.linalg.pinv(X) * A * S
# J_q_Q = B * A.T * np.linalg.pinv(X)
Sen_pressure = []
S_pressure = np.abs(J_H_Cw).sum(axis=1).tolist() # 修改为绝对值
for ss in S_pressure:
Sen_pressure.append(ss[0])
# 求总灵敏度
SS_pressure = copy.deepcopy(hL)
for i in range(0, len(Sen_pressure)):
SS_pressure.iloc[i, :] = SS_pressure.iloc[i, :] * Sen_pressure[i]
SS = copy.deepcopy(hL)
for i in range(0, len(Sen_pressure)):
SS.iloc[i, :] = SS.iloc[i, :] * Sen_pressure[i]
# SS[i,j]:节点nodes[i]的灵敏度*该节点到nodes[j]的水力距离
return SS
# 2025/03/12
# Step4: 传感器布置优化
# Sensorplacement
# weight分配权重
# sensor传感器布置的位置
class Sensorplacement(wn_func):
"""
Sensorplacement 类继承了 wn_func 类,并且用于计算和优化传感器布置的位置。
"""
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int):
"""
:param wn: 由wntr生成的模型
:param sensornum: 传感器的数量
"""
wn_func.__init__(self, wn)
self.sensornum = sensornum
# 1.某个节点到所有节点的加权距离之和
# 2.某个节点到该组内所有节点的加权距离之和
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
"""
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
:param SS: 灵敏度矩阵每个节点的行和列代表不同节点矩阵元素表示节点间的灵敏度。SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
:param G: 加权图,表示管网的拓扑结构,每个节点通过管道连接。图的边的权重通常是根据水力距离或者流量等计算的
:param group: 节点分组,字典的键是分组编号,值是该组的节点名称列表
:return:
"""
# 传感器布置个数以及位置
# W = self.weight()
n = self.n - len(self.delnodes)
nodes = copy.deepcopy(self.nodes)
for node in self.delnodes:
nodes.remove(node)
# sumSSlist[float],每个节点到其他节点的灵敏度之和。SS.iloc[i, :] 返回第 i 个节点与所有其他节点的灵敏度值sum(SS.iloc[i, :]) 计算这些灵敏度值的总和。
sumSS = []
for i in range(0, n):
sumSS.append(sum(SS.iloc[i, :]))
# 一个整数范围表示每个节点的索引用作sumSS_ DataFrame的索引
indices = range(0, n)
# sumSS_pandas.DataFrame,将 sumSS 转换成 DataFrame 格式,并且将节点的总灵敏度保存到 CSV 文件 sumSS_data.csv 中
sumSS_ = pd.DataFrame(np.array(sumSS), index=indices)
sumSS_.to_csv('sumSS_data.csv') # 存储节点总灵敏度
# sumSSpandas.DataFrame,sumSS 被转换为 DataFrame 类型并且按总灵敏度即灵敏度之和降序排列。此时sumSS 是按节点的灵敏度之和排序的 DataFrame
sumSS = pd.DataFrame(np.array(sumSS), index=nodes)
sumSS = sumSS.sort_values(by=[0], ascending=[False])
# sensorindexlist[str],用于存储根据灵敏度排序选出的传感器位置的节点名称,存储根据总灵敏度排序的节点列表,用于传感器布置
sensorindex = []
# sensorindex_2list[str],用于存储每组内根据灵敏度排序选出的传感器位置的节点名称,存储每个组内根据灵敏度排序选择的传感器节点
sensorindex_2 = []
# group_Sdict[int, pandas.DataFrame],存储每个组内的灵敏度矩阵
group_S = {}
# group_sumSSdict[int, list[float]],存储每个组内节点的总灵敏度,值为每个组内节点灵敏度之和的列表
group_sumSS = {}
for i in range(0, len(group)):
for node in self.delnodes:
# 这里的group[i]是每个组的节点列表代码首先去除已经被标记为删除的节点self.delnodes
if node in group[i]:
group[i].remove(node)
group_S[i] = SS.loc[group[i], group[i]]
# 对每个组内的节点计算组内节点的总灵敏度group_sumSS[i])。它将每个组内节点的灵敏度值相加,并且按灵敏度降序排序
group_sumSS[i] = []
for j in range(0, len(group[i])):
group_sumSS[i].append(sum(group_S[i].iloc[j, :]))
group_sumSS[i] = pd.DataFrame(np.array(group_sumSS[i]), index=group[i])
group_sumSS[i] = group_sumSS[i].sort_values(by=[0], ascending=[False])
pass
# 1.选sumSS最大的节点然后把这个节点所在的那个组删掉就可以不再从这个组选点。再重新排序选sumSS最大的;
# 2.在每组内选group_sumSS最大的节点
# 在这个循环中首先选择灵敏度最高的节点Smaxnode并添加到sensorindex。然后根据灵敏度排序删除已选的节点并继续选择下一个灵敏度最大的节点。这个过程用于选择传感器的位置
sensornum = self.sensornum
for i in range(0, sensornum):
# Smaxnodestr,最大灵敏度节点sumSS.index[0] 表示灵敏度最高的节点
Smaxnode = sumSS.index[0]
sensorindex.append(Smaxnode)
sensorindex_2.append(group_sumSS[i].index[0])
for key, value in group.items():
if Smaxnode in value:
sumSS = sumSS.drop(index=group[key])
continue
sumSS = sumSS.sort_values(by=[0], ascending=[False])
return sensorindex, sensorindex_2
# 2025/03/13
def get_sensor_coord(name: str, sensor_num: int) -> dict[str, float]:
"""
获取布置测压点的坐标,初始测压点布置根据灵敏度来布置计算初始情况下的校准过程的error
:param name: 数据库名称
:param sensor_num: 测压点数目
:return: 测压点坐标字典
"""
# inp_file_realstr,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp包含管网的结构信息、节点、管道、泵等数据
inp_file_real = f'./db_inp/{name}.db.inp'
# sensornumint,需要布置的传感器数量
# sensornum = sensor_num
# wn_realwntr.network.WaterNetworkModel,加载 EPANET 水力模型
wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
# sim_realwntr.sim.EpanetSimulator,创建一个水力仿真器对象
sim_real = wntr.sim.EpanetSimulator(wn_real)
# results_realwntr.sim.results.SimulationResults,运行仿真并返回结果
results_real = sim_real.run_sim()
# real_Clist[float],包含所有管道粗糙度的列表
real_C = wn_real.query_link_attribute('roughness').tolist()
# wn_fun1wn_func继承自 object创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
wn_fun1 = wn_func(wn_real)
# nodeslist[str],管网的节点名称列表
nodes = wn_fun1.nodes
# delnodeslist[str],被删除的节点(如水库、泵、阀门连接的节点等)
delnodes = wn_fun1.delnodes
# Coor_nodepandas.DataFrame
Coor_node = getCoor(wn_real)
Coor_node = Coor_node.drop(wn_fun1.delnodes)
nodes = [node for node in wn_fun1.nodes if node not in delnodes]
# coordinatespandas.Series存储所有节点的坐标类型为 Series索引为节点名称值为 (x, y) 坐标对
coordinates = wn_fun1.coordinates
# 随机产生监测点
# junctionnumintnodes 的长度,表示节点的数量
junctionnum = len(nodes)
# random_numberslist[int],使用 random.sample 随机选择 sensornum20个节点的编号。它返回一个不重复的随机编号列表
# random_numbers = random.sample(range(junctionnum), sensor_num)
# for i in range(sensor_num):
# # print(random_numbers[i])
wn_fun1.get_Conn()
# hLpandas.DataFrame水力距离矩阵表示每个节点到其他节点的水力阻力
# Gnetworkx.DiGraph加权有向图表示管网的拓扑结构节点之间的边带有权重
hL, G = wn_fun1.CtoS()
# SSpandas.DataFrame灵敏度矩阵表示每个节点对管网变化如粗糙度、流量等的响应
SS = wn_fun1.Jaco(hL)
# groupdict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
group = kgroup(Coor_node, sensor_num)
# wn_funSensorplacement继承自wn_func
# 创建Sensorplacement类的实例传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
wn_fun = Sensorplacement(wn_real, sensor_num)
wn_fun.__dict__.update(wn_fun1.__dict__)
# sensorindexlist[str],初始传感器布置位置的节点名称
# sensorindex_2list[str],根据分组选择的传感器位置
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
sensor_coord = {}
# 重新打开数据库
if is_project_open(name=name):
close_project(name=name)
open_project(name=name)
for node_id in sensorindex:
sensor_coord[node_id] = get_node_coord(name=name, node_id=node_id)
close_project(name=name)
# print(sensor_coord)
return sensor_coord
if __name__ == '__main__':
sensor_coord = get_sensor_coord(name=project_info.name, sensor_num=20)
print(sensor_coord)
# '''
# 初始测压点布置根据灵敏度来布置计算初始情况下的校准过程的error
# '''
#
# # inp_file_realstr,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp包含管网的结构信息、节点、管道、泵等数据
# inp_file_real = './db_inp/bb.db.inp'
# # sensornumint,需要布置的传感器数量
# sensornum = 20
# # wn_realwntr.network.WaterNetworkModel,加载 EPANET 水力模型
# wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
# # sim_realwntr.sim.EpanetSimulator,创建一个水力仿真器对象
# sim_real = wntr.sim.EpanetSimulator(wn_real)
# # results_realwntr.sim.results.SimulationResults,运行仿真并返回结果
# results_real = sim_real.run_sim()
#
# # real_Clist[float],包含所有管道粗糙度的列表
# real_C = wn_real.query_link_attribute('roughness').tolist()
# # wn_fun1wn_func继承自 object创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
# wn_fun1 = wn_func(wn_real)
# # nodeslist[str],管网的节点名称列表
# nodes = wn_fun1.nodes
# # delnodeslist[str],被删除的节点(如水库、泵、阀门连接的节点等)
# delnodes = wn_fun1.delnodes
# # Coor_nodepandas.DataFrame
# Coor_node = getCoor(wn_real)
# Coor_node = Coor_node.drop(wn_fun1.delnodes)
# nodes = [node for node in wn_fun1.nodes if node not in delnodes]
# # coordinatespandas.Series存储所有节点的坐标类型为 Series索引为节点名称值为 (x, y) 坐标对
# coordinates = wn_fun1.coordinates
#
# # 随机产生监测点
# # junctionnumintnodes 的长度,表示节点的数量
# junctionnum = len(nodes)
# # random_numberslist[int],使用 random.sample 随机选择 sensornum20个节点的编号。它返回一个不重复的随机编号列表
# random_numbers = random.sample(range(junctionnum), sensornum)
# for i in range(sensornum):
# print(random_numbers[i])
#
# wn_fun1.get_Conn()
# # hLpandas.DataFrame水力距离矩阵表示每个节点到其他节点的水力阻力
# # Gnetworkx.DiGraph加权有向图表示管网的拓扑结构节点之间的边带有权重
# hL, G = wn_fun1.CtoS()
# # SSpandas.DataFrame灵敏度矩阵表示每个节点对管网变化如粗糙度、流量等的响应
# SS = wn_fun1.Jaco(hL)
# # groupdict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
# group = kgroup(Coor_node, sensornum)
# # wn_funSensorplacement继承自wn_func
# # 创建Sensorplacement类的实例传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
# wn_fun = Sensorplacement(wn_real, sensornum)
# wn_fun.__dict__.update(wn_fun1.__dict__)
# # sensorindexlist[str],初始传感器布置位置的节点名称
# # sensorindex_2list[str],根据分组选择的传感器位置
# sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
# print(str(sensornum), "个测压点,测压点位置:", sensorindex)
# # 分区画图
# colorlist = ['lightpink', 'coral', 'rosybrown', 'olive', 'powderblue', 'lightskyblue', 'steelblue', 'peachpuff','brown','silver','indigo','lime','gold','violet','maroon','navy','teal','magenta','cyan',
# 'burlywood', 'tan', 'slategrey', 'thistle', 'lightseagreen', 'lightgreen', 'red','blue','yellow','orange','purple','grey','green','pink','lightblue','beige','chartreuse','turquoise','lavender','fuchsia','coral']
# G = wn_real.to_graph()
# G = G.to_undirected() # 变为无向图
# pos = nx.get_node_attributes(G, 'pos')
# pass
# for i in range(0, sensornum):
# ax = plt.gca()
# ax.set_title(inp_file_real + str(sensornum))
# nodes = nx.draw_networkx_nodes(G, pos, nodelist=group[i], node_color=colorlist[i], node_size=20)
# nodes = nx.draw_networkx_nodes(G, pos,
# nodelist=sensorindex_2, node_color='black', node_size=70, node_shape='*'
# )
# edges = nx.draw_networkx_edges(G, pos)
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)
# ax.spines['bottom'].set_visible(False)
# ax.spines['left'].set_visible(False)
# plt.savefig(inp_file_real + str(sensornum) + ".png")
# plt.show()
#
# wntr.graphics.plot_network(wn_real, node_attribute=sensorindex_2, node_size=50, node_labels=False,
# title=inp_file_real + '_Projetion' + str(sensornum))
# plt.savefig(inp_file_real + '_S' + str(sensornum) + ".png")
# plt.show()

0
app/api/__init__.py Normal file
View File

0
app/api/v1/__init__.py Normal file
View File

View File

View File

View File

View File

View File

View File

View File

View File

20
app/api/v1/router.py Normal file
View File

@@ -0,0 +1,20 @@
from fastapi import APIRouter
from app.api.v1.endpoints import (
auth,
project,
network_elements,
simulation,
scada,
extension,
snapshots
)
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(project.router, prefix="/projects", tags=["projects"])
api_router.include_router(network_elements.router, prefix="/elements", tags=["network-elements"])
api_router.include_router(simulation.router, prefix="/simulation", tags=["simulation"])
api_router.include_router(scada.router, prefix="/scada", tags=["scada"])
api_router.include_router(extension.router, prefix="/extension", tags=["extension"])
api_router.include_router(snapshots.router, prefix="/snapshots", tags=["snapshots"])

0
app/audit/__init__.py Normal file
View File

0
app/auth/__init__.py Normal file
View File

21
app/auth/dependencies.py Normal file
View File

@@ -0,0 +1,21 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from app.core.config import settings
from jose import jwt, JWTError
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
return username

0
app/core/__init__.py Normal file
View File

3
app/core/audit.py Normal file
View File

@@ -0,0 +1,3 @@
# Placeholder for audit logic
async def log_audit_event(event_type: str, user_id: str, details: dict):
pass

30
app/core/config.py Normal file
View File

@@ -0,0 +1,30 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
PROJECT_NAME: str = "TJWater Server"
API_V1_STR: str = "/api/v1"
SECRET_KEY: str = "your-secret-key-here" # Change in production
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
# Database Config (PostgreSQL)
DB_NAME: str = "tjwater"
DB_HOST: str = "localhost"
DB_PORT: str = "5432"
DB_USER: str = "postgres"
DB_PASSWORD: str = "password"
# InfluxDB
INFLUXDB_URL: str = "http://localhost:8086"
INFLUXDB_TOKEN: str = "token"
INFLUXDB_ORG: str = "org"
INFLUXDB_BUCKET: str = "bucket"
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
class Config:
env_file = ".env"
settings = Settings()

9
app/core/encryption.py Normal file
View File

@@ -0,0 +1,9 @@
# Placeholder for encryption logic
class Encryptor:
def encrypt(self, data: str) -> str:
return data # Implement actual encryption
def decrypt(self, data: str) -> str:
return data # Implement actual decryption
encryptor = Encryptor()

23
app/core/security.py Normal file
View File

@@ -0,0 +1,23 @@
from datetime import datetime, timedelta
from typing import Optional, Union, Any
from jose import jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
return encoded_jwt
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)

0
app/crypto/__init__.py Normal file
View File

0
app/domain/__init__.py Normal file
View File

View File

View File

0
app/infra/__init__.py Normal file
View File

View File

0
app/infra/cache/__init__.py vendored Normal file
View File

0
app/infra/db/__init__.py Normal file
View File

View File

4964
app/infra/db/influxdb/api.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
# influxdb数据库连接信息
url = "http://127.0.0.1:8086" # 替换为你的InfluxDB实例地址
token = "kMPX2V5HsbzPpUT2B9HPBu1sTG1Emf-lPlT2UjxYnGAuocpXq_f_0lK4HHs-TbbKyjsZpICkMsyXG_V2D7P7yQ==" # 替换为你的InfluxDB Token
# _ENCODED_TOKEN = "eEdETTVSWnFSSkF1ekFHUy1vdFhVZEMyTkZkWTc1cUpBalJMcUFCNHA1V2NJSUFsSVVwT3BUOF95QTE2QU9IbUpXZXJ3UV8wOGd3Yjg0c3k0MmpuWlE9PQ=="
# token = base64.b64decode(_ENCODED_TOKEN).decode("utf-8")
org = "TJWATERORG" # 替换为你的Organization名称

View File

@@ -0,0 +1,33 @@
from influxdb_client import InfluxDBClient, Point, WriteOptions
from influxdb_client.client.query_api import QueryApi
import influxdb_info
# 配置 InfluxDB 连接
url = influxdb_info.url
token = influxdb_info.token
org = influxdb_info.org
bucket = "SCADA_data"
# 创建 InfluxDB 客户端
client = InfluxDBClient(url=url, token=token, org=org)
# 创建查询 API 对象
query_api = client.query_api()
# 构建查询语句
query = f'''
from(bucket: "{bucket}")
|> range(start: -1h)
'''
# 执行查询
result = query_api.query(query)
print(result)
# 处理查询结果
for table in result:
for record in table.records:
print(f"Time: {record.get_time()}, Value: {record.get_value()}, Measurement: {record.get_measurement()}, Field: {record.get_field()}")
# 关闭客户端连接
client.close()

View File

@@ -0,0 +1 @@
from .router import router

View File

@@ -0,0 +1,108 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import app.native.api.postgresql_info as postgresql_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = postgresql_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=5,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(f"PostgreSQL connection pool initialized for database: default")
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("PostgreSQL connection pool closed.")
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")

View File

@@ -0,0 +1,83 @@
import time
from typing import List, Optional
from fastapi.logger import logger
import postgresql_info
import psycopg
class InternalQueries:
@staticmethod
def get_links_by_property(
fields: Optional[List[str]] = None,
property_conditions: Optional[dict] = None,
db_name: str = None,
max_retries: int = 3,
) -> List[dict]:
"""
查询pg数据库中,pipes 的指定字段记录或根据属性筛选
:param fields: 要查询的字段列表,如 ["id", "diameter", "status"],默认查询所有字段
:param property: 可选的筛选条件字典,如 {"status": "Open"} 或 {"diameter": 300}
:param db_name: 数据库名称
:param max_retries: 最大重试次数
:return: 包含所有记录的列表,每条记录为一个字典
"""
# 如果未指定字段,查询所有字段
if not fields:
fields = [
"id",
"node1",
"node2",
"length",
"diameter",
"roughness",
"minor_loss",
"status",
]
for attempt in range(max_retries):
try:
conn_string = (
postgresql_info.get_pgconn_string(db_name=db_name)
if db_name
else postgresql_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
with conn.cursor() as cur:
# 构建SELECT子句
select_fields = ", ".join(fields)
base_query = f"""
SELECT {select_fields}
FROM public.pipes
"""
# 如果提供了筛选条件构建WHERE子句
if property_conditions:
conditions = []
params = []
for key, value in property_conditions.items():
conditions.append(f"{key} = %s")
params.append(value)
query = base_query + " WHERE " + " AND ".join(conditions)
cur.execute(query, params)
else:
cur.execute(base_query)
records = cur.fetchall()
# 将查询结果转换为字典列表
pipes = []
for record in records:
pipe_dict = {}
for idx, field in enumerate(fields):
pipe_dict[field] = record[idx]
pipes.append(pipe_dict)
return pipes
break # 成功
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise

View File

@@ -0,0 +1,90 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from psycopg import AsyncConnection
from .database import get_database_instance
from .scada_info import ScadaRepository
from .scheme import SchemeRepository
router = APIRouter(prefix="/postgresql", tags=["postgresql"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
@router.get("/scada-info")
async def get_scada_info_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有SCADA信息
"""
try:
# 使用ScadaRepository查询SCADA信息
scada_data = await ScadaRepository.get_scadas(conn)
return {"success": True, "data": scada_data, "count": len(scada_data)}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}"
)
@router.get("/scheme-list")
async def get_scheme_list_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有方案信息
"""
try:
# 使用SchemeRepository查询方案信息
scheme_data = await SchemeRepository.get_schemes(conn)
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}")
@router.get("/burst-locate-result")
async def get_burst_locate_result_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有爆管定位结果
"""
try:
# 使用SchemeRepository查询爆管定位结果
burst_data = await SchemeRepository.get_burst_locate_results(conn)
return {"success": True, "data": burst_data, "count": len(burst_data)}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"查询爆管定位结果时发生错误: {str(e)}"
)
@router.get("/burst-locate-result/{burst_incident}")
async def get_burst_locate_result_by_incident(
burst_incident: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""
根据 burst_incident 查询爆管定位结果
"""
try:
# 使用SchemeRepository查询爆管定位结果
return await SchemeRepository.get_burst_locate_result_by_incident(
conn, burst_incident
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"根据 burst_incident 查询爆管定位结果时发生错误: {str(e)}",
)

View File

@@ -0,0 +1,36 @@
from typing import List, Optional, Any
from psycopg import AsyncConnection
class ScadaRepository:
@staticmethod
async def get_scadas(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中,scada_info 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表,每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, associated_element_id, transmission_mode, transmission_frequency, reliability
FROM public.scada_info
"""
)
records = await cur.fetchall()
# 将查询结果转换为字典列表(假设 record 是字典)
scada_infos = []
for record in records:
scada_infos.append(
{
"id": record["id"], # 使用字典键
"type": record["type"],
"associated_element_id": record["associated_element_id"],
"transmission_mode": record["transmission_mode"],
"transmission_frequency": record["transmission_frequency"],
"reliability": record["reliability"],
}
)
return scada_infos

View File

@@ -0,0 +1,104 @@
from typing import List, Optional, Any
from psycopg import AsyncConnection
class SchemeRepository:
@staticmethod
async def get_schemes(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中, scheme_list 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表, 每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail
FROM public.scheme_list
"""
)
records = await cur.fetchall()
scheme_list = []
for record in records:
scheme_list.append(
{
"scheme_id": record["scheme_id"],
"scheme_name": record["scheme_name"],
"scheme_type": record["scheme_type"],
"username": record["username"],
"create_time": record["create_time"],
"scheme_start_time": record["scheme_start_time"],
"scheme_detail": record["scheme_detail"],
}
)
return scheme_list
@staticmethod
async def get_burst_locate_results(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中, burst_locate_result 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表, 每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, burst_incident, leakage, detect_time, locate_result
FROM public.burst_locate_result
"""
)
records = await cur.fetchall()
results = []
for record in records:
results.append(
{
"id": record["id"],
"type": record["type"],
"burst_incident": record["burst_incident"],
"leakage": record["leakage"],
"detect_time": record["detect_time"],
"locate_result": record["locate_result"],
}
)
return results
@staticmethod
async def get_burst_locate_result_by_incident(
conn: AsyncConnection, burst_incident: str
) -> List[dict]:
"""
根据 burst_incident 查询爆管定位结果
:param conn: 异步数据库连接
:param burst_incident: 爆管事件标识
:return: 包含匹配记录的列表
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, burst_incident, leakage, detect_time, locate_result
FROM public.burst_locate_result
WHERE burst_incident = %s
""",
(burst_incident,),
)
records = await cur.fetchall()
results = []
for record in records:
results.append(
{
"id": record["id"],
"type": record["type"],
"burst_incident": record["burst_incident"],
"leakage": record["leakage"],
"detect_time": record["detect_time"],
"locate_result": record["locate_result"],
}
)
return results

View File

@@ -0,0 +1,4 @@
from .router import router
from .database import *
from .timescaledb_info import *
from .composite_queries import CompositeQueries

View File

@@ -0,0 +1,606 @@
import time
from typing import List, Optional, Any, Dict, Tuple
from datetime import datetime, timedelta
from psycopg import AsyncConnection
import pandas as pd
import numpy as np
from api_ex.Fdataclean import clean_flow_data_df_kf
from api_ex.Pdataclean import clean_pressure_data_df_km
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
from postgresql.internal_queries import InternalQueries
from postgresql.scada_info import ScadaRepository as PostgreScadaRepository
from timescaledb.schemas.realtime import RealtimeRepository
from timescaledb.schemas.scheme import SchemeRepository
from timescaledb.schemas.scada import ScadaRepository
class CompositeQueries:
"""
复合查询类,提供跨表查询功能
"""
@staticmethod
async def get_scada_associated_realtime_simulation_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node 模拟值
根据传入的 SCADA device_ids找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
device_ids: SCADA 设备ID列表
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
for device_id in device_ids:
# 2. 根据 device_id 找到对应的 SCADA 信息
target_scada = None
for scada in scada_infos:
if scada["id"] == device_id:
target_scada = scada
break
if not target_scada:
raise ValueError(f"SCADA device {device_id} not found")
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
element_id = target_scada["associated_element_id"]
scada_type = target_scada["type"]
if scada_type.lower() == "pipe_flow":
# 查询 link 模拟数据
res = await RealtimeRepository.get_link_field_by_time_range(
timescale_conn, start_time, end_time, element_id, "flow"
)
elif scada_type.lower() == "pressure":
# 查询 node 模拟数据
res = await RealtimeRepository.get_node_field_by_time_range(
timescale_conn, start_time, end_time, element_id, "pressure"
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_scada_associated_scheme_simulation_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
scheme_type: str,
scheme_name: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node scheme 模拟值
根据传入的 SCADA device_ids找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
device_ids: SCADA 设备ID列表
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
for device_id in device_ids:
# 2. 根据 device_id 找到对应的 SCADA 信息
target_scada = None
for scada in scada_infos:
if scada["id"] == device_id:
target_scada = scada
break
if not target_scada:
raise ValueError(f"SCADA device {device_id} not found")
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
element_id = target_scada["associated_element_id"]
scada_type = target_scada["type"]
if scada_type.lower() == "pipe_flow":
# 查询 link 模拟数据
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
element_id,
"flow",
)
elif scada_type.lower() == "pressure":
# 查询 node 模拟数据
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
element_id,
"pressure",
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_realtime_simulation_data(
timescale_conn: AsyncConnection,
featureInfos: List[Tuple[str, str]],
start_time: datetime,
end_time: datetime,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 link/node 模拟值
根据传入的 featureInfos找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
for feature_id, type in featureInfos:
if type.lower() == "pipe":
# 查询 link 模拟数据
res = await RealtimeRepository.get_link_field_by_time_range(
timescale_conn, start_time, end_time, feature_id, "flow"
)
elif type.lower() == "junction":
# 查询 node 模拟数据
res = await RealtimeRepository.get_node_field_by_time_range(
timescale_conn, start_time, end_time, feature_id, "pressure"
)
else:
raise ValueError(f"Unknown type: {type}")
# 添加 scada_id 到每个数据项
for item in res:
item["feature_id"] = feature_id
result[feature_id] = res
return result
@staticmethod
async def get_scheme_simulation_data(
timescale_conn: AsyncConnection,
featureInfos: List[Tuple[str, str]],
start_time: datetime,
end_time: datetime,
scheme_type: str,
scheme_name: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 link/node scheme 模拟值
根据传入的 featureInfos找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
start_time: 开始时间
end_time: 结束时间
scheme_type: 工况类型
scheme_name: 工况名称
Returns:
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
Raises:
ValueError: 当类型无效时
"""
result = {}
for feature_id, type in featureInfos:
if type.lower() == "pipe":
# 查询 link 模拟数据
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
feature_id,
"flow",
)
elif type.lower() == "junction":
# 查询 node 模拟数据
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
feature_id,
"pressure",
)
else:
raise ValueError(f"Unknown type: {type}")
# 添加 feature_id 到每个数据项
for item in res:
item["feature_id"] = feature_id
result[feature_id] = res
return result
@staticmethod
async def get_element_associated_scada_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
element_id: str,
start_time: datetime,
end_time: datetime,
use_cleaned: bool = False,
) -> Optional[Any]:
"""
获取 link/node 关联的 SCADA 监测值
根据传入的 link/node id匹配 SCADA 信息,
如果存在关联的 SCADA device_id获取实际的监测数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
element_id: link 或 node 的 ID
start_time: 开始时间
end_time: 结束时间
use_cleaned: 是否使用清洗后的数据 (True: "cleaned_value", False: "monitored_value")
Returns:
SCADA 监测数据值,如果没有找到则返回 None
Raises:
ValueError: 当元素类型无效时
"""
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
# 2. 根据 element_type 和 element_id 找到关联的 SCADA 设备
associated_scada = None
for scada in scada_infos:
if scada["associated_element_id"] == element_id:
associated_scada = scada
break
if not associated_scada:
# 没有找到关联的 SCADA 设备
return None
# 3. 通过 SCADA device_id 获取监测数据
device_id = associated_scada["id"]
# 根据 use_cleaned 参数选择字段
data_field = "cleaned_value" if use_cleaned else "monitored_value"
# 保证 device_id 以列表形式传递
res = await ScadaRepository.get_scada_field_by_id_time_range(
timescale_conn, [device_id], start_time, end_time, data_field
)
# 将 device_id 替换为 element_id 返回
return {element_id: res.get(device_id, [])}
@staticmethod
async def clean_scada_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> str:
"""
清洗 SCADA 数据
根据 device_ids 查询 monitored_value清洗后更新 cleaned_value
Args:
timescale_conn: TimescaleDB 连接
postgres_conn: PostgreSQL 连接
device_ids: 设备 ID 列表
start_time: 开始时间
end_time: 结束时间
Returns:
"success" 或错误信息
"""
try:
# 获取所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
# 将列表转换为字典,以 device_id 为键
scada_device_info_dict = {info["id"]: info for info in scada_infos}
# 如果 device_ids 为空,则处理所有 SCADA 设备
if not device_ids:
device_ids = list(scada_device_info_dict.keys())
# 批量查询所有设备的数据
data = await ScadaRepository.get_scada_field_by_id_time_range(
timescale_conn, device_ids, start_time, end_time, "monitored_value"
)
if not data:
return "error: fetch none scada data" # 没有数据,直接返回
# 将嵌套字典转换为 DataFrame使用 time 作为索引
# data 格式: {device_id: [{"time": "...", "value": ...}, ...]}
all_records = []
for device_id, records in data.items():
for record in records:
all_records.append(
{
"time": record["time"],
"device_id": device_id,
"value": record["value"],
}
)
if not all_records:
return "error: fetch none scada data" # 没有数据,直接返回
# 创建 DataFrame 并透视,使 device_id 成为列
df_long = pd.DataFrame(all_records)
df = df_long.pivot(index="time", columns="device_id", values="value")
# 根据type分类设备
pressure_ids = [
id
for id in df.columns
if scada_device_info_dict.get(id, {}).get("type") == "pressure"
]
flow_ids = [
id
for id in df.columns
if scada_device_info_dict.get(id, {}).get("type") == "pipe_flow"
]
# 处理pressure数据
if pressure_ids:
pressure_df = df[pressure_ids]
# 重置索引,将 time 变为普通列
pressure_df = pressure_df.reset_index()
# 移除 time 列,准备输入给清洗方法
value_df = pressure_df.drop(columns=["time"])
# 调用清洗方法
cleaned_value_df = clean_pressure_data_df_km(value_df)
# 添加 time 列到首列
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
# 将清洗后的数据写回数据库
for device_id in pressure_ids:
if device_id in cleaned_df.columns:
cleaned_values = cleaned_df[device_id].tolist()
time_values = cleaned_df["time"].tolist()
for i, time_str in enumerate(time_values):
time_dt = datetime.fromisoformat(time_str)
value = cleaned_values[i]
await ScadaRepository.update_scada_field(
timescale_conn,
time_dt,
device_id,
"cleaned_value",
value,
)
# 处理flow数据
if flow_ids:
flow_df = df[flow_ids]
# 重置索引,将 time 变为普通列
flow_df = flow_df.reset_index()
# 移除 time 列,准备输入给清洗方法
value_df = flow_df.drop(columns=["time"])
# 调用清洗方法
cleaned_value_df = clean_flow_data_df_kf(value_df)
# 添加 time 列到首列
cleaned_df = pd.concat([flow_df["time"], cleaned_value_df], axis=1)
# 将清洗后的数据写回数据库
for device_id in flow_ids:
if device_id in cleaned_df.columns:
cleaned_values = cleaned_df[device_id].tolist()
time_values = cleaned_df["time"].tolist()
for i, time_str in enumerate(time_values):
time_dt = datetime.fromisoformat(time_str)
value = cleaned_values[i]
await ScadaRepository.update_scada_field(
timescale_conn,
time_dt,
device_id,
"cleaned_value",
value,
)
return "success"
except Exception as e:
return f"error: {str(e)}"
@staticmethod
async def predict_pipeline_health(
timescale_conn: AsyncConnection,
network_name: str,
query_time: datetime,
) -> List[Dict[str, Any]]:
"""
预测管道健康状况
根据管网名称和当前时间,查询管道信息和实时数据,
使用随机生存森林模型预测管道的生存概率
Args:
timescale_conn: TimescaleDB 异步连接
db_name: 管网数据库名称
query_time: 查询时间
property_conditions: 可选的管道筛选条件,如 {"diameter": 300}
Returns:
预测结果列表,每个元素包含 link_id 和对应的生存函数
Raises:
ValueError: 当参数无效或数据不足时
FileNotFoundError: 当模型文件未找到时
"""
try:
# 1. 准备时间范围查询时间前后1秒
start_time = query_time - timedelta(seconds=1)
end_time = query_time + timedelta(seconds=1)
# 2. 先查询流速数据velocity获取有数据的管道ID列表
velocity_data = await RealtimeRepository.get_links_field_by_time_range(
timescale_conn, start_time, end_time, "velocity"
)
if not velocity_data:
raise ValueError("未找到流速数据")
# 3. 只查询有流速数据的管道的基本信息
valid_link_ids = list(velocity_data.keys())
# 批量查询这些管道的详细信息
fields = ["id", "diameter", "node1", "node2"]
all_links = InternalQueries.get_links_by_property(
fields=fields,
db_name=network_name,
)
# 转换为字典以快速查找
links_dict = {link["id"]: link for link in all_links}
# 获取所有需要查询的节点ID
node_ids = set()
for link_id in valid_link_ids:
if link_id in links_dict:
link = links_dict[link_id]
node_ids.add(link["node1"])
node_ids.add(link["node2"])
# 4. 批量查询压力数据pressure
pressure_data = await RealtimeRepository.get_nodes_field_by_time_range(
timescale_conn, start_time, end_time, "pressure"
)
# 5. 组合数据结构
materials = []
diameters = []
velocities = []
pressures = []
link_ids = []
for link_id in valid_link_ids:
# 跳过不在管道字典中的ID如泵等其他元素
if link_id not in links_dict:
continue
link = links_dict[link_id]
diameter = link["diameter"]
node1 = link["node1"]
node2 = link["node2"]
# 获取流速数据
velocity_values = velocity_data[link_id]
velocity = velocity_values[-1]["value"] if velocity_values else 0
# 获取node1和node2的压力数据计算平均值
node1_pressure = 0
node2_pressure = 0
if node1 in pressure_data and pressure_data[node1]:
pressure_values = pressure_data[node1]
node1_pressure = (
pressure_values[-1]["value"] if pressure_values else 0
)
if node2 in pressure_data and pressure_data[node2]:
pressure_values = pressure_data[node2]
node2_pressure = (
pressure_values[-1]["value"] if pressure_values else 0
)
# 计算平均压力
avg_pressure = (node1_pressure + node2_pressure) / 2
# 添加到列表
link_ids.append(link_id)
materials.append(7) # 默认材料类型为7可根据实际情况调整
diameters.append(diameter)
velocities.append(velocity)
pressures.append(avg_pressure)
if not link_ids:
raise ValueError("没有找到有效的管道数据用于预测")
# 6. 创建DataFrame
data = pd.DataFrame(
{
"Material": materials,
"Diameter": diameters,
"Flow Velocity": velocities,
"Pressure": pressures,
}
)
# 7. 使用PipelineHealthAnalyzer进行预测
analyzer = PipelineHealthAnalyzer(
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
)
survival_functions = analyzer.predict_survival(data)
# 8. 组合结果
results = []
for i, link_id in enumerate(link_ids):
sf = survival_functions[i]
results.append(
{
"link_id": link_id,
"survival_function": {
"x": sf.x.tolist(), # 时间点(年)
"y": sf.y.tolist(), # 生存概率
},
}
)
return results
except Exception as e:
raise ValueError(f"管道健康预测失败: {str(e)}")

View File

@@ -0,0 +1,115 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import app.infra.db.timescaledb.timescaledb_info as timescaledb_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = timescaledb_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=5,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
)
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("TimescaleDB connection pool closed.")
def get_pgconn_string(self, db_name=None):
"""Get the TimescaleDB connection string."""
target_db_name = db_name or self.db_name
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")

View File

@@ -0,0 +1,122 @@
from typing import List
from fastapi.logger import logger
from timescaledb.schemas.scheme import SchemeRepository
from timescaledb.schemas.realtime import RealtimeRepository
import timescaledb.timescaledb_info as timescaledb_info
from datetime import datetime, timedelta
from timescaledb.schemas.scada import ScadaRepository
import psycopg
import time
class InternalStorage:
@staticmethod
def store_realtime_simulation(
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
db_name: str = None,
max_retries: int = 3,
):
"""存储实时模拟结果"""
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
RealtimeRepository.store_realtime_simulation_result_sync(
conn, node_result_list, link_result_list, result_start_time
)
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常
@staticmethod
def store_scheme_simulation(
scheme_type: str,
scheme_name: str,
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
num_periods: int = 1,
db_name: str = None,
max_retries: int = 3,
):
"""存储方案模拟结果"""
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
SchemeRepository.store_scheme_simulation_result_sync(
conn,
scheme_type,
scheme_name,
node_result_list,
link_result_list,
result_start_time,
num_periods,
)
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常
class InternalQueries:
@staticmethod
def query_scada_by_ids_time(
device_ids: List[str],
query_time: str,
db_name: str = None,
max_retries: int = 3,
) -> dict:
"""查询指定时间点的 SCADA 数据"""
# 解析时间,假设是北京时间
beijing_time = datetime.fromisoformat(query_time)
start_time = beijing_time - timedelta(seconds=1)
end_time = beijing_time + timedelta(seconds=1)
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
rows = ScadaRepository.get_scada_by_ids_time_range_sync(
conn, device_ids, start_time, end_time
)
# 处理结果,返回每个 device_id 的第一个值
result = {}
for device_id in device_ids:
device_rows = [
row for row in rows if row["device_id"] == device_id
]
if device_rows:
result[device_id] = device_rows[0]["monitored_value"]
else:
result[device_id] = None
return result
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise

View File

@@ -0,0 +1,627 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
from .schemas.scheme import SchemeRepository
from .schemas.scada import ScadaRepository
from .composite_queries import CompositeQueries
from postgresql.database import get_database_instance as get_postgres_database_instance
router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
# PostgreSQL 数据库连接依赖函数
async def get_postgres_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
)
):
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_postgres_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
# --- Realtime Endpoints ---
@router.post("/realtime/links/batch", status_code=201)
async def insert_realtime_links(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/links")
async def get_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_links_by_time_range(conn, start_time, end_time)
@router.delete("/realtime/links")
async def delete_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_links_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.patch("/realtime/links/{link_id}/field")
async def update_realtime_link_field(
link_id: str,
time: datetime,
field: str,
value: float, # Assuming float for now, could be Any but FastAPI needs type
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await RealtimeRepository.update_link_field(conn, time, link_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/realtime/nodes/batch", status_code=201)
async def insert_realtime_nodes(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/nodes")
async def get_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_nodes_by_time_range(conn, start_time, end_time)
@router.delete("/realtime/nodes")
async def delete_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_nodes_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.post("/realtime/simulation/store", status_code=201)
async def store_realtime_simulation_result(
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Store realtime simulation results to TimescaleDB"""
await RealtimeRepository.store_realtime_simulation_result(
conn, node_result_list, link_result_list, result_start_time
)
return {"message": "Simulation results stored successfully"}
@router.get("/realtime/query/by-time-property")
async def query_realtime_records_by_time_property(
query_time: str,
type: str,
property: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query all realtime records by time and property"""
try:
results = await RealtimeRepository.query_all_record_by_time_property(
conn, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/realtime/query/by-id-time")
async def query_realtime_simulation_by_id_time(
id: str,
type: str,
query_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query realtime simulation results by id and time"""
try:
results = await RealtimeRepository.query_simulation_result_by_id_time(
conn, id, type, query_time
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# --- Scheme Endpoints ---
@router.post("/scheme/links/batch", status_code=201)
async def insert_scheme_links(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/links")
async def get_scheme_links(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await SchemeRepository.get_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
@router.get("/scheme/links/{link_id}/field")
async def get_scheme_link_field(
scheme_type: str,
scheme_name: str,
link_id: str,
start_time: datetime,
end_time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_link_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, link_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/links/{link_id}/field")
async def update_scheme_link_field(
scheme_type: str,
scheme_name: str,
link_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_link_field(
conn, time, scheme_type, scheme_name, link_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/links")
async def delete_scheme_links(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/nodes/batch", status_code=201)
async def insert_scheme_nodes(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/nodes/{node_id}/field")
async def get_scheme_node_field(
scheme_type: str,
scheme_name: str,
node_id: str,
start_time: datetime,
end_time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_node_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, node_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/nodes/{node_id}/field")
async def update_scheme_node_field(
scheme_type: str,
scheme_name: str,
node_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_node_field(
conn, time, scheme_type, scheme_name, node_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/nodes")
async def delete_scheme_nodes(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_nodes_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/simulation/store", status_code=201)
async def store_scheme_simulation_result(
scheme_type: str,
scheme_name: str,
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Store scheme simulation results to TimescaleDB"""
await SchemeRepository.store_scheme_simulation_result(
conn,
scheme_type,
scheme_name,
node_result_list,
link_result_list,
result_start_time,
)
return {"message": "Scheme simulation results stored successfully"}
@router.get("/scheme/query/by-scheme-time-property")
async def query_scheme_records_by_scheme_time_property(
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query all scheme records by scheme, time and property"""
try:
results = await SchemeRepository.query_all_record_by_scheme_time_property(
conn, scheme_type, scheme_name, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/scheme/query/by-id-time")
async def query_scheme_simulation_by_id_time(
scheme_type: str,
scheme_name: str,
id: str,
type: str,
query_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query scheme simulation results by id and time"""
try:
result = await SchemeRepository.query_scheme_simulation_result_by_id_time(
conn, scheme_type, scheme_name, id, type, query_time
)
return {"result": result}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# --- SCADA Endpoints ---
@router.post("/scada/batch", status_code=201)
async def insert_scada_data(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await ScadaRepository.insert_scada_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scada/by-ids-time-range")
async def get_scada_by_ids_time_range(
start_time: datetime,
end_time: datetime,
device_ids: str,
conn: AsyncConnection = Depends(get_database_connection),
):
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()] if device_ids else []
)
return await ScadaRepository.get_scada_by_ids_time_range(
conn, device_ids_list, start_time, end_time
)
@router.get("/scada/by-ids-field-time-range")
async def get_scada_field_by_ids_time_range(
start_time: datetime,
end_time: datetime,
field: str,
device_ids: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await ScadaRepository.get_scada_field_by_id_time_range(
conn, device_ids_list, start_time, end_time, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scada/{device_id}/field")
async def update_scada_field(
device_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await ScadaRepository.update_scada_field(conn, time, device_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scada/by-id-time-range")
async def delete_scada_data(
device_id: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await ScadaRepository.delete_scada_by_id_time_range(
conn, device_id, start_time, end_time
)
return {"message": "Deleted successfully"}
# --- Composite Query Endpoints ---
@router.get("/composite/scada-simulation")
async def get_scada_associated_simulation_data(
start_time: datetime,
end_time: datetime,
device_ids: str,
scheme_type: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取 SCADA 关联的 link/node 模拟值
根据传入的 SCADA device_ids找到关联的 link/node
并根据对应的 type查询对应的模拟数据
"""
try:
# 手动解析 device_ids 为 List[str],去除空格
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
if scheme_type and scheme_name:
result = await CompositeQueries.get_scada_associated_scheme_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list, # 使用解析后的列表
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = (
await CompositeQueries.get_scada_associated_realtime_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list, # 使用解析后的列表
start_time,
end_time,
)
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-simulation")
async def get_feature_simulation_data(
start_time: datetime,
end_time: datetime,
feature_infos: str = Query(
..., description="特征信息,格式: id1:type1,id2:type2type为pipe或junction"
),
scheme_type: str = Query(None, description="指定方案类型,若为空则查询实时数据"),
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
):
"""
获取 link/node 模拟值
根据传入的 featureInfos找到关联的 link/node
并根据对应的 type查询对应的模拟数据
Args:
feature_infos: 格式为 "element_id1:type1,element_id2:type2"
例如: "P1:pipe,J1:junction"
"""
try:
# 解析 feature_infos 为 List[Tuple[str, str]]
feature_infos_list = []
if feature_infos:
for item in feature_infos.split(","):
item = item.strip()
if ":" in item:
element_id, element_type = item.split(":", 1)
feature_infos_list.append(
(element_id.strip(), element_type.strip())
)
if not feature_infos_list:
raise HTTPException(status_code=400, detail="feature_infos cannot be empty")
if scheme_type and scheme_name:
result = await CompositeQueries.get_scheme_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = await CompositeQueries.get_realtime_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-scada")
async def get_element_associated_scada_data(
element_id: str,
start_time: datetime,
end_time: datetime,
use_cleaned: bool = Query(False, description="是否使用清洗后的数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取 link/node 关联的 SCADA 监测值
根据传入的 link/node id匹配 SCADA 信息,
如果存在关联的 SCADA device_id获取实际的监测数据
"""
try:
result = await CompositeQueries.get_element_associated_scada_data(
timescale_conn, postgres_conn, element_id, start_time, end_time, use_cleaned
)
if result is None:
raise HTTPException(
status_code=404, detail="No associated SCADA data found"
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/composite/clean-scada")
async def clean_scada_data(
device_ids: str,
start_time: datetime = Query(...),
end_time: datetime = Query(...),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
清洗 SCADA 数据
根据 device_ids 查询 monitored_value清洗后更新 cleaned_value
"""
try:
if device_ids == "all":
device_ids_list = []
else:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await CompositeQueries.clean_scada_data(
timescale_conn, postgres_conn, device_ids_list, start_time, end_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/pipeline-health-prediction")
async def predict_pipeline_health(
query_time: datetime = Query(..., description="查询时间"),
network_name: str = Query(..., description="管网数据库名称"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
):
"""
预测管道健康状况
根据管网名称和当前时间,查询管道信息和实时数据,
使用随机生存森林模型预测管道的生存概率
Args:
query_time: 查询时间
db_name: 管网数据库名称
Returns:
预测结果列表,每个元素包含 link_id 和对应的生存函数
"""
try:
return await CompositeQueries.predict_pipeline_health(
timescale_conn, network_name, query_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")

View File

@@ -0,0 +1,647 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
class RealtimeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, link_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, link_id))
@staticmethod
async def delete_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, node_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str
) -> dict:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, node_id))
@staticmethod
async def delete_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_realtime_simulation_result(
conn: AsyncConnection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB.
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await RealtimeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await RealtimeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_realtime_simulation_result_sync(
conn: Connection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
RealtimeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
RealtimeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_time_property(
conn: AsyncConnection,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by time and property from TimescaleDB.
Args:
conn: Database connection
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await RealtimeRepository.get_nodes_field_by_time_range(
conn, start_time, end_time, property
)
elif type.lower() == "link":
data = await RealtimeRepository.get_links_field_by_time_range(
conn, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_simulation_result_by_id_time(
conn: AsyncConnection,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await RealtimeRepository.get_node_by_time_range(
conn, start_time, end_time, id
)
elif type.lower() == "link":
return await RealtimeRepository.get_link_by_time_range(
conn, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")

View File

@@ -0,0 +1,106 @@
from typing import List, Any
from datetime import datetime
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
class ScadaRepository:
@staticmethod
async def insert_scada_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
async with conn.cursor() as cur:
async with cur.copy(
"COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["device_id"],
item.get("monitored_value"),
item.get("cleaned_value"),
)
)
@staticmethod
async def get_scada_by_ids_time_range(
conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
(device_ids, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
def get_scada_by_ids_time_range_sync(
conn: Connection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> List[dict]:
with conn.cursor() as cur:
cur.execute(
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
(device_ids, start_time, end_time),
)
return cur.fetchall()
@staticmethod
async def get_scada_field_by_id_time_range(
conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT device_id, time, {} FROM scada.scada_data WHERE time >= %s AND time <= %s AND device_id = ANY(%s)"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, device_ids))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["device_id"]].append({
"time": row["time"].isoformat(),
"value": row[field]
})
return dict(result)
@staticmethod
async def update_scada_field(
conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any
):
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, device_id))
@staticmethod
async def delete_scada_by_id_time_range(
conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s",
(device_id, start_time, end_time),
)

View File

@@ -0,0 +1,760 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
import globals
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
class SchemeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, link_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, link_id))
@staticmethod
async def delete_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, node_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, node_id))
@staticmethod
async def delete_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_scheme_simulation_result(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await SchemeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await SchemeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_scheme_simulation_result_sync(
conn: Connection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
SchemeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
SchemeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_scheme_time_property(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by scheme, time and property from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await SchemeRepository.get_nodes_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
elif type.lower() == "link":
data = await SchemeRepository.get_links_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_scheme_simulation_result_by_id_time(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query scheme simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await SchemeRepository.get_node_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
elif type.lower() == "link":
return await SchemeRepository.get_link_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")

View File

@@ -0,0 +1,36 @@
from dotenv import load_dotenv
import os
load_dotenv()
pg_name = os.getenv("TIMESCALEDB_DB_NAME")
pg_host = os.getenv("TIMESCALEDB_DB_HOST")
pg_port = os.getenv("TIMESCALEDB_DB_PORT")
pg_user = os.getenv("TIMESCALEDB_DB_USER")
pg_password = os.getenv("TIMESCALEDB_DB_PASSWORD")
def get_pgconn_string(
db_name=pg_name,
db_host=pg_host,
db_port=pg_port,
db_user=pg_user,
db_password=pg_password,
):
"""返回 PostgreSQL 连接字符串"""
return f"dbname={db_name} host={db_host} port={db_port} user={db_user} password={db_password}"
def get_pg_config():
"""返回 PostgreSQL 配置变量的字典"""
return {
"name": pg_name,
"host": pg_host,
"port": pg_port,
"user": pg_user,
}
def get_pg_password():
"""返回密码(谨慎使用)"""
return pg_password

View File

4244
app/main.py Normal file

File diff suppressed because it is too large Load Diff

0
app/native/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show More