重构现代化 FastAPI 后端项目框架
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/algorithms/__init__.py
Normal file
0
app/algorithms/__init__.py
Normal file
289
app/algorithms/api_ex/Fdataclean.py
Normal file
289
app/algorithms/api_ex/Fdataclean.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# ...existing code...
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pykalman import KalmanFilter
|
||||
import os
|
||||
|
||||
|
||||
def clean_flow_data_kf(input_csv_path: str, show_plot: bool = False) -> str:
|
||||
"""
|
||||
读取 input_csv_path 中的每列时间序列,使用一维 Kalman 滤波平滑并用预测值替换基于 3σ 检测出的异常点。
|
||||
保存输出为:<input_filename>_cleaned.xlsx(与输入同目录),并返回输出文件的绝对路径。
|
||||
仅保留输入文件路径作为参数(按要求)。
|
||||
"""
|
||||
# 读取 CSV
|
||||
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
# 存储 Kalman 平滑结果
|
||||
data_kf = pd.DataFrame(index=data.index, columns=data.columns)
|
||||
# 平滑每一列
|
||||
for col in data.columns:
|
||||
observations = pd.Series(data[col].values).ffill().bfill()
|
||||
if observations.isna().any():
|
||||
observations = observations.fillna(observations.mean())
|
||||
obs = observations.values.astype(float)
|
||||
|
||||
kf = KalmanFilter(
|
||||
transition_matrices=[1],
|
||||
observation_matrices=[1],
|
||||
initial_state_mean=float(obs[0]),
|
||||
initial_state_covariance=1,
|
||||
observation_covariance=1,
|
||||
transition_covariance=0.01,
|
||||
)
|
||||
# 跳过EM学习,使用固定参数以提高性能
|
||||
state_means, _ = kf.smooth(obs)
|
||||
data_kf[col] = state_means.flatten()
|
||||
|
||||
# 计算残差并用IQR检测异常(更稳健的方法)
|
||||
residuals = data - data_kf
|
||||
residual_thresholds = {}
|
||||
for col in data.columns:
|
||||
res_values = residuals[col].dropna().values # 移除NaN以计算IQR
|
||||
q1 = np.percentile(res_values, 25)
|
||||
q3 = np.percentile(res_values, 75)
|
||||
iqr = q3 - q1
|
||||
lower_threshold = q1 - 1.5 * iqr
|
||||
upper_threshold = q3 + 1.5 * iqr
|
||||
residual_thresholds[col] = (lower_threshold, upper_threshold)
|
||||
|
||||
cleaned_data = data.copy()
|
||||
anomalies_info = {}
|
||||
for col in data.columns:
|
||||
lower, upper = residual_thresholds[col]
|
||||
sensor_residuals = residuals[col]
|
||||
anomaly_mask = (sensor_residuals < lower) | (sensor_residuals > upper)
|
||||
anomaly_idx = data.index[anomaly_mask.fillna(False)]
|
||||
anomalies_info[col] = pd.DataFrame(
|
||||
{
|
||||
"Observed": data.loc[anomaly_idx, col],
|
||||
"Kalman_Predicted": data_kf.loc[anomaly_idx, col],
|
||||
"Residual": sensor_residuals.loc[anomaly_idx],
|
||||
}
|
||||
)
|
||||
cleaned_data.loc[anomaly_idx, f"{col}_cleaned"] = data_kf.loc[anomaly_idx, col]
|
||||
|
||||
# 构造输出文件名:在输入文件名基础上加后缀 _cleaned.xlsx
|
||||
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
|
||||
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
|
||||
output_filename = f"{input_base}_cleaned.xlsx"
|
||||
output_path = os.path.join(input_dir, output_filename)
|
||||
|
||||
# 覆盖同名文件
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path)
|
||||
cleaned_data.to_excel(output_path, index=False)
|
||||
|
||||
# 可选可视化(第一个传感器)
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
if show_plot and len(data.columns) > 0:
|
||||
sensor_to_plot = data.columns[0]
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(
|
||||
data.index,
|
||||
data[sensor_to_plot],
|
||||
label="监测值",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
alpha=0.7,
|
||||
)
|
||||
plt.plot(
|
||||
data.index, data_kf[sensor_to_plot], label="Kalman滤波预测值", linewidth=2
|
||||
)
|
||||
anomaly_idx = anomalies_info[sensor_to_plot].index
|
||||
if len(anomaly_idx) > 0:
|
||||
plt.plot(
|
||||
anomaly_idx,
|
||||
data[sensor_to_plot].loc[anomaly_idx],
|
||||
"ro",
|
||||
markersize=8,
|
||||
label="监测值异常点",
|
||||
)
|
||||
plt.plot(
|
||||
anomaly_idx,
|
||||
data_kf[sensor_to_plot].loc[anomaly_idx],
|
||||
"go",
|
||||
markersize=8,
|
||||
label="Kalman修复值",
|
||||
)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("监测值")
|
||||
plt.title(f"{sensor_to_plot}:观测值与Kalman滤波预测值(异常点标记)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
# 返回输出文件的绝对路径
|
||||
return os.path.abspath(output_path)
|
||||
|
||||
|
||||
def clean_flow_data_df_kf(data: pd.DataFrame, show_plot: bool = False) -> dict:
|
||||
"""
|
||||
接收一个 DataFrame 数据结构,使用一维 Kalman 滤波平滑并用预测值替换基于 IQR 检测出的异常点。
|
||||
区分合理的0值(流量转换)和异常的0值(连续多个0或孤立0)。
|
||||
返回完整的清洗后的字典数据结构。
|
||||
"""
|
||||
# 使用传入的 DataFrame
|
||||
data = data.copy()
|
||||
# 替换0值,填充NaN值
|
||||
data_filled = data.replace(0, np.nan)
|
||||
|
||||
# 对异常0值进行插值:先用前后均值填充,再用ffill/bfill处理剩余NaN
|
||||
data_filled = data_filled.interpolate(method="linear", limit_direction="both")
|
||||
|
||||
# 处理剩余的0值和NaN值
|
||||
data_filled = data_filled.ffill().bfill()
|
||||
|
||||
# 存储 Kalman 平滑结果
|
||||
data_kf = pd.DataFrame(index=data_filled.index, columns=data_filled.columns)
|
||||
# 平滑每一列
|
||||
for col in data_filled.columns:
|
||||
observations = pd.Series(data_filled[col].values).ffill().bfill()
|
||||
if observations.isna().any():
|
||||
observations = observations.fillna(observations.mean())
|
||||
obs = observations.values.astype(float)
|
||||
|
||||
kf = KalmanFilter(
|
||||
transition_matrices=[1],
|
||||
observation_matrices=[1],
|
||||
initial_state_mean=float(obs[0]),
|
||||
initial_state_covariance=1,
|
||||
observation_covariance=10,
|
||||
transition_covariance=10,
|
||||
)
|
||||
state_means, _ = kf.smooth(obs)
|
||||
data_kf[col] = state_means.flatten()
|
||||
|
||||
# 计算残差并用IQR检测异常
|
||||
residuals = data_filled - data_kf
|
||||
residual_thresholds = {}
|
||||
for col in data_filled.columns:
|
||||
res_values = residuals[col].dropna().values
|
||||
q1 = np.percentile(res_values, 25)
|
||||
q3 = np.percentile(res_values, 75)
|
||||
iqr = q3 - q1
|
||||
lower_threshold = q1 - 1.5 * iqr
|
||||
upper_threshold = q3 + 1.5 * iqr
|
||||
residual_thresholds[col] = (lower_threshold, upper_threshold)
|
||||
|
||||
# 创建完整的修复数据
|
||||
cleaned_data = data_filled.copy()
|
||||
anomalies_info = {}
|
||||
|
||||
for col in data_filled.columns:
|
||||
lower, upper = residual_thresholds[col]
|
||||
sensor_residuals = residuals[col]
|
||||
anomaly_mask = (sensor_residuals < lower) | (sensor_residuals > upper)
|
||||
anomaly_idx = data_filled.index[anomaly_mask.fillna(False)]
|
||||
|
||||
anomalies_info[col] = pd.DataFrame(
|
||||
{
|
||||
"Observed": data_filled.loc[anomaly_idx, col],
|
||||
"Kalman_Predicted": data_kf.loc[anomaly_idx, col],
|
||||
"Residual": sensor_residuals.loc[anomaly_idx],
|
||||
}
|
||||
)
|
||||
|
||||
# 直接在原列上替换异常值为 Kalman 预测值
|
||||
cleaned_data.loc[anomaly_idx, col] = data_kf.loc[anomaly_idx, col]
|
||||
|
||||
# 可选可视化
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
if show_plot and len(data.columns) > 0:
|
||||
sensor_to_plot = data.columns[0]
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(
|
||||
data.index,
|
||||
data[sensor_to_plot],
|
||||
label="原始监测值",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
alpha=0.7,
|
||||
)
|
||||
abnormal_zero_idx = data.index[data_filled[sensor_to_plot].isna()]
|
||||
if len(abnormal_zero_idx) > 0:
|
||||
plt.plot(
|
||||
abnormal_zero_idx,
|
||||
data[sensor_to_plot].loc[abnormal_zero_idx],
|
||||
"mo",
|
||||
markersize=8,
|
||||
label="异常0值",
|
||||
)
|
||||
plt.plot(
|
||||
data.index, data_kf[sensor_to_plot], label="Kalman滤波预测值", linewidth=2
|
||||
)
|
||||
anomaly_idx = anomalies_info[sensor_to_plot].index
|
||||
if len(anomaly_idx) > 0:
|
||||
plt.plot(
|
||||
anomaly_idx,
|
||||
data_filled[sensor_to_plot].loc[anomaly_idx],
|
||||
"ro",
|
||||
markersize=8,
|
||||
label="IQR异常点",
|
||||
)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("流量值")
|
||||
plt.title(f"{sensor_to_plot}:原始数据与异常检测")
|
||||
plt.legend()
|
||||
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(
|
||||
data.index,
|
||||
cleaned_data[sensor_to_plot],
|
||||
label="修复后监测值",
|
||||
marker="o",
|
||||
markersize=3,
|
||||
color="green",
|
||||
)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("流量值")
|
||||
plt.title(f"{sensor_to_plot}:修复后数据")
|
||||
plt.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# 返回完整的修复后字典
|
||||
return cleaned_data
|
||||
|
||||
|
||||
# # 测试
|
||||
# if __name__ == "__main__":
|
||||
# # 默认:脚本目录下同名 CSV 文件
|
||||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# default_csv = os.path.join(script_dir, "pipe_flow_data_to_clean2.0.csv")
|
||||
# out = clean_flow_data_kf(default_csv)
|
||||
# print("清洗后的数据已保存到:", out)
|
||||
|
||||
# 测试 clean_flow_data_dict 函数
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
# 读取 szh_flow_scada.csv 文件
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
csv_path = os.path.join(script_dir, "szh_flow_scada.csv")
|
||||
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
|
||||
# 排除 Time 列,随机选择 5 列
|
||||
columns_to_exclude = ["Time"]
|
||||
available_columns = [col for col in data.columns if col not in columns_to_exclude]
|
||||
selected_columns = random.sample(available_columns, 1)
|
||||
|
||||
# 将选中的列转换为字典
|
||||
data_dict = {col: data[col].tolist() for col in selected_columns}
|
||||
|
||||
print("选中的列:", selected_columns)
|
||||
print("原始数据长度:", len(data_dict[selected_columns[0]]))
|
||||
|
||||
# 调用函数进行清洗
|
||||
cleaned_dict = clean_flow_data_df_kf(data_dict, show_plot=True)
|
||||
# 将清洗后的字典写回 CSV
|
||||
out_csv = os.path.join(script_dir, f"{selected_columns[0]}_clean.csv")
|
||||
pd.DataFrame(cleaned_dict).to_csv(out_csv, index=False, encoding="utf-8-sig")
|
||||
print("已保存清洗结果到:", out_csv)
|
||||
print("清洗后的字典键:", list(cleaned_dict.keys()))
|
||||
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
|
||||
print("测试完成:函数运行正常")
|
||||
238
app/algorithms/api_ex/Pdataclean.py
Normal file
238
app/algorithms/api_ex/Pdataclean.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.impute import SimpleImputer
|
||||
import os
|
||||
|
||||
|
||||
def clean_pressure_data_km(input_csv_path: str, show_plot: bool = False) -> str:
|
||||
"""
|
||||
读取输入 CSV,基于 KMeans 检测异常并用滚动平均修复。输出为 <input_basename>_cleaned.xlsx(同目录)。
|
||||
原始数据在 sheet 'raw_pressure_data',处理后数据在 sheet 'cleaned_pressusre_data'。
|
||||
返回输出文件的绝对路径。
|
||||
"""
|
||||
# 读取 CSV
|
||||
input_csv_path = os.path.abspath(input_csv_path)
|
||||
data = pd.read_csv(input_csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
# 标准化
|
||||
data_norm = (data - data.mean()) / data.std()
|
||||
|
||||
# 聚类与异常检测
|
||||
k = 3
|
||||
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
|
||||
clusters = kmeans.fit_predict(data_norm)
|
||||
centers = kmeans.cluster_centers_
|
||||
|
||||
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
|
||||
threshold = distances.mean() + 3 * distances.std()
|
||||
|
||||
anomaly_pos = np.where(distances > threshold)[0]
|
||||
anomaly_indices = data.index[anomaly_pos]
|
||||
|
||||
anomaly_details = {}
|
||||
for pos in anomaly_pos:
|
||||
row_norm = data_norm.iloc[pos]
|
||||
cluster_idx = clusters[pos]
|
||||
center = centers[cluster_idx]
|
||||
diff = abs(row_norm - center)
|
||||
main_sensor = diff.idxmax()
|
||||
anomaly_details[data.index[pos]] = main_sensor
|
||||
|
||||
# 修复:滚动平均(窗口可调)
|
||||
data_rolled = data.rolling(window=13, center=True, min_periods=1).mean()
|
||||
data_repaired = data.copy()
|
||||
for pos in anomaly_pos:
|
||||
label = data.index[pos]
|
||||
sensor = anomaly_details[label]
|
||||
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
|
||||
|
||||
# 可选可视化(使用位置作为 x 轴)
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
if show_plot and len(data.columns) > 0:
|
||||
n = len(data)
|
||||
time = np.arange(n)
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data.columns:
|
||||
plt.plot(time, data[col].values, marker="o", markersize=3, label=col)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
plt.plot(pos, data.iloc[pos][sensor], "ro", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("压力监测值")
|
||||
plt.title("各传感器折线图(红色标记主要异常点)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data_repaired.columns:
|
||||
plt.plot(
|
||||
time, data_repaired[col].values, marker="o", markersize=3, label=col
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("修复后压力监测值")
|
||||
plt.title("修复后各传感器折线图(绿色标记修复值)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
# 保存到 Excel:两个 sheet
|
||||
input_dir = os.path.dirname(os.path.abspath(input_csv_path))
|
||||
input_base = os.path.splitext(os.path.basename(input_csv_path))[0]
|
||||
output_filename = f"{input_base}_cleaned.xlsx"
|
||||
output_path = os.path.join(input_dir, output_filename)
|
||||
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path) # 覆盖同名文件
|
||||
with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
|
||||
data.to_excel(writer, sheet_name="raw_pressure_data", index=False)
|
||||
data_repaired.to_excel(writer, sheet_name="cleaned_pressusre_data", index=False)
|
||||
|
||||
# 返回输出文件的绝对路径
|
||||
return os.path.abspath(output_path)
|
||||
|
||||
|
||||
def clean_pressure_data_df_km(data: pd.DataFrame, show_plot: bool = False) -> dict:
|
||||
"""
|
||||
接收一个 DataFrame 数据结构,使用KMeans聚类检测异常并用滚动平均修复。
|
||||
返回清洗后的字典数据结构。
|
||||
"""
|
||||
# 使用传入的 DataFrame
|
||||
data = data.copy()
|
||||
# 填充NaN值
|
||||
data = data.ffill().bfill()
|
||||
# 异常值预处理
|
||||
# 将0值替换为NaN,然后用线性插值填充
|
||||
data_filled = data.replace(0, np.nan)
|
||||
data_filled = data_filled.interpolate(method="linear", limit_direction="both")
|
||||
# 如果仍有NaN(全为0的列),用前后值填充
|
||||
data_filled = data_filled.ffill().bfill()
|
||||
|
||||
# 标准化(使用填充后的数据)
|
||||
data_norm = (data_filled - data_filled.mean()) / data_filled.std()
|
||||
|
||||
# 添加:处理标准化后的 NaN(例如,标准差为0的列),防止异常数据,时间段内所有数据都相同导致计算结果为 NaN
|
||||
imputer = SimpleImputer(
|
||||
strategy="constant", fill_value=0, keep_empty_features=True
|
||||
) # 用 0 填充 NaN,包括全 NaN,并保留空特征
|
||||
data_norm = pd.DataFrame(
|
||||
imputer.fit_transform(data_norm),
|
||||
columns=data_norm.columns,
|
||||
index=data_norm.index,
|
||||
)
|
||||
|
||||
# 聚类与异常检测
|
||||
k = 3
|
||||
kmeans = KMeans(n_clusters=k, init="k-means++", n_init=50, random_state=42)
|
||||
clusters = kmeans.fit_predict(data_norm)
|
||||
centers = kmeans.cluster_centers_
|
||||
|
||||
distances = np.linalg.norm(data_norm.values - centers[clusters], axis=1)
|
||||
threshold = distances.mean() + 3 * distances.std()
|
||||
|
||||
anomaly_pos = np.where(distances > threshold)[0]
|
||||
anomaly_indices = data.index[anomaly_pos]
|
||||
|
||||
anomaly_details = {}
|
||||
for pos in anomaly_pos:
|
||||
row_norm = data_norm.iloc[pos]
|
||||
cluster_idx = clusters[pos]
|
||||
center = centers[cluster_idx]
|
||||
diff = abs(row_norm - center)
|
||||
main_sensor = diff.idxmax()
|
||||
anomaly_details[data.index[pos]] = main_sensor
|
||||
|
||||
# 修复:滚动平均(窗口可调)
|
||||
data_rolled = data_filled.rolling(window=13, center=True, min_periods=1).mean()
|
||||
data_repaired = data_filled.copy()
|
||||
for pos in anomaly_pos:
|
||||
label = data.index[pos]
|
||||
sensor = anomaly_details[label]
|
||||
data_repaired.loc[label, sensor] = data_rolled.loc[label, sensor]
|
||||
|
||||
# 可选可视化(使用位置作为 x 轴)
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
if show_plot and len(data.columns) > 0:
|
||||
n = len(data)
|
||||
time = np.arange(n)
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data.columns:
|
||||
plt.plot(
|
||||
time, data[col].values, marker="o", markersize=3, label=col, alpha=0.5
|
||||
)
|
||||
for col in data_filled.columns:
|
||||
plt.plot(
|
||||
time,
|
||||
data_filled[col].values,
|
||||
marker="x",
|
||||
markersize=3,
|
||||
label=f"{col}_filled",
|
||||
linestyle="--",
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
plt.plot(pos, data_filled.iloc[pos][sensor], "ro", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("压力监测值")
|
||||
plt.title("各传感器折线图(红色标记主要异常点,虚线为0值填充后)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
for col in data_repaired.columns:
|
||||
plt.plot(
|
||||
time, data_repaired[col].values, marker="o", markersize=3, label=col
|
||||
)
|
||||
for pos in anomaly_pos:
|
||||
sensor = anomaly_details[data.index[pos]]
|
||||
plt.plot(pos, data_repaired.iloc[pos][sensor], "go", markersize=8)
|
||||
plt.xlabel("时间点(序号)")
|
||||
plt.ylabel("修复后压力监测值")
|
||||
plt.title("修复后各传感器折线图(绿色标记修复值)")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
# 返回清洗后的字典
|
||||
return data_repaired
|
||||
|
||||
|
||||
# 测试
|
||||
# if __name__ == "__main__":
|
||||
# # 默认使用脚本目录下的 pressure_raw_data.csv
|
||||
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# default_csv = os.path.join(script_dir, "pressure_raw_data.csv")
|
||||
# out_path = clean_pressure_data_km(default_csv, show_plot=False)
|
||||
# print("保存路径:", out_path)
|
||||
|
||||
# 测试 clean_pressure_data_dict_km 函数
|
||||
if __name__ == "__main__":
|
||||
import random
|
||||
|
||||
# 读取 szh_pressure_scada.csv 文件
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
csv_path = os.path.join(script_dir, "szh_pressure_scada.csv")
|
||||
data = pd.read_csv(csv_path, header=0, index_col=None, encoding="utf-8")
|
||||
|
||||
# 排除 Time 列,随机选择 5 列
|
||||
columns_to_exclude = ["Time"]
|
||||
available_columns = [col for col in data.columns if col not in columns_to_exclude]
|
||||
selected_columns = random.sample(available_columns, 5)
|
||||
|
||||
# 将选中的列转换为字典
|
||||
data_dict = {col: data[col].tolist() for col in selected_columns}
|
||||
|
||||
print("选中的列:", selected_columns)
|
||||
print("原始数据长度:", len(data_dict[selected_columns[0]]))
|
||||
|
||||
# 调用函数进行清洗
|
||||
cleaned_dict = clean_pressure_data_df_km(data_dict, show_plot=True)
|
||||
|
||||
print("清洗后的字典键:", list(cleaned_dict.keys()))
|
||||
print("清洗后的数据长度:", len(cleaned_dict[selected_columns[0]]))
|
||||
print("测试完成:函数运行正常")
|
||||
3
app/algorithms/api_ex/__init__.py
Normal file
3
app/algorithms/api_ex/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .Fdataclean import *
|
||||
from .Pdataclean import *
|
||||
from .pipeline_health_analyzer import *
|
||||
109
app/algorithms/api_ex/kmeans_sensor.py
Normal file
109
app/algorithms/api_ex/kmeans_sensor.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import wntr
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import sklearn.cluster
|
||||
import os
|
||||
|
||||
|
||||
|
||||
class QD_KMeans(object):
|
||||
def __init__(self, wn, num_monitors):
|
||||
# self.inp = inp
|
||||
self.cluster_num = num_monitors # 聚类中心个数,也即测压点个数
|
||||
self.wn=wn
|
||||
self.monitor_nodes = []
|
||||
self.coords = []
|
||||
self.junction_nodes = {} # Added missing initialization
|
||||
|
||||
|
||||
def get_junctions_coordinates(self):
|
||||
|
||||
for junction_name in self.wn.junction_name_list:
|
||||
junction = self.wn.get_node(junction_name)
|
||||
self.junction_nodes[junction_name] = junction.coordinates
|
||||
self.coords.append(junction.coordinates )
|
||||
|
||||
# print(f"Total junctions: {self.junction_coordinates}")
|
||||
|
||||
def select_monitoring_points(self):
|
||||
if not self.coords: # Add check if coordinates are collected
|
||||
self.get_junctions_coordinates()
|
||||
coords = np.array(self.coords)
|
||||
coords_normalized = (coords - coords.min(axis=0)) / (coords.max(axis=0) - coords.min(axis=0))
|
||||
kmeans = sklearn.cluster.KMeans(n_clusters= self.cluster_num, random_state=42)
|
||||
kmeans.fit(coords_normalized)
|
||||
|
||||
for center in kmeans.cluster_centers_:
|
||||
distances = np.sum((coords_normalized - center) ** 2, axis=1)
|
||||
nearest_node = self.wn.junction_name_list[np.argmin(distances)]
|
||||
self.monitor_nodes.append(nearest_node)
|
||||
|
||||
return self.monitor_nodes
|
||||
|
||||
|
||||
def visualize_network(self):
|
||||
"""Visualize network with monitoring points"""
|
||||
ax=wntr.graphics.plot_network(self.wn,
|
||||
node_attribute=self.monitor_nodes,
|
||||
node_size=30,
|
||||
title='Optimal sensor')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
def kmeans_sensor_placement(name: str, sensor_num: int, min_diameter: int) -> list:
|
||||
inp_name = f'./db_inp/{name}.db.inp'
|
||||
wn= wntr.network.WaterNetworkModel(inp_name)
|
||||
wn_cluster=QD_KMeans(wn, sensor_num)
|
||||
|
||||
# Select monitoring pointse
|
||||
sensor_ids= wn_cluster.select_monitoring_points()
|
||||
|
||||
# wn_cluster.visualize_network()
|
||||
|
||||
return sensor_ids
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#sensorindex = get_ID(name='suzhouhe_2024_cloud_0817', sensor_num=30, min_diameter=500)
|
||||
sensorindex = kmeans_sensor_placement(name='szh', sensor_num=50, min_diameter=300)
|
||||
print(sensorindex)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
BIN
app/algorithms/api_ex/model/my_survival_forest_model_quxi.zip
Normal file
BIN
app/algorithms/api_ex/model/my_survival_forest_model_quxi.zip
Normal file
Binary file not shown.
142
app/algorithms/api_ex/pipeline_health_analyzer.py
Normal file
142
app/algorithms/api_ex/pipeline_health_analyzer.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import os
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class PipelineHealthAnalyzer:
|
||||
"""
|
||||
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
|
||||
|
||||
该类封装了模型加载和预测功能,便于在其他项目中复用。
|
||||
模型基于4个特征进行生存分析预测:材料、直径、流速、压力。
|
||||
|
||||
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"):
|
||||
"""
|
||||
初始化分析器,加载预训练的随机生存森林模型。
|
||||
|
||||
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
|
||||
:raises FileNotFoundError: 如果模型文件不存在。
|
||||
:raises Exception: 如果模型加载失败。
|
||||
"""
|
||||
# 确保 model 目录存在
|
||||
model_dir = os.path.dirname(model_path)
|
||||
if model_dir and not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
||||
|
||||
try:
|
||||
self.rsf = joblib.load(model_path)
|
||||
self.features = [
|
||||
"Material",
|
||||
"Diameter",
|
||||
"Flow Velocity",
|
||||
"Pressure", # 'Temperature', 'Precipitation',
|
||||
# 'Location', 'Structural Defects', 'Functional Defects'
|
||||
]
|
||||
except Exception as e:
|
||||
raise Exception(f"加载模型时出错: {str(e)}")
|
||||
|
||||
def predict_survival(self, data: pd.DataFrame) -> list:
|
||||
"""
|
||||
基于输入数据预测生存函数。
|
||||
|
||||
:param data: pandas DataFrame,包含4个必需特征列。数据应为数值型或可转换为数值型。
|
||||
:return: 生存函数列表,每个元素为一个生存函数对象(包含时间点x和生存概率y)。
|
||||
:raises ValueError: 如果数据缺少必需特征或格式不正确。
|
||||
"""
|
||||
# 检查必需特征是否存在
|
||||
missing_features = [feat for feat in self.features if feat not in data.columns]
|
||||
if missing_features:
|
||||
raise ValueError(f"数据缺少必需特征: {missing_features}")
|
||||
|
||||
# 提取特征数据
|
||||
try:
|
||||
x_test = data[self.features].astype(float) # 确保数值型
|
||||
except ValueError as e:
|
||||
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
|
||||
|
||||
# 进行预测
|
||||
survival_functions = self.rsf.predict_survival_function(x_test)
|
||||
return list(survival_functions)
|
||||
|
||||
def plot_survival(
|
||||
self, survival_functions: list, save_path: str = None, show_plot: bool = True
|
||||
):
|
||||
"""
|
||||
可视化生存函数,生成生存概率图表。
|
||||
|
||||
:param survival_functions: predict_survival返回的生存函数列表。
|
||||
:param save_path: 可选,保存图表的路径(.png格式)。如果为None,则不保存。
|
||||
:param show_plot: 是否显示图表(在交互环境中)。
|
||||
"""
|
||||
plt.figure(figsize=(10, 6))
|
||||
for i, sf in enumerate(survival_functions):
|
||||
plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
|
||||
plt.xlabel("时间(年)")
|
||||
plt.ylabel("生存概率")
|
||||
plt.title("管道生存概率预测")
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"图表已保存到: {save_path}")
|
||||
|
||||
if show_plot:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
|
||||
# 调用说明示例
|
||||
"""
|
||||
在其他项目中使用PipelineHealthAnalyzer类的步骤:
|
||||
|
||||
1. 安装依赖(在requirements.txt中添加):
|
||||
joblib==1.5.0
|
||||
pandas==2.2.3
|
||||
numpy==2.0.2
|
||||
scikit-survival==0.23.1
|
||||
matplotlib==3.9.4
|
||||
|
||||
2. 导入类:
|
||||
from pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
3. 初始化分析器(替换为实际模型路径):
|
||||
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
|
||||
|
||||
4. 准备数据(pandas DataFrame,包含9个特征列):
|
||||
import pandas as pd
|
||||
data = pd.DataFrame({
|
||||
'Material': [1, 2], # 示例数据
|
||||
'Diameter': [100, 150],
|
||||
'Flow Velocity': [1.5, 2.0],
|
||||
'Pressure': [50, 60],
|
||||
'Temperature': [20, 25],
|
||||
'Precipitation': [0.1, 0.2],
|
||||
'Location': [1, 2],
|
||||
'Structural Defects': [0, 1],
|
||||
'Functional Defects': [0, 0]
|
||||
})
|
||||
|
||||
5. 进行预测:
|
||||
survival_funcs = analyzer.predict_survival(data)
|
||||
|
||||
6. 查看结果(每个样本的生存概率随时间变化):
|
||||
for i, sf in enumerate(survival_funcs):
|
||||
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
|
||||
|
||||
7. 可视化(可选):
|
||||
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
|
||||
|
||||
注意:
|
||||
- 数据格式必须匹配特征列表,特征值为数值型。
|
||||
- 模型文件需从原项目复制或重新训练。
|
||||
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
||||
"""
|
||||
1568
app/algorithms/online_Analysis.py
Normal file
1568
app/algorithms/online_Analysis.py
Normal file
File diff suppressed because it is too large
Load Diff
654
app/algorithms/sensitivity.py
Normal file
654
app/algorithms/sensitivity.py
Normal file
@@ -0,0 +1,654 @@
|
||||
# 改进灵敏度法
|
||||
import networkx
|
||||
import numpy as np
|
||||
import pandas
|
||||
import wntr
|
||||
import pandas as pd
|
||||
import copy
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from sklearn.cluster import KMeans
|
||||
from wntr.epanet.toolkit import EpanetException
|
||||
from numpy.linalg import slogdet
|
||||
import random
|
||||
from tjnetwork import *
|
||||
from matplotlib.lines import Line2D
|
||||
from sklearn.cluster import SpectralClustering
|
||||
import libpysal as ps
|
||||
from spopt.region import Skater
|
||||
from shapely.geometry import Point
|
||||
import geopandas as gpd
|
||||
from sklearn.metrics import pairwise_distances
|
||||
import project_info
|
||||
|
||||
# 2025/03/12
|
||||
# Step1: 获取节点坐标
|
||||
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
|
||||
"""
|
||||
获取管网模型的节点坐标
|
||||
:param wn: 由wntr生成的模型
|
||||
:return: 节点坐标
|
||||
"""
|
||||
# site: pandas.Series
|
||||
# index:节点名称(wn.node_name_list)
|
||||
# values:每个节点的坐标,格式为 tuple(如 (x, y) 或 (x, y, z))
|
||||
site = wn.query_node_attribute('coordinates')
|
||||
# Coor: pandas.Series
|
||||
# index:与site相同(节点名称)。
|
||||
# values:坐标转换为numpy.ndarray(如array([10.5, 20.3]))
|
||||
Coor = site.apply(lambda x: np.array(x)) # 将节点坐标转换为numpy数组
|
||||
# x, y: list[float]
|
||||
x = [] # 存储所有节点的 x 坐标
|
||||
y = [] # 存储所有节点的 y 坐标
|
||||
for i in range(0, len(Coor)):
|
||||
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
|
||||
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
|
||||
# xy: dict[str, list], x、y 坐标的字典
|
||||
xy = {'x': x, 'y': y}
|
||||
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
|
||||
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
|
||||
return Coor_node
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step2: KMeans 聚类
|
||||
# 将节点用kmeans根据坐标分为k组,存入字典g
|
||||
def kgroup(coor: pandas.DataFrame, knum: int) -> dict[int, list[str]]:
|
||||
"""
|
||||
使用KMeans聚类,将节点坐标分组
|
||||
:param coor: 存储所有节点的坐标数据
|
||||
:param knum: 需要分成的聚类数
|
||||
:return: 聚类结果字典
|
||||
"""
|
||||
g = {}
|
||||
# estimator: sklearn.cluster.KMeans,KMeans 聚类模型
|
||||
estimator = KMeans(n_clusters=knum)
|
||||
estimator.fit(coor)
|
||||
# label_pred: numpy.ndarray(int),每个点的类别标签
|
||||
label_pred = estimator.labels_
|
||||
for i in range(0, knum):
|
||||
g[i] = coor[label_pred == i].index.tolist()
|
||||
return g
|
||||
|
||||
|
||||
def skater_partition(G, n_clusters):
|
||||
"""
|
||||
使用 SKATER 算法对输入的无向图 G 进行区域划分,
|
||||
保证每个划分区域在图论意义上是连通的,
|
||||
同时依据节点坐标的空间信息进行划分。
|
||||
|
||||
参数:
|
||||
G: networkx.Graph
|
||||
带有节点坐标属性(键为 'pos')的无向图。
|
||||
n_clusters: int
|
||||
希望划分的区域数量。
|
||||
|
||||
返回:
|
||||
groups: dict
|
||||
字典形式的聚类结果,键为区域编号,值为该区域内的节点列表。
|
||||
"""
|
||||
# 1. 获取所有节点坐标,假设每个节点都有 'pos' 属性
|
||||
pos = nx.get_node_attributes(G, 'pos')
|
||||
nodes = list(G.nodes())
|
||||
# 构造坐标数组:每行为 [x, y]
|
||||
coords = np.array([pos[node] for node in nodes])
|
||||
|
||||
# 2. 构造 GeoDataFrame:创建 DataFrame 并生成 geometry 列
|
||||
df = pd.DataFrame(coords, columns=['x', 'y'], index=nodes)
|
||||
# 利用 shapely 的 Point 构造空间位置
|
||||
df['geometry'] = df.apply(lambda row: Point(row['x'], row['y']), axis=1)
|
||||
gdf = gpd.GeoDataFrame(df, geometry='geometry')
|
||||
|
||||
# 3. 构造空间权重矩阵,使用 4 近邻方法(k=4,可根据实际情况调整)
|
||||
w = ps.weights.KNN.from_array(coords, k=4)
|
||||
w.transform = 'R'
|
||||
|
||||
# 4. 调用 SKATER:新版本 API 要求传入 gdf, w 以及 attrs_name(这里使用 'x' 和 'y' 作为属性)
|
||||
skater = Skater(gdf, w, attrs_name=['x', 'y'], n_clusters=n_clusters)
|
||||
skater.solve()
|
||||
|
||||
# 5. 获取聚类标签,构造成字典格式
|
||||
labels = skater.labels_
|
||||
groups = {}
|
||||
for label, node in zip(labels, nodes):
|
||||
groups.setdefault(label, []).append(node)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def spectral_partition(G, n_clusters):
|
||||
"""
|
||||
利用谱聚类算法对图 G 进行分区:
|
||||
1. 根据所有节点的空间坐标计算欧氏距离矩阵;
|
||||
2. 利用高斯核函数构造相似度矩阵;
|
||||
3. 使用 SpectralClustering 进行归一化割,返回分区结果。
|
||||
|
||||
参数:
|
||||
G: networkx.Graph
|
||||
每个节点需要有 'pos' 属性,其值为 (x, y) 坐标。
|
||||
n_clusters: int
|
||||
希望划分的聚类数目。
|
||||
|
||||
返回:
|
||||
groups: dict
|
||||
键为聚类标签,值为该聚类对应的节点列表。
|
||||
"""
|
||||
# 1. 获取节点空间坐标,注意保证每个节点都有 'pos' 属性
|
||||
pos_dict = nx.get_node_attributes(G, 'pos')
|
||||
nodes = list(G.nodes())
|
||||
coords = np.array([pos_dict[node] for node in nodes])
|
||||
|
||||
# 2. 计算节点之间的欧氏距离矩阵
|
||||
D = pairwise_distances(coords, metric='euclidean')
|
||||
|
||||
# 3. 计算 sigma 值:这里取所有距离的均值,当然也可以根据实际情况调整
|
||||
sigma = np.mean(D)
|
||||
|
||||
# 4. 构造相似度矩阵:使用高斯核函数
|
||||
# A(i, j) = exp( -d(i,j)^2 / (2*sigma^2) )
|
||||
A = np.exp(- (D ** 2) / (2 * sigma ** 2))
|
||||
|
||||
# 5. 使用谱聚类进行图分区
|
||||
clustering = SpectralClustering(n_clusters=n_clusters,
|
||||
affinity='precomputed',
|
||||
random_state=0)
|
||||
labels = clustering.fit_predict(A)
|
||||
|
||||
# 6. 构造字典形式的分区结果
|
||||
groups = {}
|
||||
for label, node in zip(labels, nodes):
|
||||
groups.setdefault(label, []).append(node)
|
||||
|
||||
return groups
|
||||
|
||||
# 2025/03/12
|
||||
# Step3: wn_func类,水力计算
|
||||
# wn_func 主要用于计算:
|
||||
# 水力距离(hydraulic length):即节点之间的水力阻力。
|
||||
# 灵敏度分析(sensitivity analysis):用于优化测压点的布置。
|
||||
# 一些与水力相关的函数,包括 CtoS:求水力距离,stafun:求状态函数F
|
||||
# # diff:求F对P的导数,返回灵敏度矩阵A
|
||||
# # sensitivity:返回灵敏度和总灵敏度
|
||||
class wn_func(object):
|
||||
|
||||
# Step3.1: 初始化
|
||||
def __init__(self, wn: wntr.network.WaterNetworkModel, min_diameter: int):
|
||||
"""
|
||||
获取管网模型信息
|
||||
:param wn: 由wntr生成的模型
|
||||
:param min_diameter: 安装的最小管径
|
||||
"""
|
||||
# self.results: wntr.sim.results.SimulationResults,仿真结果,包含压力、流量、水头等数据
|
||||
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
|
||||
self.wn = wn
|
||||
# self.q:pandas.DataFrame,管道流量,索引为时间步长,列为管道名称
|
||||
self.q = self.results.link['flowrate']
|
||||
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
|
||||
ReservoirIndex = wn.reservoir_name_list
|
||||
Tankindex = wn.tank_name_list
|
||||
# 删除水库节点,删除与直接水库相连的虚拟管道
|
||||
# self.pipes: list[str],所有管道的名称
|
||||
self.pipes = wn.pipe_name_list
|
||||
# self.nodes: list[str],所有节点的名称
|
||||
self.nodes = wn.node_name_list
|
||||
# self.coordinates:pandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
|
||||
self.coordinates = wn.query_node_attribute('coordinates')
|
||||
# allpumps / allvalves: list[str],所有泵/阀门名称列表
|
||||
allpumps = wn.pump_name_list
|
||||
allvalves = wn.valve_name_list
|
||||
# pumpstnode / pumpednode / valvestnode / valveednode: list[str],存储泵和阀门 起终点节点的名称
|
||||
pumpstnode = []
|
||||
pumpednode = []
|
||||
valvestnode = []
|
||||
valveednode = []
|
||||
# Reservoirpipe / Reservoirednode: list[str],记录与水库相关的管道和节点
|
||||
Reservoirpipe = []
|
||||
Reservoirednode = []
|
||||
for pump in allpumps:
|
||||
pumpstnode.append(wn.links[pump].start_node.name)
|
||||
pumpednode.append(wn.links[pump].end_node.name)
|
||||
for valve in allvalves:
|
||||
valvestnode.append(wn.links[valve].start_node.name)
|
||||
valveednode.append(wn.links[valve].end_node.name)
|
||||
for pipe in self.pipes:
|
||||
if wn.links[pipe].start_node.name in ReservoirIndex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].end_node.name)
|
||||
if wn.links[pipe].start_node.name in Tankindex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].end_node.name)
|
||||
if wn.links[pipe].end_node.name in Tankindex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].start_node.name)
|
||||
# 泵的起终点、tank、reservoir
|
||||
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
|
||||
self.delnodes = list(
|
||||
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
|
||||
# 泵、起终点为tank、reservoir的管道
|
||||
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
|
||||
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
|
||||
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
|
||||
# self.L: list[float],所有管道的长度(以米为单位)
|
||||
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
|
||||
self.n = len(self.nodes)
|
||||
self.m = len(self.pipes)
|
||||
# self.unit_headloss: list[float],单位水头损失(headloss 数据的第一行,单位:米/km)
|
||||
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
|
||||
##
|
||||
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
|
||||
|
||||
# === 改动新增部分:筛选管径小于 min_diameter 的管道节点 ===
|
||||
self.less_than_min_diameter_junction_list = []
|
||||
for pipe in self.pipes:
|
||||
diameter = wn.links[pipe].diameter
|
||||
if diameter < min_diameter:
|
||||
start_node = wn.links[pipe].start_node.name
|
||||
end_node = wn.links[pipe].end_node.name
|
||||
self.less_than_min_diameter_junction_list.extend([start_node, end_node])
|
||||
# 去重
|
||||
self.less_than_min_diameter_junction_list = list(set(self.less_than_min_diameter_junction_list))
|
||||
|
||||
# Step3.2: 计算水力距离
|
||||
def CtoS(self):
|
||||
"""
|
||||
计算水力距离矩阵
|
||||
:return:
|
||||
"""
|
||||
# 水力距离:当行索引对应的节点为控制点时,列索引对应的节点距离控制点的(路径*水头损失)的最小值
|
||||
# nodes:list[str](节点名称)
|
||||
nodes = copy.deepcopy(self.nodes)
|
||||
# pipes:list[str](管道名称)
|
||||
pipes = self.pipes
|
||||
wn = self.wn
|
||||
# n / m:int(节点数 / 管道数)
|
||||
n = self.n
|
||||
m = self.m
|
||||
s1 = [0] * m
|
||||
q = self.q
|
||||
L = self.L
|
||||
# H1:pandas.DataFrame,水头数据,索引为时间步长,列为节点名
|
||||
H1 = self.results.node['head'].T
|
||||
# hh:list[float],计算管道两端水头之差
|
||||
hh = []
|
||||
# 水头损失
|
||||
for p in pipes:
|
||||
h1 = self.wn.links[p].start_node.name
|
||||
h1 = H1.loc[str(h1)]
|
||||
h2 = self.wn.links[p].end_node.name
|
||||
h2 = H1.loc[str(h2)]
|
||||
hh.append(abs(h1 - h2))
|
||||
hh = np.array(hh)
|
||||
# headloss:pandas.DataFrame,管道水头损失矩阵
|
||||
headloss = pd.DataFrame(hh, index=pipes).T
|
||||
# s1:管道阻力系数,s2:将管道阻力系数与管道的起始节点和终止节点对应
|
||||
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵,index代表起始节点,columns代表终止节点
|
||||
G = nx.DiGraph()
|
||||
for i in range(0, m):
|
||||
pipe = pipes[i]
|
||||
a = wn.links[pipe].start_node.name
|
||||
b = wn.links[pipe].end_node.name
|
||||
if q.loc[0, pipe] > 0:
|
||||
hf.loc[a, b] = headloss.loc[0, pipe]
|
||||
weightL.loc[a, b] = headloss.loc[0, pipe] * L[i]
|
||||
G.add_weighted_edges_from([(a, b, weightL.loc[a, b])])
|
||||
|
||||
else:
|
||||
hf.loc[b, a] = headloss.loc[0, pipe]
|
||||
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
|
||||
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
|
||||
|
||||
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
|
||||
for a in nodes:
|
||||
if a in G.nodes:
|
||||
d = nx.shortest_path_length(G, source=a, weight='weight')
|
||||
for b in list(d.keys()):
|
||||
hydraulicL.loc[a, b] = d[b]
|
||||
|
||||
hydraulicL = hydraulicL.drop(self.delnodes)
|
||||
hydraulicL = hydraulicL.drop(self.delnodes, axis=1)
|
||||
|
||||
# 求加权水力距离
|
||||
return hydraulicL, G
|
||||
|
||||
# Step3.3: 计算灵敏度矩阵
|
||||
# 获取关系矩阵
|
||||
def get_Conn(self):
|
||||
"""
|
||||
计算管网连接关系矩阵
|
||||
:return:
|
||||
"""
|
||||
m = self.wn.num_links
|
||||
n = self.wn.num_nodes
|
||||
p = self.wn.num_pumps
|
||||
v = self.wn.num_valves
|
||||
|
||||
self.nonjunc_index = []
|
||||
self.non_link_index = []
|
||||
for r in self.wn.reservoirs():
|
||||
self.nonjunc_index.append(r[0])
|
||||
for t in self.wn.tanks():
|
||||
self.nonjunc_index.append(t[0])
|
||||
# Conn:numpy.matrix,节点-管道连接矩阵,起点 -1,终点 1
|
||||
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
|
||||
# NConn:numpy.matrix,节点-节点连接矩阵,有管道相连的地方设为 1
|
||||
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系,之间有管道为1,反之为0
|
||||
# pipes:list[str],去除泵和阀门的管道列表
|
||||
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
|
||||
for pipe_name, pipe in pipes:
|
||||
start = self.wn.node_name_list.index(pipe.start_node_name)
|
||||
end = self.wn.node_name_list.index(pipe.end_node_name)
|
||||
p_index = self.wn.link_name_list.index(pipe_name)
|
||||
Conn[start, p_index] = -1
|
||||
Conn[end, p_index] = 1
|
||||
NConn[start, end] = 1
|
||||
NConn[end, start] = 1
|
||||
self.A = Conn
|
||||
link_name_list = [link for link in self.wn.link_name_list if
|
||||
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
|
||||
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
|
||||
self.A2 = self.A2.drop(self.delnodes)
|
||||
for pipe in self.delpipes:
|
||||
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
|
||||
self.A2 = self.A2.drop(columns=pipe)
|
||||
self.junc_list = self.A2.index
|
||||
self.A2 = np.mat(self.A2) # 节点管道关系
|
||||
self.A3 = NConn
|
||||
|
||||
def Jaco(self, hL: pandas.DataFrame):
|
||||
"""
|
||||
计算灵敏度矩阵(节点压力对粗糙度变化的响应)
|
||||
:param hL: 水力距离矩阵
|
||||
:return:
|
||||
"""
|
||||
# global result
|
||||
# A:numpy.matrix, 节点-管道关系矩阵
|
||||
A = self.A2
|
||||
wn = self.wn
|
||||
|
||||
try:
|
||||
result = wntr.sim.EpanetSimulator(wn).run_sim()
|
||||
except EpanetException:
|
||||
pass
|
||||
finally:
|
||||
h = result.link['headloss'][self.pipes].values[0]
|
||||
q = result.link['flowrate'][self.pipes].values[0]
|
||||
l = self.wn.query_link_attribute('length')[self.pipes]
|
||||
C = self.wn.query_link_attribute('roughness')[self.pipes]
|
||||
# headloss:numpy.ndarray,水头损失数组
|
||||
headloss = np.array(h)
|
||||
# 调整流量方向
|
||||
for i in range(0, len(q)):
|
||||
if q[i] < 0:
|
||||
A[:, i] = -A[:, i]
|
||||
# q:numpy.ndarray,流量数组
|
||||
q = np.abs(q)
|
||||
# 两个灵敏度矩阵
|
||||
# B / S:numpy.matrix,灵敏度计算的中间矩阵
|
||||
B = np.mat(np.diag(q / ((1.852 * headloss) + 1e-10)))
|
||||
S = np.mat(np.diag(q / C))
|
||||
# X:numpy.matrix, 灵敏度矩阵
|
||||
X = A * B * A.T
|
||||
try:
|
||||
det = np.linalg.det(X)
|
||||
except RuntimeError as e:
|
||||
sign, logdet = slogdet(X) # 防止溢出
|
||||
det = sign * np.exp(logdet)
|
||||
if det != 0:
|
||||
J_H_Cw = X.I * A * S
|
||||
# J_H_Q = -X.I
|
||||
J_q_Cw = S - B * A.T * X.I * A * S # 去掉了delnodes和delpipes
|
||||
# J_q_Q = B * A.T * X.I
|
||||
else: # 当X不可逆
|
||||
J_H_Cw = np.linalg.pinv(X) @ A @ S
|
||||
# J_H_Q = -np.linalg.pinv(X)
|
||||
J_q_Cw = S - B * A.T * np.linalg.pinv(X) * A * S
|
||||
# J_q_Q = B * A.T * np.linalg.pinv(X)
|
||||
|
||||
Sen_pressure = []
|
||||
S_pressure = np.abs(J_H_Cw).sum(axis=1).tolist() # 修改为绝对值
|
||||
for ss in S_pressure:
|
||||
Sen_pressure.append(ss[0])
|
||||
# 求总灵敏度
|
||||
SS_pressure = copy.deepcopy(hL)
|
||||
for i in range(0, len(Sen_pressure)):
|
||||
SS_pressure.iloc[i, :] = SS_pressure.iloc[i, :] * Sen_pressure[i]
|
||||
SS = copy.deepcopy(hL)
|
||||
for i in range(0, len(Sen_pressure)):
|
||||
SS.iloc[i, :] = SS.iloc[i, :] * Sen_pressure[i]
|
||||
# SS[i,j]:节点nodes[i]的灵敏度*该节点到nodes[j]的水力距离
|
||||
return SS
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step4: 传感器布置优化
|
||||
# Sensorplacement
|
||||
# weight:分配权重
|
||||
# sensor:传感器布置的位置
|
||||
class Sensorplacement(wn_func):
|
||||
"""
|
||||
Sensorplacement 类继承了 wn_func 类,并且用于计算和优化传感器布置的位置。
|
||||
"""
|
||||
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int, min_diameter: int):
|
||||
"""
|
||||
|
||||
:param wn: 由wntr生成的模型
|
||||
:param sensornum: 传感器的数量
|
||||
:param min_diameter: 安装的最小管径
|
||||
"""
|
||||
wn_func.__init__(self, wn, min_diameter=min_diameter)
|
||||
self.sensornum = sensornum
|
||||
|
||||
# 1.某个节点到所有节点的加权距离之和
|
||||
# 2.某个节点到该组内所有节点的加权距离之和
|
||||
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
|
||||
"""
|
||||
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
|
||||
:param SS: 灵敏度矩阵,每个节点的行和列代表不同节点,矩阵元素表示节点间的灵敏度。SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
|
||||
:param G: 加权图,表示管网的拓扑结构,每个节点通过管道连接。图的边的权重通常是根据水力距离或者流量等计算的
|
||||
:param group: 节点分组,字典的键是分组编号,值是该组的节点名称列表
|
||||
:return:
|
||||
"""
|
||||
# 传感器布置个数以及位置
|
||||
# W = self.weight()
|
||||
n = self.n - len(self.delnodes)
|
||||
nodes = copy.deepcopy(self.nodes)
|
||||
for node in self.delnodes:
|
||||
nodes.remove(node)
|
||||
# sumSS:list[float],每个节点到其他节点的灵敏度之和。SS.iloc[i, :] 返回第 i 个节点与所有其他节点的灵敏度值,sum(SS.iloc[i, :]) 计算这些灵敏度值的总和。
|
||||
sumSS = []
|
||||
for i in range(0, n):
|
||||
sumSS.append(sum(SS.iloc[i, :]))
|
||||
# 一个整数范围,表示每个节点的索引,用作sumSS_ DataFrame的索引
|
||||
indices = range(0, n)
|
||||
# sumSS_:pandas.DataFrame,将 sumSS 转换成 DataFrame 格式,并且将节点的总灵敏度保存到 CSV 文件 sumSS_data.csv 中
|
||||
sumSS_ = pd.DataFrame(np.array(sumSS), index=indices)
|
||||
# sumSS_.to_csv('sumSS_data.csv') # 存储节点总灵敏度
|
||||
|
||||
# sumSS:pandas.DataFrame,sumSS 被转换为 DataFrame 类型,并且按总灵敏度(即灵敏度之和)降序排列。此时,sumSS 是按节点的灵敏度之和排序的 DataFrame
|
||||
sumSS = pd.DataFrame(np.array(sumSS), index=nodes)
|
||||
sumSS = sumSS.sort_values(by=[0], ascending=[False])
|
||||
# sensorindex:list[str],用于存储根据灵敏度排序选出的传感器位置的节点名称,存储根据总灵敏度排序的节点列表,用于传感器布置
|
||||
sensorindex = []
|
||||
# sensorindex_2:list[str],用于存储每组内根据灵敏度排序选出的传感器位置的节点名称,存储每个组内根据灵敏度排序选择的传感器节点
|
||||
sensorindex_2 = []
|
||||
# group_S:dict[int, pandas.DataFrame],存储每个组内的灵敏度矩阵
|
||||
group_S = {}
|
||||
# group_sumSS:dict[int, list[float]],存储每个组内节点的总灵敏度,值为每个组内节点灵敏度之和的列表
|
||||
group_sumSS = {}
|
||||
|
||||
# 改动
|
||||
for i in range(0, len(group)):
|
||||
for node in self.delnodes:
|
||||
# 这里的group[i]是每个组的节点列表,代码首先去除已经被标记为删除的节点self.delnodes
|
||||
if node in group[i]:
|
||||
group[i].remove(node)
|
||||
group_S[i] = SS.loc[group[i], group[i]]
|
||||
# 对每个组内的节点,计算组内节点的总灵敏度(group_sumSS[i])。它将每个组内节点的灵敏度值相加,并且按灵敏度降序排序
|
||||
group_sumSS[i] = []
|
||||
for j in range(0, len(group[i])):
|
||||
group_sumSS[i].append(sum(group_S[i].iloc[j, :]))
|
||||
group_sumSS[i] = pd.DataFrame(np.array(group_sumSS[i]), index=group[i])
|
||||
group_sumSS[i] = group_sumSS[i].sort_values(by=[0], ascending=[False])
|
||||
for node in self.less_than_min_diameter_junction_list:
|
||||
# 这里的group_sumSS[i]是每个分组的灵敏度节点排序列表,去除已经被标记为删除的节点self.less_than_min_diameter_junction_list
|
||||
if node in group_sumSS[i]:
|
||||
group_sumSS[i].remove(node)
|
||||
pass
|
||||
|
||||
# 1.选sumSS最大的节点,然后把这个节点所在的那个组删掉,就可以不再从这个组选点。再重新排序选sumSS最大的;
|
||||
# 2.在每组内选group_sumSS最大的节点
|
||||
# 在这个循环中,首先选择灵敏度最高的节点Smaxnode并添加到sensorindex。然后根据灵敏度排序,删除已选的节点并继续选择下一个灵敏度最大的节点。这个过程用于选择传感器的位置
|
||||
sensornum = self.sensornum
|
||||
for i in range(0, sensornum):
|
||||
# Smaxnode:str,最大灵敏度节点,sumSS.index[0] 表示灵敏度最高的节点
|
||||
Smaxnode = sumSS.index[0]
|
||||
sensorindex.append(Smaxnode)
|
||||
sensorindex_2.append(group_sumSS[i].index[0])
|
||||
|
||||
for key, value in group.items():
|
||||
if Smaxnode in value:
|
||||
sumSS = sumSS.drop(index=group[key])
|
||||
continue
|
||||
|
||||
sumSS = sumSS.sort_values(by=[0], ascending=[False])
|
||||
|
||||
return sensorindex, sensorindex_2
|
||||
|
||||
|
||||
# 2025/03/13
|
||||
def get_ID(name: str, sensor_num: int, min_diameter: int) -> list[str]:
|
||||
"""
|
||||
获取布置测压点的坐标,初始测压点布置根据灵敏度来布置,计算初始情况下的校准过程的error
|
||||
:param name: 数据库名称
|
||||
:param sensor_num: 测压点数目
|
||||
:param min_diameter: 安装的最小管径
|
||||
:return: 测压点节点ID
|
||||
"""
|
||||
# inp_file_real:str,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
|
||||
inp_file_real = f'./db_inp/{name}.db.inp'
|
||||
# sensornum:int,需要布置的传感器数量
|
||||
# sensornum = sensor_num
|
||||
# wn_real:wntr.network.WaterNetworkModel,加载 EPANET 水力模型
|
||||
wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
|
||||
# sim_real:wntr.sim.EpanetSimulator,创建一个水力仿真器对象
|
||||
sim_real = wntr.sim.EpanetSimulator(wn_real)
|
||||
# results_real:wntr.sim.results.SimulationResults,运行仿真并返回结果
|
||||
results_real = sim_real.run_sim()
|
||||
|
||||
# real_C:list[float],包含所有管道粗糙度的列表
|
||||
real_C = wn_real.query_link_attribute('roughness').tolist()
|
||||
# wn_fun1:wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
|
||||
wn_fun1 = wn_func(wn_real, min_diameter=min_diameter)
|
||||
# nodes:list[str],管网的节点名称列表
|
||||
nodes = wn_fun1.nodes
|
||||
# delnodes:list[str],被删除的节点(如水库、泵、阀门连接的节点等)
|
||||
delnodes = wn_fun1.delnodes
|
||||
# Coor_node:pandas.DataFrame
|
||||
Coor_node = getCoor(wn_real)
|
||||
Coor_node = Coor_node.drop(wn_fun1.delnodes)
|
||||
nodes = [node for node in wn_fun1.nodes if node not in delnodes]
|
||||
# coordinates:pandas.Series,存储所有节点的坐标,类型为 Series,索引为节点名称,值为 (x, y) 坐标对
|
||||
coordinates = wn_fun1.coordinates
|
||||
|
||||
# 随机产生监测点
|
||||
# junctionnum:int,nodes 的长度,表示节点的数量
|
||||
junctionnum = len(nodes)
|
||||
# random_numbers:list[int],使用 random.sample 随机选择 sensornum(20)个节点的编号。它返回一个不重复的随机编号列表
|
||||
# random_numbers = random.sample(range(junctionnum), sensor_num)
|
||||
# for i in range(sensor_num):
|
||||
# # print(random_numbers[i])
|
||||
|
||||
wn_fun1.get_Conn()
|
||||
# hL:pandas.DataFrame,水力距离矩阵,表示每个节点到其他节点的水力阻力
|
||||
# G:networkx.DiGraph,加权有向图,表示管网的拓扑结构,节点之间的边带有权重
|
||||
hL, G = wn_fun1.CtoS()
|
||||
# SS:pandas.DataFrame,灵敏度矩阵,表示每个节点对管网变化(如粗糙度、流量等)的响应
|
||||
SS = wn_fun1.Jaco(hL)
|
||||
# group:dict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组,每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
|
||||
|
||||
G1 = wn_real.to_graph()
|
||||
G1 = G1.to_undirected() # 变为无向图
|
||||
|
||||
group = kgroup(Coor_node, sensor_num)
|
||||
# group = skater_partition(G1, sensor_num)
|
||||
# group = spectral_partition(G1, sensor_num)
|
||||
|
||||
# print(group)
|
||||
# --------------------- 保存 group 数据 ---------------------
|
||||
# 将 group 数据转换为一个“长格式”的 DataFrame,
|
||||
# 每一行记录一个节点及其所属的分组
|
||||
# group_data = []
|
||||
# for group_id, node_list in group.items():
|
||||
# for node in node_list:
|
||||
# group_data.append({"Group": group_id, "Node": node})
|
||||
#
|
||||
# df_group = pd.DataFrame(group_data)
|
||||
#
|
||||
# # 保存为 Excel 文件,文件名为 "group.xlsx";index=False 表示不保存行索引
|
||||
# df_group.to_excel("group.xlsx", index=False)
|
||||
|
||||
# wn_fun:Sensorplacement(继承自wn_func)
|
||||
# 创建Sensorplacement类的实例,传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
|
||||
wn_fun = Sensorplacement(wn_real, sensor_num, min_diameter=min_diameter)
|
||||
wn_fun.__dict__.update(wn_fun1.__dict__)
|
||||
# sensorindex:list[str],初始传感器布置位置的节点名称
|
||||
# sensorindex_2:list[str],根据分组选择的传感器位置
|
||||
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
|
||||
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
|
||||
|
||||
|
||||
# 重新打开数据库
|
||||
# if is_project_open(name=name):
|
||||
# close_project(name=name)
|
||||
# open_project(name=name)
|
||||
# for node_id in sensorindex :
|
||||
# sensor_coord[node_id] = get_node_coord(name=name, node_id=node_id)
|
||||
# close_project(name=name)
|
||||
# print(sensor_coord)
|
||||
# # 分区画图
|
||||
# colorlist = ['lightpink', 'coral', 'rosybrown', 'olive', 'powderblue', 'lightskyblue', 'steelblue', 'peachpuff','brown','silver','indigo','lime','gold','violet','maroon','navy','teal','magenta','cyan',
|
||||
# 'burlywood', 'tan', 'slategrey', 'thistle', 'lightseagreen', 'lightgreen', 'red','blue','yellow','orange','purple','grey','green','pink','lightblue','beige','chartreuse','turquoise','lavender','fuchsia','coral']
|
||||
# G = wn_real.to_graph()
|
||||
# G = G.to_undirected() # 变为无向图
|
||||
# pos = nx.get_node_attributes(G, 'pos')
|
||||
# pass
|
||||
#
|
||||
# for i in range(0, sensor_num):
|
||||
# ax = plt.gca()
|
||||
# ax.set_title(inp_file_real + str(sensor_num))
|
||||
# nodes = nx.draw_networkx_nodes(G, pos, nodelist=group[i], node_color=colorlist[i], node_size=10)
|
||||
# nodes = nx.draw_networkx_nodes(G, pos,
|
||||
# nodelist=sensorindex_2, node_color='red', node_size=70, node_shape='*'
|
||||
# )
|
||||
# edges = nx.draw_networkx_edges(G, pos)
|
||||
# ax.spines['top'].set_visible(False)
|
||||
# ax.spines['right'].set_visible(False)
|
||||
# ax.spines['bottom'].set_visible(False)
|
||||
# ax.spines['left'].set_visible(False)
|
||||
# plt.savefig(inp_file_real + str(sensor_num) + ".png", dpi=300)
|
||||
# plt.show()
|
||||
#
|
||||
# wntr.graphics.plot_network(wn_real, node_attribute=sensorindex_2, node_size=50, node_labels=False,
|
||||
# title=inp_file_real + '_Projetion' + str(sensor_num))
|
||||
# plt.savefig(inp_file_real + '_S' + str(sensor_num) + ".png", dpi=300)
|
||||
# plt.show()
|
||||
return sensorindex
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sensorindex = get_ID(name=project_info.name, sensor_num=20, min_diameter=300)
|
||||
|
||||
print(sensorindex)
|
||||
# 将 sensor_coord 字典转换为 DataFrame,
|
||||
# 使用 orient='index' 表示字典的键作为 DataFrame 的行索引,
|
||||
# 数据中每个键对应的 value 是一个子字典,其键 'x' 和 'y' 成为 DataFrame 的列名
|
||||
# df_sensor_coord = pd.DataFrame.from_dict(sensor_coord, orient='index')
|
||||
#
|
||||
# # 将索引名称设为 'Node'
|
||||
# df_sensor_coord.index.name = 'Node'
|
||||
#
|
||||
# # 保存到 Excel 文件
|
||||
# df_sensor_coord.to_excel("sensor_coord.xlsx", index=True)
|
||||
557
app/algorithms/sensor_placement.py
Normal file
557
app/algorithms/sensor_placement.py
Normal file
@@ -0,0 +1,557 @@
|
||||
# 改进灵敏度法
|
||||
import networkx
|
||||
import numpy as np
|
||||
import pandas
|
||||
import wntr
|
||||
import pandas as pd
|
||||
import copy
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from sklearn.cluster import KMeans
|
||||
from wntr.epanet.toolkit import EpanetException
|
||||
from numpy.linalg import slogdet
|
||||
import random
|
||||
from tjnetwork import *
|
||||
import project_info
|
||||
|
||||
# 2025/03/12
|
||||
# Step1: 获取节点坐标
|
||||
def getCoor(wn: wntr.network.WaterNetworkModel) -> pandas.DataFrame:
|
||||
"""
|
||||
获取管网模型的节点坐标
|
||||
:param wn: 由wntr生成的模型
|
||||
:return: 节点坐标
|
||||
"""
|
||||
# site: pandas.Series
|
||||
# index:节点名称(wn.node_name_list)
|
||||
# values:每个节点的坐标,格式为 tuple(如 (x, y) 或 (x, y, z))
|
||||
site = wn.query_node_attribute('coordinates')
|
||||
# Coor: pandas.Series
|
||||
# index:与site相同(节点名称)。
|
||||
# values:坐标转换为numpy.ndarray(如array([10.5, 20.3]))
|
||||
Coor = site.apply(lambda x: np.array(x)) # 将节点坐标转换为numpy数组
|
||||
# x, y: list[float]
|
||||
x = [] # 存储所有节点的 x 坐标
|
||||
y = [] # 存储所有节点的 y 坐标
|
||||
for i in range(0, len(Coor)):
|
||||
x.append(Coor.values[i][0]) # 将 x 坐标存入 x 列表。
|
||||
y.append(Coor.values[i][1]) # 将 y 坐标存入 y 列表
|
||||
# xy: dict[str, list], x、y 坐标的字典
|
||||
xy = {'x': x, 'y': y}
|
||||
# Coor_node: pandas.DataFrame, 存储节点 x, y 坐标的 DataFrame
|
||||
Coor_node = pd.DataFrame(xy, index=wn.node_name_list, columns=['x', 'y'])
|
||||
return Coor_node
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step2: KMeans 聚类
|
||||
# 将节点用kmeans根据坐标分为k组,存入字典g
|
||||
def kgroup(coor: pandas.DataFrame, knum: int) -> dict[int, list[str]]:
|
||||
"""
|
||||
使用KMeans聚类,将节点坐标分组
|
||||
:param coor: 存储所有节点的坐标数据
|
||||
:param knum: 需要分成的聚类数
|
||||
:return: 聚类结果字典
|
||||
"""
|
||||
g = {}
|
||||
# estimator: sklearn.cluster.KMeans,KMeans 聚类模型
|
||||
estimator = KMeans(n_clusters=knum)
|
||||
estimator.fit(coor)
|
||||
# label_pred: numpy.ndarray(int),每个点的类别标签
|
||||
label_pred = estimator.labels_
|
||||
for i in range(0, knum):
|
||||
g[i] = coor[label_pred == i].index.tolist()
|
||||
return g
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step3: wn_func类,水力计算
|
||||
# wn_func 主要用于计算:
|
||||
# 水力距离(hydraulic length):即节点之间的水力阻力。
|
||||
# 灵敏度分析(sensitivity analysis):用于优化测压点的布置。
|
||||
# 一些与水力相关的函数,包括 CtoS:求水力距离,stafun:求状态函数F
|
||||
# # diff:求F对P的导数,返回灵敏度矩阵A
|
||||
# # sensitivity:返回灵敏度和总灵敏度
|
||||
class wn_func(object):
|
||||
|
||||
# Step3.1: 初始化
|
||||
def __init__(self, wn: wntr.network.WaterNetworkModel):
|
||||
"""
|
||||
获取管网模型信息
|
||||
:param wn: 由wntr生成的模型
|
||||
"""
|
||||
# self.results: wntr.sim.results.SimulationResults,仿真结果,包含压力、流量、水头等数据
|
||||
self.results = wntr.sim.EpanetSimulator(wn).run_sim() # 存储运行结果
|
||||
self.wn = wn
|
||||
# self.q:pandas.DataFrame,管道流量,索引为时间步长,列为管道名称
|
||||
self.q = self.results.link['flowrate']
|
||||
# ReservoirIndex / Tankindex: list[str],水库 / 水箱节点名称列表
|
||||
ReservoirIndex = wn.reservoir_name_list
|
||||
Tankindex = wn.tank_name_list
|
||||
# 删除水库节点,删除与直接水库相连的虚拟管道
|
||||
# self.pipes: list[str],所有管道的名称
|
||||
self.pipes = wn.pipe_name_list
|
||||
# self.nodes: list[str],所有节点的名称
|
||||
self.nodes = wn.node_name_list
|
||||
# self.coordinates:pandas.Series,节点坐标,索引为节点名,值为 (x, y) 坐标的 tuple
|
||||
self.coordinates = wn.query_node_attribute('coordinates')
|
||||
# allpumps / allvalves: list[str],所有泵/阀门名称列表
|
||||
allpumps = wn.pump_name_list
|
||||
allvalves = wn.valve_name_list
|
||||
# pumpstnode / pumpednode / valvestnode / valveednode: list[str],存储泵和阀门 起终点节点的名称
|
||||
pumpstnode = []
|
||||
pumpednode = []
|
||||
valvestnode = []
|
||||
valveednode = []
|
||||
# Reservoirpipe / Reservoirednode: list[str],记录与水库相关的管道和节点
|
||||
Reservoirpipe = []
|
||||
Reservoirednode = []
|
||||
for pump in allpumps:
|
||||
pumpstnode.append(wn.links[pump].start_node.name)
|
||||
pumpednode.append(wn.links[pump].end_node.name)
|
||||
for valve in allvalves:
|
||||
valvestnode.append(wn.links[valve].start_node.name)
|
||||
valveednode.append(wn.links[valve].end_node.name)
|
||||
for pipe in self.pipes:
|
||||
if wn.links[pipe].start_node.name in ReservoirIndex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].end_node.name)
|
||||
if wn.links[pipe].start_node.name in Tankindex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].end_node.name)
|
||||
if wn.links[pipe].end_node.name in Tankindex:
|
||||
Reservoirpipe.append(pipe)
|
||||
Reservoirednode.append(wn.links[pipe].start_node.name)
|
||||
# 泵的起终点、tank、reservoir
|
||||
# self.delnodes: list[str],需要删除的节点(包括水库、泵、阀门连接的节点)
|
||||
self.delnodes = list(
|
||||
set(ReservoirIndex).union(Tankindex, pumpstnode, pumpednode, valvestnode, valveednode, Reservoirednode))
|
||||
# 泵、起终点为tank、reservoir的管道
|
||||
# self.delpipes: list[str],需要删除的管道(包括水库、泵、阀门连接的管道)
|
||||
self.delpipes = list(set(wn.pump_name_list).union(wn.valve_name_list).union(Reservoirpipe))
|
||||
self.pipes = [pipe for pipe in wn.pipe_name_list if pipe not in self.delpipes]
|
||||
# self.L: list[float],所有管道的长度(以米为单位)
|
||||
self.L = wn.query_link_attribute('length')[self.pipes].tolist()
|
||||
self.n = len(self.nodes)
|
||||
self.m = len(self.pipes)
|
||||
# self.unit_headloss: list[float],单位水头损失(headloss 数据的第一行,单位:米/km)
|
||||
self.unit_headloss = self.results.link['headloss'].iloc[0, :].tolist()
|
||||
##
|
||||
self.delnodes1 = list(set(ReservoirIndex).union(Tankindex))
|
||||
|
||||
# Step3.2: 计算水力距离
|
||||
def CtoS(self):
|
||||
"""
|
||||
计算水力距离矩阵
|
||||
:return:
|
||||
"""
|
||||
# 水力距离:当行索引对应的节点为控制点时,列索引对应的节点距离控制点的(路径*水头损失)的最小值
|
||||
# nodes:list[str](节点名称)
|
||||
nodes = copy.deepcopy(self.nodes)
|
||||
# pipes:list[str](管道名称)
|
||||
pipes = self.pipes
|
||||
wn = self.wn
|
||||
# n / m:int(节点数 / 管道数)
|
||||
n = self.n
|
||||
m = self.m
|
||||
s1 = [0] * m
|
||||
q = self.q
|
||||
L = self.L
|
||||
# H1:pandas.DataFrame,水头数据,索引为时间步长,列为节点名
|
||||
H1 = self.results.node['head'].T
|
||||
# hh:list[float],计算管道两端水头之差
|
||||
hh = []
|
||||
# 水头损失
|
||||
for p in pipes:
|
||||
h1 = self.wn.links[p].start_node.name
|
||||
h1 = H1.loc[str(h1)]
|
||||
h2 = self.wn.links[p].end_node.name
|
||||
h2 = H1.loc[str(h2)]
|
||||
hh.append(abs(h1 - h2))
|
||||
hh = np.array(hh)
|
||||
# headloss:pandas.DataFrame,管道水头损失矩阵
|
||||
headloss = pd.DataFrame(hh, index=pipes).T
|
||||
# s1:管道阻力系数,s2:将管道阻力系数与管道的起始节点和终止节点对应
|
||||
hf = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
weightL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
# s2为对应管道起始节点与终止节点的粗糙度系数矩阵,index代表起始节点,columns代表终止节点
|
||||
G = nx.DiGraph()
|
||||
for i in range(0, m):
|
||||
pipe = pipes[i]
|
||||
a = wn.links[pipe].start_node.name
|
||||
b = wn.links[pipe].end_node.name
|
||||
if q.loc[0, pipe] > 0:
|
||||
hf.loc[a, b] = headloss.loc[0, pipe]
|
||||
weightL.loc[a, b] = headloss.loc[0, pipe] * L[i]
|
||||
G.add_weighted_edges_from([(a, b, weightL.loc[a, b])])
|
||||
|
||||
else:
|
||||
hf.loc[b, a] = headloss.loc[0, pipe]
|
||||
weightL.loc[b, a] = headloss.loc[0, pipe] * L[i]
|
||||
G.add_weighted_edges_from([(b, a, weightL.loc[b, a])])
|
||||
|
||||
hydraulicL = pd.DataFrame(np.array([0] * (n ** 2)).reshape(n, n), index=nodes, columns=nodes, dtype=float)
|
||||
|
||||
for a in nodes:
|
||||
if a in G.nodes:
|
||||
d = nx.shortest_path_length(G, source=a, weight='weight')
|
||||
for b in list(d.keys()):
|
||||
hydraulicL.loc[a, b] = d[b]
|
||||
|
||||
hydraulicL = hydraulicL.drop(self.delnodes)
|
||||
hydraulicL = hydraulicL.drop(self.delnodes, axis=1)
|
||||
|
||||
# 求加权水力距离
|
||||
return hydraulicL, G
|
||||
|
||||
# Step3.3: 计算灵敏度矩阵
|
||||
# 获取关系矩阵
|
||||
def get_Conn(self):
|
||||
"""
|
||||
计算管网连接关系矩阵
|
||||
:return:
|
||||
"""
|
||||
m = self.wn.num_links
|
||||
n = self.wn.num_nodes
|
||||
p = self.wn.num_pumps
|
||||
v = self.wn.num_valves
|
||||
|
||||
self.nonjunc_index = []
|
||||
self.non_link_index = []
|
||||
for r in self.wn.reservoirs():
|
||||
self.nonjunc_index.append(r[0])
|
||||
for t in self.wn.tanks():
|
||||
self.nonjunc_index.append(t[0])
|
||||
# Conn:numpy.matrix,节点-管道连接矩阵,起点 -1,终点 1
|
||||
Conn = np.mat(np.zeros([n, m - p - v])) # 节点和管道的关系矩阵,行为节点,列为管道,起点为-1,终点为1
|
||||
# NConn:numpy.matrix,节点-节点连接矩阵,有管道相连的地方设为 1
|
||||
NConn = np.mat(np.zeros([n, n])) # 节点之间的关系,之间有管道为1,反之为0
|
||||
# pipes:list[str],去除泵和阀门的管道列表
|
||||
pipes = [pipe for pipe in self.wn.pipes() if pipe not in self.wn.pumps() and pipe not in self.wn.valves()]
|
||||
for pipe_name, pipe in pipes:
|
||||
start = self.wn.node_name_list.index(pipe.start_node_name)
|
||||
end = self.wn.node_name_list.index(pipe.end_node_name)
|
||||
p_index = self.wn.link_name_list.index(pipe_name)
|
||||
Conn[start, p_index] = -1
|
||||
Conn[end, p_index] = 1
|
||||
NConn[start, end] = 1
|
||||
NConn[end, start] = 1
|
||||
self.A = Conn
|
||||
link_name_list = [link for link in self.wn.link_name_list if
|
||||
link not in self.wn.pump_name_list and link not in self.wn.valve_name_list]
|
||||
self.A2 = pd.DataFrame(self.A, index=self.wn.node_name_list, columns=link_name_list)
|
||||
self.A2 = self.A2.drop(self.delnodes)
|
||||
for pipe in self.delpipes:
|
||||
if pipe not in self.wn.pump_name_list and pipe not in self.wn.valve_name_list:
|
||||
self.A2 = self.A2.drop(columns=pipe)
|
||||
self.junc_list = self.A2.index
|
||||
self.A2 = np.mat(self.A2) # 节点管道关系
|
||||
self.A3 = NConn
|
||||
|
||||
def Jaco(self, hL: pandas.DataFrame):
|
||||
"""
|
||||
计算灵敏度矩阵(节点压力对粗糙度变化的响应)
|
||||
:param hL: 水力距离矩阵
|
||||
:return:
|
||||
"""
|
||||
# global result
|
||||
# A:numpy.matrix, 节点-管道关系矩阵
|
||||
A = self.A2
|
||||
wn = self.wn
|
||||
|
||||
try:
|
||||
result = wntr.sim.EpanetSimulator(wn).run_sim()
|
||||
except EpanetException:
|
||||
pass
|
||||
finally:
|
||||
h = result.link['headloss'][self.pipes].values[0]
|
||||
q = result.link['flowrate'][self.pipes].values[0]
|
||||
l = self.wn.query_link_attribute('length')[self.pipes]
|
||||
C = self.wn.query_link_attribute('roughness')[self.pipes]
|
||||
# headloss:numpy.ndarray,水头损失数组
|
||||
headloss = np.array(h)
|
||||
# 调整流量方向
|
||||
for i in range(0, len(q)):
|
||||
if q[i] < 0:
|
||||
A[:, i] = -A[:, i]
|
||||
# q:numpy.ndarray,流量数组
|
||||
q = np.abs(q)
|
||||
# 两个灵敏度矩阵
|
||||
# B / S:numpy.matrix,灵敏度计算的中间矩阵
|
||||
B = np.mat(np.diag(q / ((1.852 * headloss) + 1e-10)))
|
||||
S = np.mat(np.diag(q / C))
|
||||
# X:numpy.matrix, 灵敏度矩阵
|
||||
X = A * B * A.T
|
||||
try:
|
||||
det = np.linalg.det(X)
|
||||
except RuntimeError as e:
|
||||
sign, logdet = slogdet(X) # 防止溢出
|
||||
det = sign * np.exp(logdet)
|
||||
if det != 0:
|
||||
J_H_Cw = X.I * A * S
|
||||
# J_H_Q = -X.I
|
||||
J_q_Cw = S - B * A.T * X.I * A * S # 去掉了delnodes和delpipes
|
||||
# J_q_Q = B * A.T * X.I
|
||||
else: # 当X不可逆
|
||||
J_H_Cw = np.linalg.pinv(X) @ A @ S
|
||||
# J_H_Q = -np.linalg.pinv(X)
|
||||
J_q_Cw = S - B * A.T * np.linalg.pinv(X) * A * S
|
||||
# J_q_Q = B * A.T * np.linalg.pinv(X)
|
||||
|
||||
Sen_pressure = []
|
||||
S_pressure = np.abs(J_H_Cw).sum(axis=1).tolist() # 修改为绝对值
|
||||
for ss in S_pressure:
|
||||
Sen_pressure.append(ss[0])
|
||||
# 求总灵敏度
|
||||
SS_pressure = copy.deepcopy(hL)
|
||||
for i in range(0, len(Sen_pressure)):
|
||||
SS_pressure.iloc[i, :] = SS_pressure.iloc[i, :] * Sen_pressure[i]
|
||||
SS = copy.deepcopy(hL)
|
||||
for i in range(0, len(Sen_pressure)):
|
||||
SS.iloc[i, :] = SS.iloc[i, :] * Sen_pressure[i]
|
||||
# SS[i,j]:节点nodes[i]的灵敏度*该节点到nodes[j]的水力距离
|
||||
return SS
|
||||
|
||||
|
||||
# 2025/03/12
|
||||
# Step4: 传感器布置优化
|
||||
# Sensorplacement
|
||||
# weight:分配权重
|
||||
# sensor:传感器布置的位置
|
||||
class Sensorplacement(wn_func):
|
||||
"""
|
||||
Sensorplacement 类继承了 wn_func 类,并且用于计算和优化传感器布置的位置。
|
||||
"""
|
||||
def __init__(self, wn: wntr.network.WaterNetworkModel, sensornum: int):
|
||||
"""
|
||||
|
||||
:param wn: 由wntr生成的模型
|
||||
:param sensornum: 传感器的数量
|
||||
"""
|
||||
wn_func.__init__(self, wn)
|
||||
self.sensornum = sensornum
|
||||
|
||||
# 1.某个节点到所有节点的加权距离之和
|
||||
# 2.某个节点到该组内所有节点的加权距离之和
|
||||
def sensor(self, SS: pandas.DataFrame, G: networkx.Graph, group: dict[int, list[str]]):
|
||||
"""
|
||||
sensor 方法是用来根据灵敏度矩阵 SS 和加权图 G 来确定传感器布置位置的
|
||||
:param SS: 灵敏度矩阵,每个节点的行和列代表不同节点,矩阵元素表示节点间的灵敏度。SS.iloc[i, :] 表示第 i 行对应节点 i 到所有其他节点的灵敏度
|
||||
:param G: 加权图,表示管网的拓扑结构,每个节点通过管道连接。图的边的权重通常是根据水力距离或者流量等计算的
|
||||
:param group: 节点分组,字典的键是分组编号,值是该组的节点名称列表
|
||||
:return:
|
||||
"""
|
||||
# 传感器布置个数以及位置
|
||||
# W = self.weight()
|
||||
n = self.n - len(self.delnodes)
|
||||
nodes = copy.deepcopy(self.nodes)
|
||||
for node in self.delnodes:
|
||||
nodes.remove(node)
|
||||
# sumSS:list[float],每个节点到其他节点的灵敏度之和。SS.iloc[i, :] 返回第 i 个节点与所有其他节点的灵敏度值,sum(SS.iloc[i, :]) 计算这些灵敏度值的总和。
|
||||
sumSS = []
|
||||
for i in range(0, n):
|
||||
sumSS.append(sum(SS.iloc[i, :]))
|
||||
# 一个整数范围,表示每个节点的索引,用作sumSS_ DataFrame的索引
|
||||
indices = range(0, n)
|
||||
# sumSS_:pandas.DataFrame,将 sumSS 转换成 DataFrame 格式,并且将节点的总灵敏度保存到 CSV 文件 sumSS_data.csv 中
|
||||
sumSS_ = pd.DataFrame(np.array(sumSS), index=indices)
|
||||
sumSS_.to_csv('sumSS_data.csv') # 存储节点总灵敏度
|
||||
# sumSS:pandas.DataFrame,sumSS 被转换为 DataFrame 类型,并且按总灵敏度(即灵敏度之和)降序排列。此时,sumSS 是按节点的灵敏度之和排序的 DataFrame
|
||||
sumSS = pd.DataFrame(np.array(sumSS), index=nodes)
|
||||
sumSS = sumSS.sort_values(by=[0], ascending=[False])
|
||||
# sensorindex:list[str],用于存储根据灵敏度排序选出的传感器位置的节点名称,存储根据总灵敏度排序的节点列表,用于传感器布置
|
||||
sensorindex = []
|
||||
# sensorindex_2:list[str],用于存储每组内根据灵敏度排序选出的传感器位置的节点名称,存储每个组内根据灵敏度排序选择的传感器节点
|
||||
sensorindex_2 = []
|
||||
# group_S:dict[int, pandas.DataFrame],存储每个组内的灵敏度矩阵
|
||||
group_S = {}
|
||||
# group_sumSS:dict[int, list[float]],存储每个组内节点的总灵敏度,值为每个组内节点灵敏度之和的列表
|
||||
group_sumSS = {}
|
||||
for i in range(0, len(group)):
|
||||
for node in self.delnodes:
|
||||
# 这里的group[i]是每个组的节点列表,代码首先去除已经被标记为删除的节点self.delnodes
|
||||
if node in group[i]:
|
||||
group[i].remove(node)
|
||||
group_S[i] = SS.loc[group[i], group[i]]
|
||||
# 对每个组内的节点,计算组内节点的总灵敏度(group_sumSS[i])。它将每个组内节点的灵敏度值相加,并且按灵敏度降序排序
|
||||
group_sumSS[i] = []
|
||||
for j in range(0, len(group[i])):
|
||||
group_sumSS[i].append(sum(group_S[i].iloc[j, :]))
|
||||
group_sumSS[i] = pd.DataFrame(np.array(group_sumSS[i]), index=group[i])
|
||||
group_sumSS[i] = group_sumSS[i].sort_values(by=[0], ascending=[False])
|
||||
pass
|
||||
|
||||
# 1.选sumSS最大的节点,然后把这个节点所在的那个组删掉,就可以不再从这个组选点。再重新排序选sumSS最大的;
|
||||
# 2.在每组内选group_sumSS最大的节点
|
||||
# 在这个循环中,首先选择灵敏度最高的节点Smaxnode并添加到sensorindex。然后根据灵敏度排序,删除已选的节点并继续选择下一个灵敏度最大的节点。这个过程用于选择传感器的位置
|
||||
sensornum = self.sensornum
|
||||
for i in range(0, sensornum):
|
||||
# Smaxnode:str,最大灵敏度节点,sumSS.index[0] 表示灵敏度最高的节点
|
||||
Smaxnode = sumSS.index[0]
|
||||
sensorindex.append(Smaxnode)
|
||||
sensorindex_2.append(group_sumSS[i].index[0])
|
||||
|
||||
for key, value in group.items():
|
||||
if Smaxnode in value:
|
||||
sumSS = sumSS.drop(index=group[key])
|
||||
continue
|
||||
|
||||
sumSS = sumSS.sort_values(by=[0], ascending=[False])
|
||||
|
||||
return sensorindex, sensorindex_2
|
||||
|
||||
|
||||
# 2025/03/13
|
||||
def get_sensor_coord(name: str, sensor_num: int) -> dict[str, float]:
|
||||
"""
|
||||
获取布置测压点的坐标,初始测压点布置根据灵敏度来布置,计算初始情况下的校准过程的error
|
||||
:param name: 数据库名称
|
||||
:param sensor_num: 测压点数目
|
||||
:return: 测压点坐标字典
|
||||
"""
|
||||
# inp_file_real:str,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
|
||||
inp_file_real = f'./db_inp/{name}.db.inp'
|
||||
# sensornum:int,需要布置的传感器数量
|
||||
# sensornum = sensor_num
|
||||
# wn_real:wntr.network.WaterNetworkModel,加载 EPANET 水力模型
|
||||
wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
|
||||
# sim_real:wntr.sim.EpanetSimulator,创建一个水力仿真器对象
|
||||
sim_real = wntr.sim.EpanetSimulator(wn_real)
|
||||
# results_real:wntr.sim.results.SimulationResults,运行仿真并返回结果
|
||||
results_real = sim_real.run_sim()
|
||||
|
||||
# real_C:list[float],包含所有管道粗糙度的列表
|
||||
real_C = wn_real.query_link_attribute('roughness').tolist()
|
||||
# wn_fun1:wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
|
||||
wn_fun1 = wn_func(wn_real)
|
||||
# nodes:list[str],管网的节点名称列表
|
||||
nodes = wn_fun1.nodes
|
||||
# delnodes:list[str],被删除的节点(如水库、泵、阀门连接的节点等)
|
||||
delnodes = wn_fun1.delnodes
|
||||
# Coor_node:pandas.DataFrame
|
||||
Coor_node = getCoor(wn_real)
|
||||
Coor_node = Coor_node.drop(wn_fun1.delnodes)
|
||||
nodes = [node for node in wn_fun1.nodes if node not in delnodes]
|
||||
# coordinates:pandas.Series,存储所有节点的坐标,类型为 Series,索引为节点名称,值为 (x, y) 坐标对
|
||||
coordinates = wn_fun1.coordinates
|
||||
|
||||
# 随机产生监测点
|
||||
# junctionnum:int,nodes 的长度,表示节点的数量
|
||||
junctionnum = len(nodes)
|
||||
# random_numbers:list[int],使用 random.sample 随机选择 sensornum(20)个节点的编号。它返回一个不重复的随机编号列表
|
||||
# random_numbers = random.sample(range(junctionnum), sensor_num)
|
||||
# for i in range(sensor_num):
|
||||
# # print(random_numbers[i])
|
||||
|
||||
wn_fun1.get_Conn()
|
||||
# hL:pandas.DataFrame,水力距离矩阵,表示每个节点到其他节点的水力阻力
|
||||
# G:networkx.DiGraph,加权有向图,表示管网的拓扑结构,节点之间的边带有权重
|
||||
hL, G = wn_fun1.CtoS()
|
||||
# SS:pandas.DataFrame,灵敏度矩阵,表示每个节点对管网变化(如粗糙度、流量等)的响应
|
||||
SS = wn_fun1.Jaco(hL)
|
||||
# group:dict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组,每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
|
||||
group = kgroup(Coor_node, sensor_num)
|
||||
# wn_fun:Sensorplacement(继承自wn_func)
|
||||
# 创建Sensorplacement类的实例,传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
|
||||
wn_fun = Sensorplacement(wn_real, sensor_num)
|
||||
wn_fun.__dict__.update(wn_fun1.__dict__)
|
||||
# sensorindex:list[str],初始传感器布置位置的节点名称
|
||||
# sensorindex_2:list[str],根据分组选择的传感器位置
|
||||
sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
|
||||
# print(str(sensor_num), "个测压点,测压点位置:", sensorindex)
|
||||
sensor_coord = {}
|
||||
# 重新打开数据库
|
||||
if is_project_open(name=name):
|
||||
close_project(name=name)
|
||||
open_project(name=name)
|
||||
for node_id in sensorindex:
|
||||
sensor_coord[node_id] = get_node_coord(name=name, node_id=node_id)
|
||||
close_project(name=name)
|
||||
# print(sensor_coord)
|
||||
return sensor_coord
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sensor_coord = get_sensor_coord(name=project_info.name, sensor_num=20)
|
||||
print(sensor_coord)
|
||||
# '''
|
||||
# 初始测压点布置根据灵敏度来布置,计算初始情况下的校准过程的error
|
||||
# '''
|
||||
#
|
||||
# # inp_file_real:str,输入文件名,表示原始水力模型文件的路径,该文件格式为 EPANET 输入文件(.inp),包含管网的结构信息、节点、管道、泵等数据
|
||||
# inp_file_real = './db_inp/bb.db.inp'
|
||||
# # sensornum:int,需要布置的传感器数量
|
||||
# sensornum = 20
|
||||
# # wn_real:wntr.network.WaterNetworkModel,加载 EPANET 水力模型
|
||||
# wn_real = wntr.network.WaterNetworkModel(inp_file_real) # 真实粗糙度的原始管网
|
||||
# # sim_real:wntr.sim.EpanetSimulator,创建一个水力仿真器对象
|
||||
# sim_real = wntr.sim.EpanetSimulator(wn_real)
|
||||
# # results_real:wntr.sim.results.SimulationResults,运行仿真并返回结果
|
||||
# results_real = sim_real.run_sim()
|
||||
#
|
||||
# # real_C:list[float],包含所有管道粗糙度的列表
|
||||
# real_C = wn_real.query_link_attribute('roughness').tolist()
|
||||
# # wn_fun1:wn_func(继承自 object),创建 wn_func 类的实例,传入 wn_real 水力模型对象。wn_func 用于计算管网相关的水力属性,比如水力距离、灵敏度等
|
||||
# wn_fun1 = wn_func(wn_real)
|
||||
# # nodes:list[str],管网的节点名称列表
|
||||
# nodes = wn_fun1.nodes
|
||||
# # delnodes:list[str],被删除的节点(如水库、泵、阀门连接的节点等)
|
||||
# delnodes = wn_fun1.delnodes
|
||||
# # Coor_node:pandas.DataFrame
|
||||
# Coor_node = getCoor(wn_real)
|
||||
# Coor_node = Coor_node.drop(wn_fun1.delnodes)
|
||||
# nodes = [node for node in wn_fun1.nodes if node not in delnodes]
|
||||
# # coordinates:pandas.Series,存储所有节点的坐标,类型为 Series,索引为节点名称,值为 (x, y) 坐标对
|
||||
# coordinates = wn_fun1.coordinates
|
||||
#
|
||||
# # 随机产生监测点
|
||||
# # junctionnum:int,nodes 的长度,表示节点的数量
|
||||
# junctionnum = len(nodes)
|
||||
# # random_numbers:list[int],使用 random.sample 随机选择 sensornum(20)个节点的编号。它返回一个不重复的随机编号列表
|
||||
# random_numbers = random.sample(range(junctionnum), sensornum)
|
||||
# for i in range(sensornum):
|
||||
# print(random_numbers[i])
|
||||
#
|
||||
# wn_fun1.get_Conn()
|
||||
# # hL:pandas.DataFrame,水力距离矩阵,表示每个节点到其他节点的水力阻力
|
||||
# # G:networkx.DiGraph,加权有向图,表示管网的拓扑结构,节点之间的边带有权重
|
||||
# hL, G = wn_fun1.CtoS()
|
||||
# # SS:pandas.DataFrame,灵敏度矩阵,表示每个节点对管网变化(如粗糙度、流量等)的响应
|
||||
# SS = wn_fun1.Jaco(hL)
|
||||
# # group:dict[int, list[str]],使用 kgroup 函数将节点按坐标分成若干组,每组包含的节点数不一定相同。group 是一个字典,键为分组编号,值为节点名列表
|
||||
# group = kgroup(Coor_node, sensornum)
|
||||
# # wn_fun:Sensorplacement(继承自wn_func)
|
||||
# # 创建Sensorplacement类的实例,传入水力网络模型wn_real和传感器数量sensornum。Sensorplacement用于计算和布置传感器
|
||||
# wn_fun = Sensorplacement(wn_real, sensornum)
|
||||
# wn_fun.__dict__.update(wn_fun1.__dict__)
|
||||
# # sensorindex:list[str],初始传感器布置位置的节点名称
|
||||
# # sensorindex_2:list[str],根据分组选择的传感器位置
|
||||
# sensorindex, sensorindex_2 = wn_fun.sensor(SS, G, group) # 初始的sensorindex
|
||||
# print(str(sensornum), "个测压点,测压点位置:", sensorindex)
|
||||
|
||||
# # 分区画图
|
||||
# colorlist = ['lightpink', 'coral', 'rosybrown', 'olive', 'powderblue', 'lightskyblue', 'steelblue', 'peachpuff','brown','silver','indigo','lime','gold','violet','maroon','navy','teal','magenta','cyan',
|
||||
# 'burlywood', 'tan', 'slategrey', 'thistle', 'lightseagreen', 'lightgreen', 'red','blue','yellow','orange','purple','grey','green','pink','lightblue','beige','chartreuse','turquoise','lavender','fuchsia','coral']
|
||||
# G = wn_real.to_graph()
|
||||
# G = G.to_undirected() # 变为无向图
|
||||
# pos = nx.get_node_attributes(G, 'pos')
|
||||
# pass
|
||||
# for i in range(0, sensornum):
|
||||
# ax = plt.gca()
|
||||
# ax.set_title(inp_file_real + str(sensornum))
|
||||
# nodes = nx.draw_networkx_nodes(G, pos, nodelist=group[i], node_color=colorlist[i], node_size=20)
|
||||
# nodes = nx.draw_networkx_nodes(G, pos,
|
||||
# nodelist=sensorindex_2, node_color='black', node_size=70, node_shape='*'
|
||||
# )
|
||||
# edges = nx.draw_networkx_edges(G, pos)
|
||||
# ax.spines['top'].set_visible(False)
|
||||
# ax.spines['right'].set_visible(False)
|
||||
# ax.spines['bottom'].set_visible(False)
|
||||
# ax.spines['left'].set_visible(False)
|
||||
# plt.savefig(inp_file_real + str(sensornum) + ".png")
|
||||
# plt.show()
|
||||
#
|
||||
# wntr.graphics.plot_network(wn_real, node_attribute=sensorindex_2, node_size=50, node_labels=False,
|
||||
# title=inp_file_real + '_Projetion' + str(sensornum))
|
||||
# plt.savefig(inp_file_real + '_S' + str(sensornum) + ".png")
|
||||
# plt.show()
|
||||
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
0
app/api/v1/__init__.py
Normal file
0
app/api/v1/__init__.py
Normal file
0
app/api/v1/endpoints/__init__.py
Normal file
0
app/api/v1/endpoints/__init__.py
Normal file
0
app/api/v1/endpoints/auth.py
Normal file
0
app/api/v1/endpoints/auth.py
Normal file
0
app/api/v1/endpoints/extension.py
Normal file
0
app/api/v1/endpoints/extension.py
Normal file
0
app/api/v1/endpoints/network_elements.py
Normal file
0
app/api/v1/endpoints/network_elements.py
Normal file
0
app/api/v1/endpoints/project.py
Normal file
0
app/api/v1/endpoints/project.py
Normal file
0
app/api/v1/endpoints/scada.py
Normal file
0
app/api/v1/endpoints/scada.py
Normal file
0
app/api/v1/endpoints/simulation.py
Normal file
0
app/api/v1/endpoints/simulation.py
Normal file
0
app/api/v1/endpoints/snapshots.py
Normal file
0
app/api/v1/endpoints/snapshots.py
Normal file
20
app/api/v1/router.py
Normal file
20
app/api/v1/router.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.v1.endpoints import (
|
||||
auth,
|
||||
project,
|
||||
network_elements,
|
||||
simulation,
|
||||
scada,
|
||||
extension,
|
||||
snapshots
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(project.router, prefix="/projects", tags=["projects"])
|
||||
api_router.include_router(network_elements.router, prefix="/elements", tags=["network-elements"])
|
||||
api_router.include_router(simulation.router, prefix="/simulation", tags=["simulation"])
|
||||
api_router.include_router(scada.router, prefix="/scada", tags=["scada"])
|
||||
api_router.include_router(extension.router, prefix="/extension", tags=["extension"])
|
||||
api_router.include_router(snapshots.router, prefix="/snapshots", tags=["snapshots"])
|
||||
0
app/audit/__init__.py
Normal file
0
app/audit/__init__.py
Normal file
0
app/auth/__init__.py
Normal file
0
app/auth/__init__.py
Normal file
21
app/auth/dependencies.py
Normal file
21
app/auth/dependencies.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from app.core.config import settings
|
||||
from jose import jwt, JWTError
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/login/access-token")
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
return username
|
||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
3
app/core/audit.py
Normal file
3
app/core/audit.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Placeholder for audit logic
|
||||
async def log_audit_event(event_type: str, user_id: str, details: dict):
|
||||
pass
|
||||
30
app/core/config.py
Normal file
30
app/core/config.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "TJWater Server"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
SECRET_KEY: str = "your-secret-key-here" # Change in production
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
|
||||
# Database Config (PostgreSQL)
|
||||
DB_NAME: str = "tjwater"
|
||||
DB_HOST: str = "localhost"
|
||||
DB_PORT: str = "5432"
|
||||
DB_USER: str = "postgres"
|
||||
DB_PASSWORD: str = "password"
|
||||
|
||||
# InfluxDB
|
||||
INFLUXDB_URL: str = "http://localhost:8086"
|
||||
INFLUXDB_TOKEN: str = "token"
|
||||
INFLUXDB_ORG: str = "org"
|
||||
INFLUXDB_BUCKET: str = "bucket"
|
||||
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
return f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
||||
9
app/core/encryption.py
Normal file
9
app/core/encryption.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Placeholder for encryption logic
|
||||
class Encryptor:
|
||||
def encrypt(self, data: str) -> str:
|
||||
return data # Implement actual encryption
|
||||
|
||||
def decrypt(self, data: str) -> str:
|
||||
return data # Implement actual decryption
|
||||
|
||||
encryptor = Encryptor()
|
||||
23
app/core/security.py
Normal file
23
app/core/security.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, Any
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
0
app/crypto/__init__.py
Normal file
0
app/crypto/__init__.py
Normal file
0
app/domain/__init__.py
Normal file
0
app/domain/__init__.py
Normal file
0
app/domain/models/__init__.py
Normal file
0
app/domain/models/__init__.py
Normal file
0
app/domain/schemas/__init__.py
Normal file
0
app/domain/schemas/__init__.py
Normal file
0
app/infra/__init__.py
Normal file
0
app/infra/__init__.py
Normal file
0
app/infra/audit/__init__.py
Normal file
0
app/infra/audit/__init__.py
Normal file
0
app/infra/cache/__init__.py
vendored
Normal file
0
app/infra/cache/__init__.py
vendored
Normal file
0
app/infra/db/__init__.py
Normal file
0
app/infra/db/__init__.py
Normal file
0
app/infra/db/influxdb/__init__.py
Normal file
0
app/infra/db/influxdb/__init__.py
Normal file
4964
app/infra/db/influxdb/api.py
Normal file
4964
app/infra/db/influxdb/api.py
Normal file
File diff suppressed because it is too large
Load Diff
6
app/infra/db/influxdb/info.py
Normal file
6
app/infra/db/influxdb/info.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# influxdb数据库连接信息
|
||||
url = "http://127.0.0.1:8086" # 替换为你的InfluxDB实例地址
|
||||
token = "kMPX2V5HsbzPpUT2B9HPBu1sTG1Emf-lPlT2UjxYnGAuocpXq_f_0lK4HHs-TbbKyjsZpICkMsyXG_V2D7P7yQ==" # 替换为你的InfluxDB Token
|
||||
# _ENCODED_TOKEN = "eEdETTVSWnFSSkF1ekFHUy1vdFhVZEMyTkZkWTc1cUpBalJMcUFCNHA1V2NJSUFsSVVwT3BUOF95QTE2QU9IbUpXZXJ3UV8wOGd3Yjg0c3k0MmpuWlE9PQ=="
|
||||
# token = base64.b64decode(_ENCODED_TOKEN).decode("utf-8")
|
||||
org = "TJWATERORG" # 替换为你的Organization名称
|
||||
33
app/infra/db/influxdb/query.py
Normal file
33
app/infra/db/influxdb/query.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from influxdb_client import InfluxDBClient, Point, WriteOptions
|
||||
from influxdb_client.client.query_api import QueryApi
|
||||
import influxdb_info
|
||||
|
||||
# 配置 InfluxDB 连接
|
||||
url = influxdb_info.url
|
||||
token = influxdb_info.token
|
||||
org = influxdb_info.org
|
||||
bucket = "SCADA_data"
|
||||
|
||||
# 创建 InfluxDB 客户端
|
||||
client = InfluxDBClient(url=url, token=token, org=org)
|
||||
|
||||
# 创建查询 API 对象
|
||||
query_api = client.query_api()
|
||||
|
||||
# 构建查询语句
|
||||
query = f'''
|
||||
from(bucket: "{bucket}")
|
||||
|> range(start: -1h)
|
||||
'''
|
||||
|
||||
# 执行查询
|
||||
result = query_api.query(query)
|
||||
print(result)
|
||||
|
||||
# 处理查询结果
|
||||
for table in result:
|
||||
for record in table.records:
|
||||
print(f"Time: {record.get_time()}, Value: {record.get_value()}, Measurement: {record.get_measurement()}, Field: {record.get_field()}")
|
||||
|
||||
# 关闭客户端连接
|
||||
client.close()
|
||||
1
app/infra/db/postgresql/__init__.py
Normal file
1
app/infra/db/postgresql/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .router import router
|
||||
108
app/infra/db/postgresql/database.py
Normal file
108
app/infra/db/postgresql/database.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
import psycopg_pool
|
||||
from psycopg.rows import dict_row
|
||||
import app.native.api.postgresql_info as postgresql_info
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_name=None):
|
||||
self.pool = None
|
||||
self.db_name = db_name
|
||||
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = postgresql_info.get_pgconn_string()
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
min_size=5,
|
||||
max_size=20,
|
||||
open=False, # Don't open immediately, wait for startup
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(f"PostgreSQL connection pool initialized for database: default")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize postgresql connection pool: {e}")
|
||||
raise
|
||||
|
||||
async def open(self):
|
||||
if self.pool:
|
||||
await self.pool.open()
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection pool."""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
logger.info("PostgreSQL connection pool closed.")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator:
|
||||
"""Get a connection from the pool."""
|
||||
if not self.pool:
|
||||
raise Exception("Database pool is not initialized.")
|
||||
|
||||
async with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# 默认数据库实例
|
||||
db = Database()
|
||||
|
||||
# 缓存不同数据库的实例 - 避免重复创建连接池
|
||||
_database_instances: Dict[str, Database] = {}
|
||||
|
||||
|
||||
def create_database_instance(db_name):
|
||||
"""Create a new Database instance for a specific database."""
|
||||
return Database(db_name=db_name)
|
||||
|
||||
|
||||
async def get_database_instance(db_name: Optional[str] = None) -> Database:
|
||||
"""Get or create a database instance for the specified database name."""
|
||||
if not db_name:
|
||||
return db # 返回默认数据库实例
|
||||
|
||||
if db_name not in _database_instances:
|
||||
# 创建新的数据库实例
|
||||
instance = create_database_instance(db_name)
|
||||
instance.init_pool()
|
||||
await instance.open()
|
||||
_database_instances[db_name] = instance
|
||||
logger.info(f"Created new database instance for: {db_name}")
|
||||
|
||||
return _database_instances[db_name]
|
||||
|
||||
|
||||
async def get_db_connection():
|
||||
"""Dependency for FastAPI to get a database connection."""
|
||||
async with db.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_database_connection(db_name: Optional[str] = None):
|
||||
"""
|
||||
FastAPI dependency to get database connection with optional database name.
|
||||
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
|
||||
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
|
||||
"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def cleanup_database_instances():
|
||||
"""Clean up all database instances (call this on application shutdown)."""
|
||||
for db_name, instance in _database_instances.items():
|
||||
await instance.close()
|
||||
logger.info(f"Closed database instance for: {db_name}")
|
||||
_database_instances.clear()
|
||||
|
||||
# 关闭默认数据库
|
||||
await db.close()
|
||||
logger.info("All database instances cleaned up.")
|
||||
83
app/infra/db/postgresql/internal_queries.py
Normal file
83
app/infra/db/postgresql/internal_queries.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi.logger import logger
|
||||
import postgresql_info
|
||||
import psycopg
|
||||
|
||||
|
||||
class InternalQueries:
|
||||
@staticmethod
|
||||
def get_links_by_property(
|
||||
fields: Optional[List[str]] = None,
|
||||
property_conditions: Optional[dict] = None,
|
||||
db_name: str = None,
|
||||
max_retries: int = 3,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中,pipes 的指定字段记录或根据属性筛选
|
||||
:param fields: 要查询的字段列表,如 ["id", "diameter", "status"],默认查询所有字段
|
||||
:param property: 可选的筛选条件字典,如 {"status": "Open"} 或 {"diameter": 300}
|
||||
:param db_name: 数据库名称
|
||||
:param max_retries: 最大重试次数
|
||||
:return: 包含所有记录的列表,每条记录为一个字典
|
||||
"""
|
||||
# 如果未指定字段,查询所有字段
|
||||
if not fields:
|
||||
fields = [
|
||||
"id",
|
||||
"node1",
|
||||
"node2",
|
||||
"length",
|
||||
"diameter",
|
||||
"roughness",
|
||||
"minor_loss",
|
||||
"status",
|
||||
]
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn_string = (
|
||||
postgresql_info.get_pgconn_string(db_name=db_name)
|
||||
if db_name
|
||||
else postgresql_info.get_pgconn_string()
|
||||
)
|
||||
with psycopg.Connection.connect(conn_string) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 构建SELECT子句
|
||||
select_fields = ", ".join(fields)
|
||||
base_query = f"""
|
||||
SELECT {select_fields}
|
||||
FROM public.pipes
|
||||
"""
|
||||
|
||||
# 如果提供了筛选条件,构建WHERE子句
|
||||
if property_conditions:
|
||||
conditions = []
|
||||
params = []
|
||||
for key, value in property_conditions.items():
|
||||
conditions.append(f"{key} = %s")
|
||||
params.append(value)
|
||||
|
||||
query = base_query + " WHERE " + " AND ".join(conditions)
|
||||
cur.execute(query, params)
|
||||
else:
|
||||
cur.execute(base_query)
|
||||
|
||||
records = cur.fetchall()
|
||||
# 将查询结果转换为字典列表
|
||||
pipes = []
|
||||
for record in records:
|
||||
pipe_dict = {}
|
||||
for idx, field in enumerate(fields):
|
||||
pipe_dict[field] = record[idx]
|
||||
pipes.append(pipe_dict)
|
||||
|
||||
return pipes
|
||||
break # 成功
|
||||
except Exception as e:
|
||||
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
else:
|
||||
raise
|
||||
90
app/infra/db/postgresql/router.py
Normal file
90
app/infra/db/postgresql/router.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Optional
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .scada_info import ScadaRepository
|
||||
from .scheme import SchemeRepository
|
||||
|
||||
router = APIRouter(prefix="/postgresql", tags=["postgresql"])
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
@router.get("/scada-info")
|
||||
async def get_scada_info_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有SCADA信息
|
||||
"""
|
||||
try:
|
||||
# 使用ScadaRepository查询SCADA信息
|
||||
scada_data = await ScadaRepository.get_scadas(conn)
|
||||
return {"success": True, "data": scada_data, "count": len(scada_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scheme-list")
|
||||
async def get_scheme_list_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有方案信息
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询方案信息
|
||||
scheme_data = await SchemeRepository.get_schemes(conn)
|
||||
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/burst-locate-result")
|
||||
async def get_burst_locate_result_with_connection(
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
使用连接池查询所有爆管定位结果
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询爆管定位结果
|
||||
burst_data = await SchemeRepository.get_burst_locate_results(conn)
|
||||
return {"success": True, "data": burst_data, "count": len(burst_data)}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"查询爆管定位结果时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/burst-locate-result/{burst_incident}")
|
||||
async def get_burst_locate_result_by_incident(
|
||||
burst_incident: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
根据 burst_incident 查询爆管定位结果
|
||||
"""
|
||||
try:
|
||||
# 使用SchemeRepository查询爆管定位结果
|
||||
return await SchemeRepository.get_burst_locate_result_by_incident(
|
||||
conn, burst_incident
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"根据 burst_incident 查询爆管定位结果时发生错误: {str(e)}",
|
||||
)
|
||||
36
app/infra/db/postgresql/scada_info.py
Normal file
36
app/infra/db/postgresql/scada_info.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import List, Optional, Any
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
|
||||
class ScadaRepository:
|
||||
|
||||
@staticmethod
|
||||
async def get_scadas(conn: AsyncConnection) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中,scada_info 的所有记录
|
||||
:param conn: 异步数据库连接
|
||||
:return: 包含所有记录的列表,每条记录为一个字典
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT id, type, associated_element_id, transmission_mode, transmission_frequency, reliability
|
||||
FROM public.scada_info
|
||||
"""
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
# 将查询结果转换为字典列表(假设 record 是字典)
|
||||
scada_infos = []
|
||||
for record in records:
|
||||
scada_infos.append(
|
||||
{
|
||||
"id": record["id"], # 使用字典键
|
||||
"type": record["type"],
|
||||
"associated_element_id": record["associated_element_id"],
|
||||
"transmission_mode": record["transmission_mode"],
|
||||
"transmission_frequency": record["transmission_frequency"],
|
||||
"reliability": record["reliability"],
|
||||
}
|
||||
)
|
||||
|
||||
return scada_infos
|
||||
104
app/infra/db/postgresql/scheme.py
Normal file
104
app/infra/db/postgresql/scheme.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from typing import List, Optional, Any
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
|
||||
class SchemeRepository:
|
||||
|
||||
@staticmethod
|
||||
async def get_schemes(conn: AsyncConnection) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中, scheme_list 的所有记录
|
||||
:param conn: 异步数据库连接
|
||||
:return: 包含所有记录的列表, 每条记录为一个字典
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail
|
||||
FROM public.scheme_list
|
||||
"""
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
|
||||
scheme_list = []
|
||||
for record in records:
|
||||
scheme_list.append(
|
||||
{
|
||||
"scheme_id": record["scheme_id"],
|
||||
"scheme_name": record["scheme_name"],
|
||||
"scheme_type": record["scheme_type"],
|
||||
"username": record["username"],
|
||||
"create_time": record["create_time"],
|
||||
"scheme_start_time": record["scheme_start_time"],
|
||||
"scheme_detail": record["scheme_detail"],
|
||||
}
|
||||
)
|
||||
|
||||
return scheme_list
|
||||
|
||||
@staticmethod
|
||||
async def get_burst_locate_results(conn: AsyncConnection) -> List[dict]:
|
||||
"""
|
||||
查询pg数据库中, burst_locate_result 的所有记录
|
||||
:param conn: 异步数据库连接
|
||||
:return: 包含所有记录的列表, 每条记录为一个字典
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT id, type, burst_incident, leakage, detect_time, locate_result
|
||||
FROM public.burst_locate_result
|
||||
"""
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
|
||||
results = []
|
||||
for record in records:
|
||||
results.append(
|
||||
{
|
||||
"id": record["id"],
|
||||
"type": record["type"],
|
||||
"burst_incident": record["burst_incident"],
|
||||
"leakage": record["leakage"],
|
||||
"detect_time": record["detect_time"],
|
||||
"locate_result": record["locate_result"],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
async def get_burst_locate_result_by_incident(
|
||||
conn: AsyncConnection, burst_incident: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
根据 burst_incident 查询爆管定位结果
|
||||
:param conn: 异步数据库连接
|
||||
:param burst_incident: 爆管事件标识
|
||||
:return: 包含匹配记录的列表
|
||||
"""
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"""
|
||||
SELECT id, type, burst_incident, leakage, detect_time, locate_result
|
||||
FROM public.burst_locate_result
|
||||
WHERE burst_incident = %s
|
||||
""",
|
||||
(burst_incident,),
|
||||
)
|
||||
records = await cur.fetchall()
|
||||
|
||||
results = []
|
||||
for record in records:
|
||||
results.append(
|
||||
{
|
||||
"id": record["id"],
|
||||
"type": record["type"],
|
||||
"burst_incident": record["burst_incident"],
|
||||
"leakage": record["leakage"],
|
||||
"detect_time": record["detect_time"],
|
||||
"locate_result": record["locate_result"],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
4
app/infra/db/timescaledb/__init__.py
Normal file
4
app/infra/db/timescaledb/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .router import router
|
||||
from .database import *
|
||||
from .timescaledb_info import *
|
||||
from .composite_queries import CompositeQueries
|
||||
606
app/infra/db/timescaledb/composite_queries.py
Normal file
606
app/infra/db/timescaledb/composite_queries.py
Normal file
@@ -0,0 +1,606 @@
|
||||
import time
|
||||
from typing import List, Optional, Any, Dict, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from psycopg import AsyncConnection
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from api_ex.Fdataclean import clean_flow_data_df_kf
|
||||
from api_ex.Pdataclean import clean_pressure_data_df_km
|
||||
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
from postgresql.internal_queries import InternalQueries
|
||||
from postgresql.scada_info import ScadaRepository as PostgreScadaRepository
|
||||
from timescaledb.schemas.realtime import RealtimeRepository
|
||||
from timescaledb.schemas.scheme import SchemeRepository
|
||||
from timescaledb.schemas.scada import ScadaRepository
|
||||
|
||||
|
||||
class CompositeQueries:
|
||||
"""
|
||||
复合查询类,提供跨表查询功能
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_scada_associated_realtime_simulation_data(
|
||||
timescale_conn: AsyncConnection,
|
||||
postgres_conn: AsyncConnection,
|
||||
device_ids: List[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取 SCADA 关联的 link/node 模拟值
|
||||
|
||||
根据传入的 SCADA device_ids,找到关联的 link/node,
|
||||
并根据对应的 type,查询对应的模拟数据
|
||||
|
||||
Args:
|
||||
timescale_conn: TimescaleDB 异步连接
|
||||
postgres_conn: PostgreSQL 异步连接
|
||||
device_ids: SCADA 设备ID列表
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
|
||||
Returns:
|
||||
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
|
||||
|
||||
Raises:
|
||||
ValueError: 当 SCADA 设备未找到或字段无效时
|
||||
"""
|
||||
result = {}
|
||||
# 1. 查询所有 SCADA 信息
|
||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||
|
||||
for device_id in device_ids:
|
||||
# 2. 根据 device_id 找到对应的 SCADA 信息
|
||||
target_scada = None
|
||||
for scada in scada_infos:
|
||||
if scada["id"] == device_id:
|
||||
target_scada = scada
|
||||
break
|
||||
|
||||
if not target_scada:
|
||||
raise ValueError(f"SCADA device {device_id} not found")
|
||||
|
||||
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
|
||||
element_id = target_scada["associated_element_id"]
|
||||
scada_type = target_scada["type"]
|
||||
|
||||
if scada_type.lower() == "pipe_flow":
|
||||
# 查询 link 模拟数据
|
||||
res = await RealtimeRepository.get_link_field_by_time_range(
|
||||
timescale_conn, start_time, end_time, element_id, "flow"
|
||||
)
|
||||
elif scada_type.lower() == "pressure":
|
||||
# 查询 node 模拟数据
|
||||
res = await RealtimeRepository.get_node_field_by_time_range(
|
||||
timescale_conn, start_time, end_time, element_id, "pressure"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
||||
# 添加 scada_id 到每个数据项
|
||||
for item in res:
|
||||
item["scada_id"] = device_id
|
||||
result[device_id] = res
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_scada_associated_scheme_simulation_data(
|
||||
timescale_conn: AsyncConnection,
|
||||
postgres_conn: AsyncConnection,
|
||||
device_ids: List[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取 SCADA 关联的 link/node scheme 模拟值
|
||||
|
||||
根据传入的 SCADA device_ids,找到关联的 link/node,
|
||||
并根据对应的 type,查询对应的模拟数据
|
||||
|
||||
Args:
|
||||
timescale_conn: TimescaleDB 异步连接
|
||||
postgres_conn: PostgreSQL 异步连接
|
||||
device_ids: SCADA 设备ID列表
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
|
||||
Returns:
|
||||
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
|
||||
|
||||
Raises:
|
||||
ValueError: 当 SCADA 设备未找到或字段无效时
|
||||
"""
|
||||
result = {}
|
||||
# 1. 查询所有 SCADA 信息
|
||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||
|
||||
for device_id in device_ids:
|
||||
# 2. 根据 device_id 找到对应的 SCADA 信息
|
||||
target_scada = None
|
||||
for scada in scada_infos:
|
||||
if scada["id"] == device_id:
|
||||
target_scada = scada
|
||||
break
|
||||
|
||||
if not target_scada:
|
||||
raise ValueError(f"SCADA device {device_id} not found")
|
||||
|
||||
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
|
||||
element_id = target_scada["associated_element_id"]
|
||||
scada_type = target_scada["type"]
|
||||
|
||||
if scada_type.lower() == "pipe_flow":
|
||||
# 查询 link 模拟数据
|
||||
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
|
||||
timescale_conn,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
start_time,
|
||||
end_time,
|
||||
element_id,
|
||||
"flow",
|
||||
)
|
||||
elif scada_type.lower() == "pressure":
|
||||
# 查询 node 模拟数据
|
||||
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
|
||||
timescale_conn,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
start_time,
|
||||
end_time,
|
||||
element_id,
|
||||
"pressure",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
||||
# 添加 scada_id 到每个数据项
|
||||
for item in res:
|
||||
item["scada_id"] = device_id
|
||||
result[device_id] = res
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_realtime_simulation_data(
|
||||
timescale_conn: AsyncConnection,
|
||||
featureInfos: List[Tuple[str, str]],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取 link/node 模拟值
|
||||
|
||||
根据传入的 featureInfos,找到关联的 link/node,
|
||||
并根据对应的 type,查询对应的模拟数据
|
||||
|
||||
Args:
|
||||
timescale_conn: TimescaleDB 异步连接
|
||||
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
|
||||
Returns:
|
||||
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
|
||||
|
||||
Raises:
|
||||
ValueError: 当 SCADA 设备未找到或字段无效时
|
||||
"""
|
||||
result = {}
|
||||
for feature_id, type in featureInfos:
|
||||
|
||||
if type.lower() == "pipe":
|
||||
# 查询 link 模拟数据
|
||||
res = await RealtimeRepository.get_link_field_by_time_range(
|
||||
timescale_conn, start_time, end_time, feature_id, "flow"
|
||||
)
|
||||
elif type.lower() == "junction":
|
||||
# 查询 node 模拟数据
|
||||
res = await RealtimeRepository.get_node_field_by_time_range(
|
||||
timescale_conn, start_time, end_time, feature_id, "pressure"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type}")
|
||||
# 添加 scada_id 到每个数据项
|
||||
for item in res:
|
||||
item["feature_id"] = feature_id
|
||||
result[feature_id] = res
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_scheme_simulation_data(
|
||||
timescale_conn: AsyncConnection,
|
||||
featureInfos: List[Tuple[str, str]],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取 link/node scheme 模拟值
|
||||
|
||||
根据传入的 featureInfos,找到关联的 link/node,
|
||||
并根据对应的 type,查询对应的模拟数据
|
||||
|
||||
Args:
|
||||
timescale_conn: TimescaleDB 异步连接
|
||||
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
scheme_type: 工况类型
|
||||
scheme_name: 工况名称
|
||||
|
||||
Returns:
|
||||
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
|
||||
|
||||
Raises:
|
||||
ValueError: 当类型无效时
|
||||
"""
|
||||
result = {}
|
||||
for feature_id, type in featureInfos:
|
||||
if type.lower() == "pipe":
|
||||
# 查询 link 模拟数据
|
||||
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
|
||||
timescale_conn,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
start_time,
|
||||
end_time,
|
||||
feature_id,
|
||||
"flow",
|
||||
)
|
||||
elif type.lower() == "junction":
|
||||
# 查询 node 模拟数据
|
||||
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
|
||||
timescale_conn,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
start_time,
|
||||
end_time,
|
||||
feature_id,
|
||||
"pressure",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type}")
|
||||
# 添加 feature_id 到每个数据项
|
||||
for item in res:
|
||||
item["feature_id"] = feature_id
|
||||
result[feature_id] = res
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_element_associated_scada_data(
|
||||
timescale_conn: AsyncConnection,
|
||||
postgres_conn: AsyncConnection,
|
||||
element_id: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
use_cleaned: bool = False,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
获取 link/node 关联的 SCADA 监测值
|
||||
|
||||
根据传入的 link/node id,匹配 SCADA 信息,
|
||||
如果存在关联的 SCADA device_id,获取实际的监测数据
|
||||
|
||||
Args:
|
||||
timescale_conn: TimescaleDB 异步连接
|
||||
postgres_conn: PostgreSQL 异步连接
|
||||
element_id: link 或 node 的 ID
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
use_cleaned: 是否使用清洗后的数据 (True: "cleaned_value", False: "monitored_value")
|
||||
|
||||
Returns:
|
||||
SCADA 监测数据值,如果没有找到则返回 None
|
||||
|
||||
Raises:
|
||||
ValueError: 当元素类型无效时
|
||||
"""
|
||||
|
||||
# 1. 查询所有 SCADA 信息
|
||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||
|
||||
# 2. 根据 element_type 和 element_id 找到关联的 SCADA 设备
|
||||
associated_scada = None
|
||||
for scada in scada_infos:
|
||||
if scada["associated_element_id"] == element_id:
|
||||
associated_scada = scada
|
||||
break
|
||||
|
||||
if not associated_scada:
|
||||
# 没有找到关联的 SCADA 设备
|
||||
return None
|
||||
|
||||
# 3. 通过 SCADA device_id 获取监测数据
|
||||
device_id = associated_scada["id"]
|
||||
|
||||
# 根据 use_cleaned 参数选择字段
|
||||
data_field = "cleaned_value" if use_cleaned else "monitored_value"
|
||||
|
||||
# 保证 device_id 以列表形式传递
|
||||
res = await ScadaRepository.get_scada_field_by_id_time_range(
|
||||
timescale_conn, [device_id], start_time, end_time, data_field
|
||||
)
|
||||
|
||||
# 将 device_id 替换为 element_id 返回
|
||||
return {element_id: res.get(device_id, [])}
|
||||
|
||||
@staticmethod
|
||||
async def clean_scada_data(
|
||||
timescale_conn: AsyncConnection,
|
||||
postgres_conn: AsyncConnection,
|
||||
device_ids: List[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> str:
|
||||
"""
|
||||
清洗 SCADA 数据
|
||||
|
||||
根据 device_ids 查询 monitored_value,清洗后更新 cleaned_value
|
||||
|
||||
Args:
|
||||
timescale_conn: TimescaleDB 连接
|
||||
postgres_conn: PostgreSQL 连接
|
||||
device_ids: 设备 ID 列表
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
|
||||
Returns:
|
||||
"success" 或错误信息
|
||||
"""
|
||||
try:
|
||||
# 获取所有 SCADA 信息
|
||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||
# 将列表转换为字典,以 device_id 为键
|
||||
scada_device_info_dict = {info["id"]: info for info in scada_infos}
|
||||
|
||||
# 如果 device_ids 为空,则处理所有 SCADA 设备
|
||||
if not device_ids:
|
||||
device_ids = list(scada_device_info_dict.keys())
|
||||
|
||||
# 批量查询所有设备的数据
|
||||
data = await ScadaRepository.get_scada_field_by_id_time_range(
|
||||
timescale_conn, device_ids, start_time, end_time, "monitored_value"
|
||||
)
|
||||
|
||||
if not data:
|
||||
return "error: fetch none scada data" # 没有数据,直接返回
|
||||
|
||||
# 将嵌套字典转换为 DataFrame,使用 time 作为索引
|
||||
# data 格式: {device_id: [{"time": "...", "value": ...}, ...]}
|
||||
all_records = []
|
||||
for device_id, records in data.items():
|
||||
for record in records:
|
||||
all_records.append(
|
||||
{
|
||||
"time": record["time"],
|
||||
"device_id": device_id,
|
||||
"value": record["value"],
|
||||
}
|
||||
)
|
||||
|
||||
if not all_records:
|
||||
return "error: fetch none scada data" # 没有数据,直接返回
|
||||
|
||||
# 创建 DataFrame 并透视,使 device_id 成为列
|
||||
df_long = pd.DataFrame(all_records)
|
||||
df = df_long.pivot(index="time", columns="device_id", values="value")
|
||||
|
||||
# 根据type分类设备
|
||||
pressure_ids = [
|
||||
id
|
||||
for id in df.columns
|
||||
if scada_device_info_dict.get(id, {}).get("type") == "pressure"
|
||||
]
|
||||
flow_ids = [
|
||||
id
|
||||
for id in df.columns
|
||||
if scada_device_info_dict.get(id, {}).get("type") == "pipe_flow"
|
||||
]
|
||||
|
||||
# 处理pressure数据
|
||||
if pressure_ids:
|
||||
pressure_df = df[pressure_ids]
|
||||
# 重置索引,将 time 变为普通列
|
||||
pressure_df = pressure_df.reset_index()
|
||||
# 移除 time 列,准备输入给清洗方法
|
||||
value_df = pressure_df.drop(columns=["time"])
|
||||
# 调用清洗方法
|
||||
cleaned_value_df = clean_pressure_data_df_km(value_df)
|
||||
# 添加 time 列到首列
|
||||
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
|
||||
# 将清洗后的数据写回数据库
|
||||
for device_id in pressure_ids:
|
||||
if device_id in cleaned_df.columns:
|
||||
cleaned_values = cleaned_df[device_id].tolist()
|
||||
time_values = cleaned_df["time"].tolist()
|
||||
for i, time_str in enumerate(time_values):
|
||||
time_dt = datetime.fromisoformat(time_str)
|
||||
value = cleaned_values[i]
|
||||
await ScadaRepository.update_scada_field(
|
||||
timescale_conn,
|
||||
time_dt,
|
||||
device_id,
|
||||
"cleaned_value",
|
||||
value,
|
||||
)
|
||||
|
||||
# 处理flow数据
|
||||
if flow_ids:
|
||||
flow_df = df[flow_ids]
|
||||
# 重置索引,将 time 变为普通列
|
||||
flow_df = flow_df.reset_index()
|
||||
# 移除 time 列,准备输入给清洗方法
|
||||
value_df = flow_df.drop(columns=["time"])
|
||||
# 调用清洗方法
|
||||
cleaned_value_df = clean_flow_data_df_kf(value_df)
|
||||
# 添加 time 列到首列
|
||||
cleaned_df = pd.concat([flow_df["time"], cleaned_value_df], axis=1)
|
||||
# 将清洗后的数据写回数据库
|
||||
for device_id in flow_ids:
|
||||
if device_id in cleaned_df.columns:
|
||||
cleaned_values = cleaned_df[device_id].tolist()
|
||||
time_values = cleaned_df["time"].tolist()
|
||||
for i, time_str in enumerate(time_values):
|
||||
time_dt = datetime.fromisoformat(time_str)
|
||||
value = cleaned_values[i]
|
||||
await ScadaRepository.update_scada_field(
|
||||
timescale_conn,
|
||||
time_dt,
|
||||
device_id,
|
||||
"cleaned_value",
|
||||
value,
|
||||
)
|
||||
|
||||
return "success"
|
||||
except Exception as e:
|
||||
return f"error: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
async def predict_pipeline_health(
|
||||
timescale_conn: AsyncConnection,
|
||||
network_name: str,
|
||||
query_time: datetime,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
预测管道健康状况
|
||||
|
||||
根据管网名称和当前时间,查询管道信息和实时数据,
|
||||
使用随机生存森林模型预测管道的生存概率
|
||||
|
||||
Args:
|
||||
timescale_conn: TimescaleDB 异步连接
|
||||
db_name: 管网数据库名称
|
||||
query_time: 查询时间
|
||||
property_conditions: 可选的管道筛选条件,如 {"diameter": 300}
|
||||
|
||||
Returns:
|
||||
预测结果列表,每个元素包含 link_id 和对应的生存函数
|
||||
|
||||
Raises:
|
||||
ValueError: 当参数无效或数据不足时
|
||||
FileNotFoundError: 当模型文件未找到时
|
||||
"""
|
||||
try:
|
||||
# 1. 准备时间范围(查询时间前后1秒)
|
||||
start_time = query_time - timedelta(seconds=1)
|
||||
end_time = query_time + timedelta(seconds=1)
|
||||
|
||||
# 2. 先查询流速数据(velocity),获取有数据的管道ID列表
|
||||
velocity_data = await RealtimeRepository.get_links_field_by_time_range(
|
||||
timescale_conn, start_time, end_time, "velocity"
|
||||
)
|
||||
|
||||
if not velocity_data:
|
||||
raise ValueError("未找到流速数据")
|
||||
|
||||
# 3. 只查询有流速数据的管道的基本信息
|
||||
valid_link_ids = list(velocity_data.keys())
|
||||
|
||||
# 批量查询这些管道的详细信息
|
||||
fields = ["id", "diameter", "node1", "node2"]
|
||||
all_links = InternalQueries.get_links_by_property(
|
||||
fields=fields,
|
||||
db_name=network_name,
|
||||
)
|
||||
|
||||
# 转换为字典以快速查找
|
||||
links_dict = {link["id"]: link for link in all_links}
|
||||
|
||||
# 获取所有需要查询的节点ID
|
||||
node_ids = set()
|
||||
for link_id in valid_link_ids:
|
||||
if link_id in links_dict:
|
||||
link = links_dict[link_id]
|
||||
node_ids.add(link["node1"])
|
||||
node_ids.add(link["node2"])
|
||||
|
||||
# 4. 批量查询压力数据(pressure)
|
||||
pressure_data = await RealtimeRepository.get_nodes_field_by_time_range(
|
||||
timescale_conn, start_time, end_time, "pressure"
|
||||
)
|
||||
|
||||
# 5. 组合数据结构
|
||||
materials = []
|
||||
diameters = []
|
||||
velocities = []
|
||||
pressures = []
|
||||
link_ids = []
|
||||
|
||||
for link_id in valid_link_ids:
|
||||
# 跳过不在管道字典中的ID(如泵等其他元素)
|
||||
if link_id not in links_dict:
|
||||
continue
|
||||
|
||||
link = links_dict[link_id]
|
||||
diameter = link["diameter"]
|
||||
node1 = link["node1"]
|
||||
node2 = link["node2"]
|
||||
|
||||
# 获取流速数据
|
||||
velocity_values = velocity_data[link_id]
|
||||
velocity = velocity_values[-1]["value"] if velocity_values else 0
|
||||
|
||||
# 获取node1和node2的压力数据,计算平均值
|
||||
node1_pressure = 0
|
||||
node2_pressure = 0
|
||||
|
||||
if node1 in pressure_data and pressure_data[node1]:
|
||||
pressure_values = pressure_data[node1]
|
||||
node1_pressure = (
|
||||
pressure_values[-1]["value"] if pressure_values else 0
|
||||
)
|
||||
|
||||
if node2 in pressure_data and pressure_data[node2]:
|
||||
pressure_values = pressure_data[node2]
|
||||
node2_pressure = (
|
||||
pressure_values[-1]["value"] if pressure_values else 0
|
||||
)
|
||||
|
||||
# 计算平均压力
|
||||
avg_pressure = (node1_pressure + node2_pressure) / 2
|
||||
|
||||
# 添加到列表
|
||||
link_ids.append(link_id)
|
||||
materials.append(7) # 默认材料类型为7,可根据实际情况调整
|
||||
diameters.append(diameter)
|
||||
velocities.append(velocity)
|
||||
pressures.append(avg_pressure)
|
||||
|
||||
if not link_ids:
|
||||
raise ValueError("没有找到有效的管道数据用于预测")
|
||||
|
||||
# 6. 创建DataFrame
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"Material": materials,
|
||||
"Diameter": diameters,
|
||||
"Flow Velocity": velocities,
|
||||
"Pressure": pressures,
|
||||
}
|
||||
)
|
||||
|
||||
# 7. 使用PipelineHealthAnalyzer进行预测
|
||||
analyzer = PipelineHealthAnalyzer(
|
||||
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
|
||||
)
|
||||
survival_functions = analyzer.predict_survival(data)
|
||||
# 8. 组合结果
|
||||
results = []
|
||||
for i, link_id in enumerate(link_ids):
|
||||
sf = survival_functions[i]
|
||||
results.append(
|
||||
{
|
||||
"link_id": link_id,
|
||||
"survival_function": {
|
||||
"x": sf.x.tolist(), # 时间点(年)
|
||||
"y": sf.y.tolist(), # 生存概率
|
||||
},
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"管道健康预测失败: {str(e)}")
|
||||
115
app/infra/db/timescaledb/database.py
Normal file
115
app/infra/db/timescaledb/database.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
import psycopg_pool
|
||||
from psycopg.rows import dict_row
|
||||
import app.infra.db.timescaledb.timescaledb_info as timescaledb_info
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_name=None):
|
||||
self.pool = None
|
||||
self.db_name = db_name
|
||||
|
||||
def init_pool(self, db_name=None):
|
||||
"""Initialize the connection pool."""
|
||||
# Use provided db_name, or the one from constructor, or default from config
|
||||
conn_string = timescaledb_info.get_pgconn_string()
|
||||
try:
|
||||
self.pool = psycopg_pool.AsyncConnectionPool(
|
||||
conninfo=conn_string,
|
||||
min_size=5,
|
||||
max_size=20,
|
||||
open=False, # Don't open immediately, wait for startup
|
||||
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
|
||||
)
|
||||
logger.info(
|
||||
f"TimescaleDB connection pool initialized for database: default"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
|
||||
raise
|
||||
|
||||
async def open(self):
|
||||
if self.pool:
|
||||
await self.pool.open()
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection pool."""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
logger.info("TimescaleDB connection pool closed.")
|
||||
|
||||
def get_pgconn_string(self, db_name=None):
|
||||
"""Get the TimescaleDB connection string."""
|
||||
target_db_name = db_name or self.db_name
|
||||
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection(self) -> AsyncGenerator:
|
||||
"""Get a connection from the pool."""
|
||||
if not self.pool:
|
||||
raise Exception("Database pool is not initialized.")
|
||||
|
||||
async with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# 默认数据库实例
|
||||
db = Database()
|
||||
|
||||
# 缓存不同数据库的实例 - 避免重复创建连接池
|
||||
_database_instances: Dict[str, Database] = {}
|
||||
|
||||
|
||||
def create_database_instance(db_name):
|
||||
"""Create a new Database instance for a specific database."""
|
||||
return Database(db_name=db_name)
|
||||
|
||||
|
||||
async def get_database_instance(db_name: Optional[str] = None) -> Database:
|
||||
"""Get or create a database instance for the specified database name."""
|
||||
if not db_name:
|
||||
return db # 返回默认数据库实例
|
||||
|
||||
if db_name not in _database_instances:
|
||||
# 创建新的数据库实例
|
||||
instance = create_database_instance(db_name)
|
||||
instance.init_pool()
|
||||
await instance.open()
|
||||
_database_instances[db_name] = instance
|
||||
logger.info(f"Created new database instance for: {db_name}")
|
||||
|
||||
return _database_instances[db_name]
|
||||
|
||||
|
||||
async def get_db_connection():
|
||||
"""Dependency for FastAPI to get a database connection."""
|
||||
async with db.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def get_database_connection(db_name: Optional[str] = None):
|
||||
"""
|
||||
FastAPI dependency to get database connection with optional database name.
|
||||
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
|
||||
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
|
||||
"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
async def cleanup_database_instances():
|
||||
"""Clean up all database instances (call this on application shutdown)."""
|
||||
for db_name, instance in _database_instances.items():
|
||||
await instance.close()
|
||||
logger.info(f"Closed database instance for: {db_name}")
|
||||
_database_instances.clear()
|
||||
|
||||
# 关闭默认数据库
|
||||
await db.close()
|
||||
logger.info("All database instances cleaned up.")
|
||||
122
app/infra/db/timescaledb/internal_queries.py
Normal file
122
app/infra/db/timescaledb/internal_queries.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi.logger import logger
|
||||
from timescaledb.schemas.scheme import SchemeRepository
|
||||
from timescaledb.schemas.realtime import RealtimeRepository
|
||||
import timescaledb.timescaledb_info as timescaledb_info
|
||||
from datetime import datetime, timedelta
|
||||
from timescaledb.schemas.scada import ScadaRepository
|
||||
import psycopg
|
||||
import time
|
||||
|
||||
|
||||
class InternalStorage:
|
||||
@staticmethod
|
||||
def store_realtime_simulation(
|
||||
node_result_list: List[dict],
|
||||
link_result_list: List[dict],
|
||||
result_start_time: str,
|
||||
db_name: str = None,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
"""存储实时模拟结果"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn_string = (
|
||||
timescaledb_info.get_pgconn_string(db_name=db_name)
|
||||
if db_name
|
||||
else timescaledb_info.get_pgconn_string()
|
||||
)
|
||||
with psycopg.Connection.connect(conn_string) as conn:
|
||||
RealtimeRepository.store_realtime_simulation_result_sync(
|
||||
conn, node_result_list, link_result_list, result_start_time
|
||||
)
|
||||
break # 成功
|
||||
except Exception as e:
|
||||
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1) # 重试前等待
|
||||
else:
|
||||
raise # 达到最大重试次数后抛出异常
|
||||
|
||||
@staticmethod
|
||||
def store_scheme_simulation(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
node_result_list: List[dict],
|
||||
link_result_list: List[dict],
|
||||
result_start_time: str,
|
||||
num_periods: int = 1,
|
||||
db_name: str = None,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
"""存储方案模拟结果"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn_string = (
|
||||
timescaledb_info.get_pgconn_string(db_name=db_name)
|
||||
if db_name
|
||||
else timescaledb_info.get_pgconn_string()
|
||||
)
|
||||
with psycopg.Connection.connect(conn_string) as conn:
|
||||
SchemeRepository.store_scheme_simulation_result_sync(
|
||||
conn,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
node_result_list,
|
||||
link_result_list,
|
||||
result_start_time,
|
||||
num_periods,
|
||||
)
|
||||
break # 成功
|
||||
except Exception as e:
|
||||
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1) # 重试前等待
|
||||
else:
|
||||
raise # 达到最大重试次数后抛出异常
|
||||
|
||||
|
||||
class InternalQueries:
|
||||
@staticmethod
|
||||
def query_scada_by_ids_time(
|
||||
device_ids: List[str],
|
||||
query_time: str,
|
||||
db_name: str = None,
|
||||
max_retries: int = 3,
|
||||
) -> dict:
|
||||
"""查询指定时间点的 SCADA 数据"""
|
||||
|
||||
# 解析时间,假设是北京时间
|
||||
beijing_time = datetime.fromisoformat(query_time)
|
||||
start_time = beijing_time - timedelta(seconds=1)
|
||||
end_time = beijing_time + timedelta(seconds=1)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn_string = (
|
||||
timescaledb_info.get_pgconn_string(db_name=db_name)
|
||||
if db_name
|
||||
else timescaledb_info.get_pgconn_string()
|
||||
)
|
||||
with psycopg.Connection.connect(conn_string) as conn:
|
||||
rows = ScadaRepository.get_scada_by_ids_time_range_sync(
|
||||
conn, device_ids, start_time, end_time
|
||||
)
|
||||
# 处理结果,返回每个 device_id 的第一个值
|
||||
result = {}
|
||||
for device_id in device_ids:
|
||||
device_rows = [
|
||||
row for row in rows if row["device_id"] == device_id
|
||||
]
|
||||
if device_rows:
|
||||
result[device_id] = device_rows[0]["monitored_value"]
|
||||
else:
|
||||
result[device_id] = None
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
else:
|
||||
raise
|
||||
627
app/infra/db/timescaledb/router.py
Normal file
627
app/infra/db/timescaledb/router.py
Normal file
@@ -0,0 +1,627 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
from .database import get_database_instance
|
||||
from .schemas.realtime import RealtimeRepository
|
||||
from .schemas.scheme import SchemeRepository
|
||||
from .schemas.scada import ScadaRepository
|
||||
from .composite_queries import CompositeQueries
|
||||
from postgresql.database import get_database_instance as get_postgres_database_instance
|
||||
|
||||
router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"])
|
||||
|
||||
|
||||
# 创建支持数据库选择的连接依赖函数
|
||||
async def get_database_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
):
|
||||
"""获取数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# PostgreSQL 数据库连接依赖函数
|
||||
async def get_postgres_connection(
|
||||
db_name: Optional[str] = Query(
|
||||
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
|
||||
)
|
||||
):
|
||||
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
|
||||
instance = await get_postgres_database_instance(db_name)
|
||||
async with instance.get_connection() as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
# --- Realtime Endpoints ---
|
||||
|
||||
|
||||
@router.post("/realtime/links/batch", status_code=201)
|
||||
async def insert_realtime_links(
|
||||
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
|
||||
):
|
||||
await RealtimeRepository.insert_links_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/realtime/links")
|
||||
async def get_realtime_links(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
return await RealtimeRepository.get_links_by_time_range(conn, start_time, end_time)
|
||||
|
||||
|
||||
@router.delete("/realtime/links")
|
||||
async def delete_realtime_links(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
await RealtimeRepository.delete_links_by_time_range(conn, start_time, end_time)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.patch("/realtime/links/{link_id}/field")
|
||||
async def update_realtime_link_field(
|
||||
link_id: str,
|
||||
time: datetime,
|
||||
field: str,
|
||||
value: float, # Assuming float for now, could be Any but FastAPI needs type
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
try:
|
||||
await RealtimeRepository.update_link_field(conn, time, link_id, field, value)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/realtime/nodes/batch", status_code=201)
|
||||
async def insert_realtime_nodes(
|
||||
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
|
||||
):
|
||||
await RealtimeRepository.insert_nodes_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/realtime/nodes")
|
||||
async def get_realtime_nodes(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
return await RealtimeRepository.get_nodes_by_time_range(conn, start_time, end_time)
|
||||
|
||||
|
||||
@router.delete("/realtime/nodes")
|
||||
async def delete_realtime_nodes(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
await RealtimeRepository.delete_nodes_by_time_range(conn, start_time, end_time)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/realtime/simulation/store", status_code=201)
|
||||
async def store_realtime_simulation_result(
|
||||
node_result_list: List[dict],
|
||||
link_result_list: List[dict],
|
||||
result_start_time: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""Store realtime simulation results to TimescaleDB"""
|
||||
await RealtimeRepository.store_realtime_simulation_result(
|
||||
conn, node_result_list, link_result_list, result_start_time
|
||||
)
|
||||
return {"message": "Simulation results stored successfully"}
|
||||
|
||||
|
||||
@router.get("/realtime/query/by-time-property")
|
||||
async def query_realtime_records_by_time_property(
|
||||
query_time: str,
|
||||
type: str,
|
||||
property: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""Query all realtime records by time and property"""
|
||||
try:
|
||||
results = await RealtimeRepository.query_all_record_by_time_property(
|
||||
conn, query_time, type, property
|
||||
)
|
||||
return {"results": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/realtime/query/by-id-time")
|
||||
async def query_realtime_simulation_by_id_time(
|
||||
id: str,
|
||||
type: str,
|
||||
query_time: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""Query realtime simulation results by id and time"""
|
||||
try:
|
||||
results = await RealtimeRepository.query_simulation_result_by_id_time(
|
||||
conn, id, type, query_time
|
||||
)
|
||||
return {"results": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# --- Scheme Endpoints ---
|
||||
|
||||
|
||||
@router.post("/scheme/links/batch", status_code=201)
|
||||
async def insert_scheme_links(
|
||||
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
|
||||
):
|
||||
await SchemeRepository.insert_links_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/scheme/links")
|
||||
async def get_scheme_links(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
return await SchemeRepository.get_links_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scheme/links/{link_id}/field")
|
||||
async def get_scheme_link_field(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
link_id: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
field: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
try:
|
||||
return await SchemeRepository.get_link_field_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, link_id, field
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/scheme/links/{link_id}/field")
|
||||
async def update_scheme_link_field(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
link_id: str,
|
||||
time: datetime,
|
||||
field: str,
|
||||
value: float,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
try:
|
||||
await SchemeRepository.update_link_field(
|
||||
conn, time, scheme_type, scheme_name, link_id, field, value
|
||||
)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/scheme/links")
|
||||
async def delete_scheme_links(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
await SchemeRepository.delete_links_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time
|
||||
)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/scheme/nodes/batch", status_code=201)
|
||||
async def insert_scheme_nodes(
|
||||
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
|
||||
):
|
||||
await SchemeRepository.insert_nodes_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/scheme/nodes/{node_id}/field")
|
||||
async def get_scheme_node_field(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
node_id: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
field: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
try:
|
||||
return await SchemeRepository.get_node_field_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, node_id, field
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/scheme/nodes/{node_id}/field")
|
||||
async def update_scheme_node_field(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
node_id: str,
|
||||
time: datetime,
|
||||
field: str,
|
||||
value: float,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
try:
|
||||
await SchemeRepository.update_node_field(
|
||||
conn, time, scheme_type, scheme_name, node_id, field, value
|
||||
)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/scheme/nodes")
|
||||
async def delete_scheme_nodes(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
await SchemeRepository.delete_nodes_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time
|
||||
)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/scheme/simulation/store", status_code=201)
|
||||
async def store_scheme_simulation_result(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
node_result_list: List[dict],
|
||||
link_result_list: List[dict],
|
||||
result_start_time: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""Store scheme simulation results to TimescaleDB"""
|
||||
await SchemeRepository.store_scheme_simulation_result(
|
||||
conn,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
node_result_list,
|
||||
link_result_list,
|
||||
result_start_time,
|
||||
)
|
||||
return {"message": "Scheme simulation results stored successfully"}
|
||||
|
||||
|
||||
@router.get("/scheme/query/by-scheme-time-property")
|
||||
async def query_scheme_records_by_scheme_time_property(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_time: str,
|
||||
type: str,
|
||||
property: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""Query all scheme records by scheme, time and property"""
|
||||
try:
|
||||
results = await SchemeRepository.query_all_record_by_scheme_time_property(
|
||||
conn, scheme_type, scheme_name, query_time, type, property
|
||||
)
|
||||
return {"results": results}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/scheme/query/by-id-time")
|
||||
async def query_scheme_simulation_by_id_time(
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
id: str,
|
||||
type: str,
|
||||
query_time: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""Query scheme simulation results by id and time"""
|
||||
try:
|
||||
result = await SchemeRepository.query_scheme_simulation_result_by_id_time(
|
||||
conn, scheme_type, scheme_name, id, type, query_time
|
||||
)
|
||||
return {"result": result}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# --- SCADA Endpoints ---
|
||||
|
||||
|
||||
@router.post("/scada/batch", status_code=201)
|
||||
async def insert_scada_data(
|
||||
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
|
||||
):
|
||||
await ScadaRepository.insert_scada_batch(conn, data)
|
||||
return {"message": f"Inserted {len(data)} records"}
|
||||
|
||||
|
||||
@router.get("/scada/by-ids-time-range")
|
||||
async def get_scada_by_ids_time_range(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
device_ids: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()] if device_ids else []
|
||||
)
|
||||
return await ScadaRepository.get_scada_by_ids_time_range(
|
||||
conn, device_ids_list, start_time, end_time
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scada/by-ids-field-time-range")
|
||||
async def get_scada_field_by_ids_time_range(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
field: str,
|
||||
device_ids: str,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
try:
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()]
|
||||
if device_ids
|
||||
else []
|
||||
)
|
||||
return await ScadaRepository.get_scada_field_by_id_time_range(
|
||||
conn, device_ids_list, start_time, end_time, field
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/scada/{device_id}/field")
|
||||
async def update_scada_field(
|
||||
device_id: str,
|
||||
time: datetime,
|
||||
field: str,
|
||||
value: float,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
try:
|
||||
await ScadaRepository.update_scada_field(conn, time, device_id, field, value)
|
||||
return {"message": "Updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/scada/by-id-time-range")
|
||||
async def delete_scada_data(
|
||||
device_id: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
await ScadaRepository.delete_scada_by_id_time_range(
|
||||
conn, device_id, start_time, end_time
|
||||
)
|
||||
return {"message": "Deleted successfully"}
|
||||
|
||||
|
||||
# --- Composite Query Endpoints ---
|
||||
|
||||
|
||||
@router.get("/composite/scada-simulation")
|
||||
async def get_scada_associated_simulation_data(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
device_ids: str,
|
||||
scheme_type: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
|
||||
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
|
||||
timescale_conn: AsyncConnection = Depends(get_database_connection),
|
||||
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
|
||||
):
|
||||
"""
|
||||
获取 SCADA 关联的 link/node 模拟值
|
||||
|
||||
根据传入的 SCADA device_ids,找到关联的 link/node,
|
||||
并根据对应的 type,查询对应的模拟数据
|
||||
"""
|
||||
try:
|
||||
# 手动解析 device_ids 为 List[str],去除空格
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()]
|
||||
if device_ids
|
||||
else []
|
||||
)
|
||||
|
||||
if scheme_type and scheme_name:
|
||||
result = await CompositeQueries.get_scada_associated_scheme_simulation_data(
|
||||
timescale_conn,
|
||||
postgres_conn,
|
||||
device_ids_list, # 使用解析后的列表
|
||||
start_time,
|
||||
end_time,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
)
|
||||
else:
|
||||
result = (
|
||||
await CompositeQueries.get_scada_associated_realtime_simulation_data(
|
||||
timescale_conn,
|
||||
postgres_conn,
|
||||
device_ids_list, # 使用解析后的列表
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="No simulation data found")
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/composite/element-simulation")
|
||||
async def get_feature_simulation_data(
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
feature_infos: str = Query(
|
||||
..., description="特征信息,格式: id1:type1,id2:type2,type为pipe或junction"
|
||||
),
|
||||
scheme_type: str = Query(None, description="指定方案类型,若为空则查询实时数据"),
|
||||
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
|
||||
timescale_conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
获取 link/node 模拟值
|
||||
|
||||
根据传入的 featureInfos,找到关联的 link/node,
|
||||
并根据对应的 type,查询对应的模拟数据
|
||||
|
||||
Args:
|
||||
feature_infos: 格式为 "element_id1:type1,element_id2:type2"
|
||||
例如: "P1:pipe,J1:junction"
|
||||
"""
|
||||
try:
|
||||
# 解析 feature_infos 为 List[Tuple[str, str]]
|
||||
feature_infos_list = []
|
||||
if feature_infos:
|
||||
for item in feature_infos.split(","):
|
||||
item = item.strip()
|
||||
if ":" in item:
|
||||
element_id, element_type = item.split(":", 1)
|
||||
feature_infos_list.append(
|
||||
(element_id.strip(), element_type.strip())
|
||||
)
|
||||
|
||||
if not feature_infos_list:
|
||||
raise HTTPException(status_code=400, detail="feature_infos cannot be empty")
|
||||
|
||||
if scheme_type and scheme_name:
|
||||
result = await CompositeQueries.get_scheme_simulation_data(
|
||||
timescale_conn,
|
||||
feature_infos_list,
|
||||
start_time,
|
||||
end_time,
|
||||
scheme_type,
|
||||
scheme_name,
|
||||
)
|
||||
else:
|
||||
result = await CompositeQueries.get_realtime_simulation_data(
|
||||
timescale_conn,
|
||||
feature_infos_list,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="No simulation data found")
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/composite/element-scada")
|
||||
async def get_element_associated_scada_data(
|
||||
element_id: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
use_cleaned: bool = Query(False, description="是否使用清洗后的数据"),
|
||||
timescale_conn: AsyncConnection = Depends(get_database_connection),
|
||||
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
|
||||
):
|
||||
"""
|
||||
获取 link/node 关联的 SCADA 监测值
|
||||
|
||||
根据传入的 link/node id,匹配 SCADA 信息,
|
||||
如果存在关联的 SCADA device_id,获取实际的监测数据
|
||||
"""
|
||||
try:
|
||||
result = await CompositeQueries.get_element_associated_scada_data(
|
||||
timescale_conn, postgres_conn, element_id, start_time, end_time, use_cleaned
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No associated SCADA data found"
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/composite/clean-scada")
|
||||
async def clean_scada_data(
|
||||
device_ids: str,
|
||||
start_time: datetime = Query(...),
|
||||
end_time: datetime = Query(...),
|
||||
timescale_conn: AsyncConnection = Depends(get_database_connection),
|
||||
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
|
||||
):
|
||||
"""
|
||||
清洗 SCADA 数据
|
||||
|
||||
根据 device_ids 查询 monitored_value,清洗后更新 cleaned_value
|
||||
"""
|
||||
try:
|
||||
if device_ids == "all":
|
||||
device_ids_list = []
|
||||
else:
|
||||
device_ids_list = (
|
||||
[id.strip() for id in device_ids.split(",") if id.strip()]
|
||||
if device_ids
|
||||
else []
|
||||
)
|
||||
return await CompositeQueries.clean_scada_data(
|
||||
timescale_conn, postgres_conn, device_ids_list, start_time, end_time
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/composite/pipeline-health-prediction")
|
||||
async def predict_pipeline_health(
|
||||
query_time: datetime = Query(..., description="查询时间"),
|
||||
network_name: str = Query(..., description="管网数据库名称"),
|
||||
timescale_conn: AsyncConnection = Depends(get_database_connection),
|
||||
):
|
||||
"""
|
||||
预测管道健康状况
|
||||
|
||||
根据管网名称和当前时间,查询管道信息和实时数据,
|
||||
使用随机生存森林模型预测管道的生存概率
|
||||
|
||||
Args:
|
||||
query_time: 查询时间
|
||||
db_name: 管网数据库名称
|
||||
|
||||
Returns:
|
||||
预测结果列表,每个元素包含 link_id 和对应的生存函数
|
||||
"""
|
||||
try:
|
||||
return await CompositeQueries.predict_pipeline_health(
|
||||
timescale_conn, network_name, query_time
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
|
||||
647
app/infra/db/timescaledb/schemas/realtime.py
Normal file
647
app/infra/db/timescaledb/schemas/realtime.py
Normal file
@@ -0,0 +1,647 @@
|
||||
from typing import List, Any, Dict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections import defaultdict
|
||||
from psycopg import AsyncConnection, Connection, sql
|
||||
|
||||
# 定义UTC+8时区
|
||||
UTC_8 = timezone(timedelta(hours=8))
|
||||
|
||||
|
||||
class RealtimeRepository:
|
||||
|
||||
# --- Link Simulation ---
|
||||
|
||||
@staticmethod
|
||||
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
|
||||
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance."""
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 假设同一批次的数据时间是相同的
|
||||
target_time = data[0]["time"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
async with conn.transaction():
|
||||
async with conn.cursor() as cur:
|
||||
# 1. 先删除该时间点的旧数据
|
||||
await cur.execute(
|
||||
"DELETE FROM realtime.link_simulation WHERE time = %s",
|
||||
(target_time,),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
async with cur.copy(
|
||||
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
await copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["id"],
|
||||
item.get("flow"),
|
||||
item.get("friction"),
|
||||
item.get("headloss"),
|
||||
item.get("quality"),
|
||||
item.get("reaction"),
|
||||
item.get("setting"),
|
||||
item.get("status"),
|
||||
item.get("velocity"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def insert_links_batch_sync(conn: Connection, data: List[dict]):
|
||||
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance (sync version)."""
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 假设同一批次的数据时间是相同的
|
||||
target_time = data[0]["time"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
with conn.transaction():
|
||||
with conn.cursor() as cur:
|
||||
# 1. 先删除该时间点的旧数据
|
||||
cur.execute(
|
||||
"DELETE FROM realtime.link_simulation WHERE time = %s",
|
||||
(target_time,),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
with cur.copy(
|
||||
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["id"],
|
||||
item.get("flow"),
|
||||
item.get("friction"),
|
||||
item.get("headloss"),
|
||||
item.get("quality"),
|
||||
item.get("reaction"),
|
||||
item.get("setting"),
|
||||
item.get("status"),
|
||||
item.get("velocity"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_link_by_time_range(
|
||||
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s",
|
||||
(start_time, end_time, link_id),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_links_by_time_range(
|
||||
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
|
||||
(start_time, end_time),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_link_field_by_time_range(
|
||||
conn: AsyncConnection,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
link_id: str,
|
||||
field: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
# Validate field name to prevent SQL injection
|
||||
valid_fields = {
|
||||
"flow",
|
||||
"friction",
|
||||
"headloss",
|
||||
"quality",
|
||||
"reaction",
|
||||
"setting",
|
||||
"status",
|
||||
"velocity",
|
||||
}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (start_time, end_time, link_id))
|
||||
rows = await cur.fetchall()
|
||||
return [
|
||||
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_links_field_by_time_range(
|
||||
conn: AsyncConnection,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
field: str,
|
||||
) -> dict:
|
||||
# Validate field name to prevent SQL injection
|
||||
valid_fields = {
|
||||
"flow",
|
||||
"friction",
|
||||
"headloss",
|
||||
"quality",
|
||||
"reaction",
|
||||
"setting",
|
||||
"status",
|
||||
"velocity",
|
||||
}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT id, time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (start_time, end_time))
|
||||
rows = await cur.fetchall()
|
||||
result = defaultdict(list)
|
||||
for row in rows:
|
||||
result[row["id"]].append(
|
||||
{"time": row["time"].isoformat(), "value": row[field]}
|
||||
)
|
||||
return dict(result)
|
||||
|
||||
@staticmethod
|
||||
async def update_link_field(
|
||||
conn: AsyncConnection,
|
||||
time: datetime,
|
||||
link_id: str,
|
||||
field: str,
|
||||
value: Any,
|
||||
):
|
||||
valid_fields = {
|
||||
"flow",
|
||||
"friction",
|
||||
"headloss",
|
||||
"quality",
|
||||
"reaction",
|
||||
"setting",
|
||||
"status",
|
||||
"velocity",
|
||||
}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (value, time, link_id))
|
||||
|
||||
@staticmethod
|
||||
async def delete_links_by_time_range(
|
||||
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
||||
):
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
|
||||
(start_time, end_time),
|
||||
)
|
||||
|
||||
# --- Node Simulation ---
|
||||
|
||||
@staticmethod
|
||||
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 假设同一批次的数据时间是相同的
|
||||
target_time = data[0]["time"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
async with conn.transaction():
|
||||
async with conn.cursor() as cur:
|
||||
# 1. 先删除该时间点的旧数据
|
||||
await cur.execute(
|
||||
"DELETE FROM realtime.node_simulation WHERE time = %s",
|
||||
(target_time,),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
async with cur.copy(
|
||||
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
await copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["id"],
|
||||
item.get("actual_demand"),
|
||||
item.get("total_head"),
|
||||
item.get("pressure"),
|
||||
item.get("quality"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 假设同一批次的数据时间是相同的
|
||||
target_time = data[0]["time"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
with conn.transaction():
|
||||
with conn.cursor() as cur:
|
||||
# 1. 先删除该时间点的旧数据
|
||||
cur.execute(
|
||||
"DELETE FROM realtime.node_simulation WHERE time = %s",
|
||||
(target_time,),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
with cur.copy(
|
||||
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["id"],
|
||||
item.get("actual_demand"),
|
||||
item.get("total_head"),
|
||||
item.get("pressure"),
|
||||
item.get("quality"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_node_by_time_range(
|
||||
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s",
|
||||
(start_time, end_time, node_id),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_nodes_by_time_range(
|
||||
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
|
||||
(start_time, end_time),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_node_field_by_time_range(
|
||||
conn: AsyncConnection,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
node_id: str,
|
||||
field: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (start_time, end_time, node_id))
|
||||
rows = await cur.fetchall()
|
||||
return [
|
||||
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_nodes_field_by_time_range(
|
||||
conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str
|
||||
) -> dict:
|
||||
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT id, time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (start_time, end_time))
|
||||
rows = await cur.fetchall()
|
||||
result = defaultdict(list)
|
||||
for row in rows:
|
||||
result[row["id"]].append(
|
||||
{"time": row["time"].isoformat(), "value": row[field]}
|
||||
)
|
||||
return dict(result)
|
||||
|
||||
@staticmethod
|
||||
async def update_node_field(
|
||||
conn: AsyncConnection,
|
||||
time: datetime,
|
||||
node_id: str,
|
||||
field: str,
|
||||
value: Any,
|
||||
):
|
||||
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (value, time, node_id))
|
||||
|
||||
@staticmethod
|
||||
async def delete_nodes_by_time_range(
|
||||
conn: AsyncConnection, start_time: datetime, end_time: datetime
|
||||
):
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
|
||||
(start_time, end_time),
|
||||
)
|
||||
|
||||
# --- 复合查询 ---
|
||||
|
||||
@staticmethod
|
||||
async def store_realtime_simulation_result(
|
||||
conn: AsyncConnection,
|
||||
node_result_list: List[Dict[str, any]],
|
||||
link_result_list: List[Dict[str, any]],
|
||||
result_start_time: str,
|
||||
):
|
||||
"""
|
||||
Store realtime simulation results to TimescaleDB.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
node_result_list: List of node simulation results
|
||||
link_result_list: List of link simulation results
|
||||
result_start_time: Start time for the results (ISO format string)
|
||||
"""
|
||||
# Convert result_start_time string to datetime if needed
|
||||
if isinstance(result_start_time, str):
|
||||
# 如果是ISO格式字符串,解析并转换为UTC+8
|
||||
if result_start_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(
|
||||
result_start_time.replace("Z", "+00:00")
|
||||
)
|
||||
simulation_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
simulation_time = datetime.fromisoformat(result_start_time)
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
simulation_time = result_start_time
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
|
||||
# Prepare node data for batch insert
|
||||
node_data = []
|
||||
for node_result in node_result_list:
|
||||
node_id = node_result.get("node")
|
||||
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
|
||||
node_data.append(
|
||||
{
|
||||
"time": simulation_time,
|
||||
"id": node_id,
|
||||
"actual_demand": data.get("demand"),
|
||||
"total_head": data.get("head"),
|
||||
"pressure": data.get("pressure"),
|
||||
"quality": data.get("quality"),
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare link data for batch insert
|
||||
link_data = []
|
||||
for link_result in link_result_list:
|
||||
link_id = link_result.get("link")
|
||||
data = link_result.get("result", [])[0]
|
||||
link_data.append(
|
||||
{
|
||||
"time": simulation_time,
|
||||
"id": link_id,
|
||||
"flow": data.get("flow"),
|
||||
"friction": data.get("friction"),
|
||||
"headloss": data.get("headloss"),
|
||||
"quality": data.get("quality"),
|
||||
"reaction": data.get("reaction"),
|
||||
"setting": data.get("setting"),
|
||||
"status": data.get("status"),
|
||||
"velocity": data.get("velocity"),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert data using batch methods
|
||||
if node_data:
|
||||
await RealtimeRepository.insert_nodes_batch(conn, node_data)
|
||||
|
||||
if link_data:
|
||||
await RealtimeRepository.insert_links_batch(conn, link_data)
|
||||
|
||||
@staticmethod
|
||||
def store_realtime_simulation_result_sync(
|
||||
conn: Connection,
|
||||
node_result_list: List[Dict[str, any]],
|
||||
link_result_list: List[Dict[str, any]],
|
||||
result_start_time: str,
|
||||
):
|
||||
"""
|
||||
Store realtime simulation results to TimescaleDB (sync version).
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
node_result_list: List of node simulation results
|
||||
link_result_list: List of link simulation results
|
||||
result_start_time: Start time for the results (ISO format string)
|
||||
"""
|
||||
# Convert result_start_time string to datetime if needed
|
||||
if isinstance(result_start_time, str):
|
||||
# 如果是ISO格式字符串,解析并转换为UTC+8
|
||||
if result_start_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(
|
||||
result_start_time.replace("Z", "+00:00")
|
||||
)
|
||||
simulation_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
simulation_time = datetime.fromisoformat(result_start_time)
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
simulation_time = result_start_time
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
|
||||
# Prepare node data for batch insert
|
||||
node_data = []
|
||||
for node_result in node_result_list:
|
||||
node_id = node_result.get("node")
|
||||
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
|
||||
node_data.append(
|
||||
{
|
||||
"time": simulation_time,
|
||||
"id": node_id,
|
||||
"actual_demand": data.get("demand"),
|
||||
"total_head": data.get("head"),
|
||||
"pressure": data.get("pressure"),
|
||||
"quality": data.get("quality"),
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare link data for batch insert
|
||||
link_data = []
|
||||
for link_result in link_result_list:
|
||||
link_id = link_result.get("link")
|
||||
data = link_result.get("result", [])[0]
|
||||
link_data.append(
|
||||
{
|
||||
"time": simulation_time,
|
||||
"id": link_id,
|
||||
"flow": data.get("flow"),
|
||||
"friction": data.get("friction"),
|
||||
"headloss": data.get("headloss"),
|
||||
"quality": data.get("quality"),
|
||||
"reaction": data.get("reaction"),
|
||||
"setting": data.get("setting"),
|
||||
"status": data.get("status"),
|
||||
"velocity": data.get("velocity"),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert data using batch methods
|
||||
if node_data:
|
||||
RealtimeRepository.insert_nodes_batch_sync(conn, node_data)
|
||||
|
||||
if link_data:
|
||||
RealtimeRepository.insert_links_batch_sync(conn, link_data)
|
||||
|
||||
@staticmethod
|
||||
async def query_all_record_by_time_property(
|
||||
conn: AsyncConnection,
|
||||
query_time: str,
|
||||
type: str,
|
||||
property: str,
|
||||
) -> list:
|
||||
"""
|
||||
Query all records by time and property from TimescaleDB.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
query_time: Time to query (ISO format string)
|
||||
type: Type of data ("node" or "link")
|
||||
property: Property/field to query
|
||||
|
||||
Returns:
|
||||
List of records matching the criteria
|
||||
"""
|
||||
# Convert query_time string to datetime
|
||||
if isinstance(query_time, str):
|
||||
if query_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
|
||||
target_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
target_time = datetime.fromisoformat(query_time)
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
target_time = query_time
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
|
||||
# Create time range: query_time ± 1 second
|
||||
start_time = target_time - timedelta(seconds=1)
|
||||
end_time = target_time + timedelta(seconds=1)
|
||||
|
||||
# Query based on type
|
||||
if type.lower() == "node":
|
||||
data = await RealtimeRepository.get_nodes_field_by_time_range(
|
||||
conn, start_time, end_time, property
|
||||
)
|
||||
elif type.lower() == "link":
|
||||
data = await RealtimeRepository.get_links_field_by_time_range(
|
||||
conn, start_time, end_time, property
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
|
||||
|
||||
# Format the results
|
||||
result = []
|
||||
for id, items in data.items():
|
||||
for item in items:
|
||||
result.append({"ID": id, "value": item["value"]})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def query_simulation_result_by_id_time(
|
||||
conn: AsyncConnection,
|
||||
id: str,
|
||||
type: str,
|
||||
query_time: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Query simulation results by id and time from TimescaleDB.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
id: The id of the node or link
|
||||
type: Type of data ("node" or "link")
|
||||
query_time: Time to query (ISO format string)
|
||||
|
||||
Returns:
|
||||
List of records matching the criteria
|
||||
"""
|
||||
# Convert query_time string to datetime
|
||||
if isinstance(query_time, str):
|
||||
if query_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
|
||||
target_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
target_time = datetime.fromisoformat(query_time)
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
target_time = query_time
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
|
||||
# Create time range: query_time ± 1 second
|
||||
start_time = target_time - timedelta(seconds=1)
|
||||
end_time = target_time + timedelta(seconds=1)
|
||||
|
||||
# Query based on type
|
||||
if type.lower() == "node":
|
||||
return await RealtimeRepository.get_node_by_time_range(
|
||||
conn, start_time, end_time, id
|
||||
)
|
||||
elif type.lower() == "link":
|
||||
return await RealtimeRepository.get_link_by_time_range(
|
||||
conn, start_time, end_time, id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
|
||||
106
app/infra/db/timescaledb/schemas/scada.py
Normal file
106
app/infra/db/timescaledb/schemas/scada.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from typing import List, Any
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from psycopg import AsyncConnection, Connection, sql
|
||||
|
||||
|
||||
class ScadaRepository:
|
||||
|
||||
@staticmethod
|
||||
async def insert_scada_batch(conn: AsyncConnection, data: List[dict]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
async with cur.copy(
|
||||
"COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
await copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["device_id"],
|
||||
item.get("monitored_value"),
|
||||
item.get("cleaned_value"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_scada_by_ids_time_range(
|
||||
conn: AsyncConnection,
|
||||
device_ids: List[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
|
||||
(device_ids, start_time, end_time),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
def get_scada_by_ids_time_range_sync(
|
||||
conn: Connection,
|
||||
device_ids: List[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> List[dict]:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
|
||||
(device_ids, start_time, end_time),
|
||||
)
|
||||
return cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_scada_field_by_id_time_range(
|
||||
conn: AsyncConnection,
|
||||
device_ids: List[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
field: str,
|
||||
) -> dict:
|
||||
valid_fields = {"monitored_value", "cleaned_value"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT device_id, time, {} FROM scada.scada_data WHERE time >= %s AND time <= %s AND device_id = ANY(%s)"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (start_time, end_time, device_ids))
|
||||
rows = await cur.fetchall()
|
||||
result = defaultdict(list)
|
||||
for row in rows:
|
||||
result[row["device_id"]].append({
|
||||
"time": row["time"].isoformat(),
|
||||
"value": row[field]
|
||||
})
|
||||
return dict(result)
|
||||
|
||||
@staticmethod
|
||||
async def update_scada_field(
|
||||
conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any
|
||||
):
|
||||
valid_fields = {"monitored_value", "cleaned_value"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (value, time, device_id))
|
||||
|
||||
@staticmethod
|
||||
async def delete_scada_by_id_time_range(
|
||||
conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime
|
||||
):
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s",
|
||||
(device_id, start_time, end_time),
|
||||
)
|
||||
760
app/infra/db/timescaledb/schemas/scheme.py
Normal file
760
app/infra/db/timescaledb/schemas/scheme.py
Normal file
@@ -0,0 +1,760 @@
|
||||
from typing import List, Any, Dict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections import defaultdict
|
||||
from psycopg import AsyncConnection, Connection, sql
|
||||
import globals
|
||||
|
||||
# 定义UTC+8时区
|
||||
UTC_8 = timezone(timedelta(hours=8))
|
||||
|
||||
|
||||
class SchemeRepository:
|
||||
|
||||
# --- Link Simulation ---
|
||||
|
||||
@staticmethod
|
||||
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
|
||||
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance."""
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 获取批次中所有不同的时间点
|
||||
all_times = list(set(item["time"] for item in data))
|
||||
target_scheme_type = data[0]["scheme_type"]
|
||||
target_scheme_name = data[0]["scheme_name"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
async with conn.transaction():
|
||||
async with conn.cursor() as cur:
|
||||
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
|
||||
await cur.execute(
|
||||
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
|
||||
(all_times, target_scheme_type, target_scheme_name),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
async with cur.copy(
|
||||
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
await copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["scheme_type"],
|
||||
item["scheme_name"],
|
||||
item["id"],
|
||||
item.get("flow"),
|
||||
item.get("friction"),
|
||||
item.get("headloss"),
|
||||
item.get("quality"),
|
||||
item.get("reaction"),
|
||||
item.get("setting"),
|
||||
item.get("status"),
|
||||
item.get("velocity"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def insert_links_batch_sync(conn: Connection, data: List[dict]):
|
||||
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance (sync version)."""
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 获取批次中所有不同的时间点
|
||||
all_times = list(set(item["time"] for item in data))
|
||||
target_scheme_type = data[0]["scheme_type"]
|
||||
target_scheme_name = data[0]["scheme_name"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
with conn.transaction():
|
||||
with conn.cursor() as cur:
|
||||
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
|
||||
cur.execute(
|
||||
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
|
||||
(all_times, target_scheme_type, target_scheme_name),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
with cur.copy(
|
||||
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["scheme_type"],
|
||||
item["scheme_name"],
|
||||
item["id"],
|
||||
item.get("flow"),
|
||||
item.get("friction"),
|
||||
item.get("headloss"),
|
||||
item.get("quality"),
|
||||
item.get("reaction"),
|
||||
item.get("setting"),
|
||||
item.get("status"),
|
||||
item.get("velocity"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_link_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
link_id: str,
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
|
||||
(scheme_type, scheme_name, start_time, end_time, link_id),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_links_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
|
||||
(scheme_type, scheme_name, start_time, end_time),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_link_field_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
link_id: str,
|
||||
field: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
# Validate field name to prevent SQL injection
|
||||
valid_fields = {
|
||||
"flow",
|
||||
"friction",
|
||||
"headloss",
|
||||
"quality",
|
||||
"reaction",
|
||||
"setting",
|
||||
"status",
|
||||
"velocity",
|
||||
}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
query, (scheme_type, scheme_name, start_time, end_time, link_id)
|
||||
)
|
||||
rows = await cur.fetchall()
|
||||
return [
|
||||
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_links_field_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
field: str,
|
||||
) -> dict:
|
||||
# Validate field name to prevent SQL injection
|
||||
valid_fields = {
|
||||
"flow",
|
||||
"friction",
|
||||
"headloss",
|
||||
"quality",
|
||||
"reaction",
|
||||
"setting",
|
||||
"status",
|
||||
"velocity",
|
||||
}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT id, time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
|
||||
rows = await cur.fetchall()
|
||||
result = defaultdict(list)
|
||||
for row in rows:
|
||||
result[row["id"]].append(
|
||||
{"time": row["time"].isoformat(), "value": row[field]}
|
||||
)
|
||||
return dict(result)
|
||||
|
||||
@staticmethod
|
||||
async def update_link_field(
|
||||
conn: AsyncConnection,
|
||||
time: datetime,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
link_id: str,
|
||||
field: str,
|
||||
value: Any,
|
||||
):
|
||||
valid_fields = {
|
||||
"flow",
|
||||
"friction",
|
||||
"headloss",
|
||||
"quality",
|
||||
"reaction",
|
||||
"setting",
|
||||
"status",
|
||||
"velocity",
|
||||
}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (value, time, scheme_type, scheme_name, link_id))
|
||||
|
||||
@staticmethod
|
||||
async def delete_links_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
):
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
|
||||
(scheme_type, scheme_name, start_time, end_time),
|
||||
)
|
||||
|
||||
# --- Node Simulation ---
|
||||
|
||||
@staticmethod
|
||||
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 获取批次中所有不同的时间点
|
||||
all_times = list(set(item["time"] for item in data))
|
||||
target_scheme_type = data[0]["scheme_type"]
|
||||
target_scheme_name = data[0]["scheme_name"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
async with conn.transaction():
|
||||
async with conn.cursor() as cur:
|
||||
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
|
||||
await cur.execute(
|
||||
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
|
||||
(all_times, target_scheme_type, target_scheme_name),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
async with cur.copy(
|
||||
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
await copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["scheme_type"],
|
||||
item["scheme_name"],
|
||||
item["id"],
|
||||
item.get("actual_demand"),
|
||||
item.get("total_head"),
|
||||
item.get("pressure"),
|
||||
item.get("quality"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
|
||||
if not data:
|
||||
return
|
||||
|
||||
# 获取批次中所有不同的时间点
|
||||
all_times = list(set(item["time"] for item in data))
|
||||
target_scheme_type = data[0]["scheme_type"]
|
||||
target_scheme_name = data[0]["scheme_name"]
|
||||
|
||||
# 使用事务确保原子性
|
||||
with conn.transaction():
|
||||
with conn.cursor() as cur:
|
||||
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
|
||||
cur.execute(
|
||||
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
|
||||
(all_times, target_scheme_type, target_scheme_name),
|
||||
)
|
||||
|
||||
# 2. 使用 COPY 快速写入新数据
|
||||
with cur.copy(
|
||||
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
|
||||
) as copy:
|
||||
for item in data:
|
||||
copy.write_row(
|
||||
(
|
||||
item["time"],
|
||||
item["scheme_type"],
|
||||
item["scheme_name"],
|
||||
item["id"],
|
||||
item.get("actual_demand"),
|
||||
item.get("total_head"),
|
||||
item.get("pressure"),
|
||||
item.get("quality"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_node_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
node_id: str,
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
|
||||
(scheme_type, scheme_name, start_time, end_time, node_id),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_nodes_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> List[dict]:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
|
||||
(scheme_type, scheme_name, start_time, end_time),
|
||||
)
|
||||
return await cur.fetchall()
|
||||
|
||||
@staticmethod
|
||||
async def get_node_field_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
node_id: str,
|
||||
field: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
# Validate field name to prevent SQL injection
|
||||
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
query, (scheme_type, scheme_name, start_time, end_time, node_id)
|
||||
)
|
||||
rows = await cur.fetchall()
|
||||
return [
|
||||
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
async def get_nodes_field_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
field: str,
|
||||
) -> dict:
|
||||
# Validate field name to prevent SQL injection
|
||||
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"SELECT id, time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
|
||||
rows = await cur.fetchall()
|
||||
result = defaultdict(list)
|
||||
for row in rows:
|
||||
result[row["id"]].append(
|
||||
{"time": row["time"].isoformat(), "value": row[field]}
|
||||
)
|
||||
return dict(result)
|
||||
|
||||
@staticmethod
|
||||
async def update_node_field(
|
||||
conn: AsyncConnection,
|
||||
time: datetime,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
node_id: str,
|
||||
field: str,
|
||||
value: Any,
|
||||
):
|
||||
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
|
||||
if field not in valid_fields:
|
||||
raise ValueError(f"Invalid field: {field}")
|
||||
|
||||
query = sql.SQL(
|
||||
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
|
||||
).format(sql.Identifier(field))
|
||||
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, (value, time, scheme_type, scheme_name, node_id))
|
||||
|
||||
@staticmethod
|
||||
async def delete_nodes_by_scheme_and_time_range(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
):
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(
|
||||
"DELETE FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
|
||||
(scheme_type, scheme_name, start_time, end_time),
|
||||
)
|
||||
|
||||
# --- 复合查询 ---
|
||||
|
||||
@staticmethod
|
||||
async def store_scheme_simulation_result(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
node_result_list: List[Dict[str, any]],
|
||||
link_result_list: List[Dict[str, any]],
|
||||
result_start_time: str,
|
||||
num_periods: int = 1,
|
||||
):
|
||||
"""
|
||||
Store scheme simulation results to TimescaleDB.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
scheme_type: Scheme type
|
||||
scheme_name: Scheme name
|
||||
node_result_list: List of node simulation results
|
||||
link_result_list: List of link simulation results
|
||||
result_start_time: Start time for the results (ISO format string)
|
||||
"""
|
||||
# Convert result_start_time string to datetime if needed
|
||||
if isinstance(result_start_time, str):
|
||||
# 如果是ISO格式字符串,解析并转换为UTC+8
|
||||
if result_start_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(
|
||||
result_start_time.replace("Z", "+00:00")
|
||||
)
|
||||
simulation_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
simulation_time = datetime.fromisoformat(result_start_time)
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
simulation_time = result_start_time
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
|
||||
timestep_parts = globals.hydraulic_timestep.split(":")
|
||||
timestep = timedelta(
|
||||
hours=int(timestep_parts[0]),
|
||||
minutes=int(timestep_parts[1]),
|
||||
seconds=int(timestep_parts[2]),
|
||||
)
|
||||
|
||||
# Prepare node data for batch insert
|
||||
node_data = []
|
||||
for node_result in node_result_list:
|
||||
node_id = node_result.get("node")
|
||||
for period_index in range(num_periods):
|
||||
current_time = simulation_time + (timestep * period_index)
|
||||
data = node_result.get("result", [])[period_index]
|
||||
node_data.append(
|
||||
{
|
||||
"time": current_time,
|
||||
"scheme_type": scheme_type,
|
||||
"scheme_name": scheme_name,
|
||||
"id": node_id,
|
||||
"actual_demand": data.get("demand"),
|
||||
"total_head": data.get("head"),
|
||||
"pressure": data.get("pressure"),
|
||||
"quality": data.get("quality"),
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare link data for batch insert
|
||||
link_data = []
|
||||
for link_result in link_result_list:
|
||||
link_id = link_result.get("link")
|
||||
for period_index in range(num_periods):
|
||||
current_time = simulation_time + (timestep * period_index)
|
||||
data = link_result.get("result", [])[period_index]
|
||||
link_data.append(
|
||||
{
|
||||
"time": current_time,
|
||||
"scheme_type": scheme_type,
|
||||
"scheme_name": scheme_name,
|
||||
"id": link_id,
|
||||
"flow": data.get("flow"),
|
||||
"friction": data.get("friction"),
|
||||
"headloss": data.get("headloss"),
|
||||
"quality": data.get("quality"),
|
||||
"reaction": data.get("reaction"),
|
||||
"setting": data.get("setting"),
|
||||
"status": data.get("status"),
|
||||
"velocity": data.get("velocity"),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert data using batch methods
|
||||
if node_data:
|
||||
await SchemeRepository.insert_nodes_batch(conn, node_data)
|
||||
|
||||
if link_data:
|
||||
await SchemeRepository.insert_links_batch(conn, link_data)
|
||||
|
||||
@staticmethod
|
||||
def store_scheme_simulation_result_sync(
|
||||
conn: Connection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
node_result_list: List[Dict[str, any]],
|
||||
link_result_list: List[Dict[str, any]],
|
||||
result_start_time: str,
|
||||
num_periods: int = 1,
|
||||
):
|
||||
"""
|
||||
Store scheme simulation results to TimescaleDB (sync version).
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
scheme_type: Scheme type
|
||||
scheme_name: Scheme name
|
||||
node_result_list: List of node simulation results
|
||||
link_result_list: List of link simulation results
|
||||
result_start_time: Start time for the results (ISO format string)
|
||||
"""
|
||||
# Convert result_start_time string to datetime if needed
|
||||
if isinstance(result_start_time, str):
|
||||
# 如果是ISO格式字符串,解析并转换为UTC+8
|
||||
if result_start_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(
|
||||
result_start_time.replace("Z", "+00:00")
|
||||
)
|
||||
simulation_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
simulation_time = datetime.fromisoformat(result_start_time)
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
simulation_time = result_start_time
|
||||
if simulation_time.tzinfo is None:
|
||||
simulation_time = simulation_time.replace(tzinfo=UTC_8)
|
||||
|
||||
timestep_parts = globals.hydraulic_timestep.split(":")
|
||||
timestep = timedelta(
|
||||
hours=int(timestep_parts[0]),
|
||||
minutes=int(timestep_parts[1]),
|
||||
seconds=int(timestep_parts[2]),
|
||||
)
|
||||
|
||||
# Prepare node data for batch insert
|
||||
node_data = []
|
||||
for node_result in node_result_list:
|
||||
node_id = node_result.get("node")
|
||||
for period_index in range(num_periods):
|
||||
current_time = simulation_time + (timestep * period_index)
|
||||
data = node_result.get("result", [])[period_index]
|
||||
node_data.append(
|
||||
{
|
||||
"time": current_time,
|
||||
"scheme_type": scheme_type,
|
||||
"scheme_name": scheme_name,
|
||||
"id": node_id,
|
||||
"actual_demand": data.get("demand"),
|
||||
"total_head": data.get("head"),
|
||||
"pressure": data.get("pressure"),
|
||||
"quality": data.get("quality"),
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare link data for batch insert
|
||||
link_data = []
|
||||
for link_result in link_result_list:
|
||||
link_id = link_result.get("link")
|
||||
for period_index in range(num_periods):
|
||||
current_time = simulation_time + (timestep * period_index)
|
||||
data = link_result.get("result", [])[period_index]
|
||||
link_data.append(
|
||||
{
|
||||
"time": current_time,
|
||||
"scheme_type": scheme_type,
|
||||
"scheme_name": scheme_name,
|
||||
"id": link_id,
|
||||
"flow": data.get("flow"),
|
||||
"friction": data.get("friction"),
|
||||
"headloss": data.get("headloss"),
|
||||
"quality": data.get("quality"),
|
||||
"reaction": data.get("reaction"),
|
||||
"setting": data.get("setting"),
|
||||
"status": data.get("status"),
|
||||
"velocity": data.get("velocity"),
|
||||
}
|
||||
)
|
||||
|
||||
# Insert data using batch methods
|
||||
if node_data:
|
||||
SchemeRepository.insert_nodes_batch_sync(conn, node_data)
|
||||
|
||||
if link_data:
|
||||
SchemeRepository.insert_links_batch_sync(conn, link_data)
|
||||
|
||||
@staticmethod
|
||||
async def query_all_record_by_scheme_time_property(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
query_time: str,
|
||||
type: str,
|
||||
property: str,
|
||||
) -> list:
|
||||
"""
|
||||
Query all records by scheme, time and property from TimescaleDB.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
scheme_type: Scheme type
|
||||
scheme_name: Scheme name
|
||||
query_time: Time to query (ISO format string)
|
||||
type: Type of data ("node" or "link")
|
||||
property: Property/field to query
|
||||
|
||||
Returns:
|
||||
List of records matching the criteria
|
||||
"""
|
||||
# Convert query_time string to datetime
|
||||
if isinstance(query_time, str):
|
||||
if query_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
|
||||
target_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
target_time = datetime.fromisoformat(query_time)
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
target_time = query_time
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
|
||||
# Create time range: query_time ± 1 second
|
||||
start_time = target_time - timedelta(seconds=1)
|
||||
end_time = target_time + timedelta(seconds=1)
|
||||
|
||||
# Query based on type
|
||||
if type.lower() == "node":
|
||||
data = await SchemeRepository.get_nodes_field_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, property
|
||||
)
|
||||
elif type.lower() == "link":
|
||||
data = await SchemeRepository.get_links_field_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, property
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
|
||||
|
||||
# Format the results
|
||||
# Format the results
|
||||
result = []
|
||||
for id, items in data.items():
|
||||
for item in items:
|
||||
result.append({"ID": id, "value": item["value"]})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def query_scheme_simulation_result_by_id_time(
|
||||
conn: AsyncConnection,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
id: str,
|
||||
type: str,
|
||||
query_time: str,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Query scheme simulation results by id and time from TimescaleDB.
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
scheme_type: Scheme type
|
||||
scheme_name: Scheme name
|
||||
id: The id of the node or link
|
||||
type: Type of data ("node" or "link")
|
||||
query_time: Time to query (ISO format string)
|
||||
|
||||
Returns:
|
||||
List of records matching the criteria
|
||||
"""
|
||||
# Convert query_time string to datetime
|
||||
if isinstance(query_time, str):
|
||||
if query_time.endswith("Z"):
|
||||
# UTC时间,转换为UTC+8
|
||||
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
|
||||
target_time = utc_time.astimezone(UTC_8)
|
||||
else:
|
||||
# 假设已经是UTC+8时间
|
||||
target_time = datetime.fromisoformat(query_time)
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
else:
|
||||
target_time = query_time
|
||||
if target_time.tzinfo is None:
|
||||
target_time = target_time.replace(tzinfo=UTC_8)
|
||||
|
||||
# Create time range: query_time ± 1 second
|
||||
start_time = target_time - timedelta(seconds=1)
|
||||
end_time = target_time + timedelta(seconds=1)
|
||||
|
||||
# Query based on type
|
||||
if type.lower() == "node":
|
||||
return await SchemeRepository.get_node_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, id
|
||||
)
|
||||
elif type.lower() == "link":
|
||||
return await SchemeRepository.get_link_by_scheme_and_time_range(
|
||||
conn, scheme_type, scheme_name, start_time, end_time, id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
|
||||
36
app/infra/db/timescaledb/timescaledb_info.py
Normal file
36
app/infra/db/timescaledb/timescaledb_info.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
pg_name = os.getenv("TIMESCALEDB_DB_NAME")
|
||||
pg_host = os.getenv("TIMESCALEDB_DB_HOST")
|
||||
pg_port = os.getenv("TIMESCALEDB_DB_PORT")
|
||||
pg_user = os.getenv("TIMESCALEDB_DB_USER")
|
||||
pg_password = os.getenv("TIMESCALEDB_DB_PASSWORD")
|
||||
|
||||
|
||||
def get_pgconn_string(
|
||||
db_name=pg_name,
|
||||
db_host=pg_host,
|
||||
db_port=pg_port,
|
||||
db_user=pg_user,
|
||||
db_password=pg_password,
|
||||
):
|
||||
"""返回 PostgreSQL 连接字符串"""
|
||||
return f"dbname={db_name} host={db_host} port={db_port} user={db_user} password={db_password}"
|
||||
|
||||
|
||||
def get_pg_config():
|
||||
"""返回 PostgreSQL 配置变量的字典"""
|
||||
return {
|
||||
"name": pg_name,
|
||||
"host": pg_host,
|
||||
"port": pg_port,
|
||||
"user": pg_user,
|
||||
}
|
||||
|
||||
|
||||
def get_pg_password():
|
||||
"""返回密码(谨慎使用)"""
|
||||
return pg_password
|
||||
0
app/infra/repositories/__init__.py
Normal file
0
app/infra/repositories/__init__.py
Normal file
4244
app/main.py
Normal file
4244
app/main.py
Normal file
File diff suppressed because it is too large
Load Diff
0
app/native/__init__.py
Normal file
0
app/native/__init__.py
Normal file
BIN
app/native/api/__init__.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/__init__.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/__init__.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/__init__.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/batch_api.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/batch_api.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/batch_api.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/batch_api.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/batch_api_cs.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/batch_api_cs.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/batch_api_cs.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/batch_api_cs.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/batch_exe.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/batch_exe.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/batch_exe.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/batch_exe.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/clean_api.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/clean_api.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/clean_api.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/clean_api.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/connection.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/connection.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/connection.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/connection.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/database.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/database.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/database.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/database.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/extension_data.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/extension_data.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/extension_data.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/extension_data.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/inp_in.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/inp_in.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/inp_in.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/inp_in.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/inp_out.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/inp_out.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/inp_out.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/inp_out.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/postgresql_info.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/postgresql_info.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/postgresql_info.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/postgresql_info.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/project.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/project.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/project.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/project.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s0_base.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s0_base.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s0_base.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s0_base.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s10_status.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s10_status.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s10_status.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s10_status.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s11_patterns.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s11_patterns.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s11_patterns.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s11_patterns.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s12_curves.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s12_curves.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s12_curves.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s12_curves.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s13_controls.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s13_controls.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s13_controls.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s13_controls.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s14_rules.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s14_rules.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s14_rules.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s14_rules.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s15_energy.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s15_energy.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s15_energy.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s15_energy.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
BIN
app/native/api/s16_emitters.cp312-win_amd64.pyd
Normal file
BIN
app/native/api/s16_emitters.cp312-win_amd64.pyd
Normal file
Binary file not shown.
BIN
app/native/api/s16_emitters.cpython-312-x86_64-linux-gnu.so
Normal file
BIN
app/native/api/s16_emitters.cpython-312-x86_64-linux-gnu.so
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user