新增tests;新增管道健康预测方法;更新复合查询函数输出格式
This commit is contained in:
3
api_ex/__init__.py
Normal file
3
api_ex/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .Fdataclean import *
|
||||||
|
from .Pdataclean import *
|
||||||
|
from .pipeline_health_analyzer import *
|
||||||
BIN
api_ex/model/my_survival_forest_model_quxi.zip
Normal file
BIN
api_ex/model/my_survival_forest_model_quxi.zip
Normal file
Binary file not shown.
140
api_ex/pipeline_health_analyzer.py
Normal file
140
api_ex/pipeline_health_analyzer.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import os
|
||||||
|
import joblib
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sksurv.ensemble import RandomSurvivalForest
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineHealthAnalyzer:
|
||||||
|
"""
|
||||||
|
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
|
||||||
|
|
||||||
|
该类封装了模型加载和预测功能,便于在其他项目中复用。
|
||||||
|
模型基于4个特征进行生存分析预测:材料、直径、流速、压力。
|
||||||
|
|
||||||
|
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
|
||||||
|
"""
|
||||||
|
初始化分析器,加载预训练的随机生存森林模型。
|
||||||
|
|
||||||
|
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
|
||||||
|
:raises FileNotFoundError: 如果模型文件不存在。
|
||||||
|
:raises Exception: 如果模型加载失败。
|
||||||
|
"""
|
||||||
|
# 确保 model 目录存在
|
||||||
|
model_dir = os.path.dirname(model_path)
|
||||||
|
if model_dir and not os.path.exists(model_dir):
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.rsf = joblib.load(model_path)
|
||||||
|
self.features = [
|
||||||
|
'Material', 'Diameter', 'Flow Velocity',
|
||||||
|
'Pressure', # 'Temperature', 'Precipitation',
|
||||||
|
# 'Location', 'Structural Defects', 'Functional Defects'
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"加载模型时出错: {str(e)}")
|
||||||
|
|
||||||
|
def predict_survival(self, data: pd.DataFrame) -> list:
|
||||||
|
"""
|
||||||
|
基于输入数据预测生存函数。
|
||||||
|
|
||||||
|
:param data: pandas DataFrame,包含4个必需特征列。数据应为数值型或可转换为数值型。
|
||||||
|
:return: 生存函数列表,每个元素为一个生存函数对象(包含时间点x和生存概率y)。
|
||||||
|
:raises ValueError: 如果数据缺少必需特征或格式不正确。
|
||||||
|
"""
|
||||||
|
# 检查必需特征是否存在
|
||||||
|
missing_features = [feat for feat in self.features if feat not in data.columns]
|
||||||
|
if missing_features:
|
||||||
|
raise ValueError(f"数据缺少必需特征: {missing_features}")
|
||||||
|
|
||||||
|
# 提取特征数据
|
||||||
|
try:
|
||||||
|
x_test = data[self.features].astype(float) # 确保数值型
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
|
||||||
|
|
||||||
|
# 进行预测
|
||||||
|
survival_functions = self.rsf.predict_survival_function(x_test)
|
||||||
|
return list(survival_functions)
|
||||||
|
|
||||||
|
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
|
||||||
|
"""
|
||||||
|
可视化生存函数,生成生存概率图表。
|
||||||
|
|
||||||
|
:param survival_functions: predict_survival返回的生存函数列表。
|
||||||
|
:param save_path: 可选,保存图表的路径(.png格式)。如果为None,则不保存。
|
||||||
|
:param show_plot: 是否显示图表(在交互环境中)。
|
||||||
|
"""
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
for i, sf in enumerate(survival_functions):
|
||||||
|
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
|
||||||
|
plt.xlabel("时间(年)")
|
||||||
|
plt.ylabel("生存概率")
|
||||||
|
plt.title("管道生存概率预测")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f"图表已保存到: {save_path}")
|
||||||
|
|
||||||
|
if show_plot:
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 调用说明示例
|
||||||
|
"""
|
||||||
|
在其他项目中使用PipelineHealthAnalyzer类的步骤:
|
||||||
|
|
||||||
|
1. 安装依赖(在requirements.txt中添加):
|
||||||
|
joblib==1.5.0
|
||||||
|
pandas==2.2.3
|
||||||
|
numpy==2.0.2
|
||||||
|
scikit-survival==0.23.1
|
||||||
|
matplotlib==3.9.4
|
||||||
|
|
||||||
|
2. 导入类:
|
||||||
|
from pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||||
|
|
||||||
|
3. 初始化分析器(替换为实际模型路径):
|
||||||
|
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
|
||||||
|
|
||||||
|
4. 准备数据(pandas DataFrame,包含9个特征列):
|
||||||
|
import pandas as pd
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'Material': [1, 2], # 示例数据
|
||||||
|
'Diameter': [100, 150],
|
||||||
|
'Flow Velocity': [1.5, 2.0],
|
||||||
|
'Pressure': [50, 60],
|
||||||
|
'Temperature': [20, 25],
|
||||||
|
'Precipitation': [0.1, 0.2],
|
||||||
|
'Location': [1, 2],
|
||||||
|
'Structural Defects': [0, 1],
|
||||||
|
'Functional Defects': [0, 0]
|
||||||
|
})
|
||||||
|
|
||||||
|
5. 进行预测:
|
||||||
|
survival_funcs = analyzer.predict_survival(data)
|
||||||
|
|
||||||
|
6. 查看结果(每个样本的生存概率随时间变化):
|
||||||
|
for i, sf in enumerate(survival_funcs):
|
||||||
|
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
|
||||||
|
|
||||||
|
7. 可视化(可选):
|
||||||
|
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- 数据格式必须匹配特征列表,特征值为数值型。
|
||||||
|
- 模型文件需从原项目复制或重新训练。
|
||||||
|
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
||||||
|
"""
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
14
tests/conftest.py
Normal file
14
tests/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 自动添加项目根目录到路径(处理项目结构)
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
def run_this_test(test_file):
|
||||||
|
"""自定义函数:运行单个测试文件(类似pytest)"""
|
||||||
|
# 提取测试文件名(无扩展名)
|
||||||
|
test_name = os.path.splitext(os.path.basename(test_file))[0]
|
||||||
|
# 使用pytest运行(自动处理导入)
|
||||||
|
pytest.main([test_file, "-v"])
|
||||||
62
tests/test_pipeline_health_analyzer.py
Normal file
62
tests/test_pipeline_health_analyzer.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_health_analyzer():
|
||||||
|
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
|
||||||
|
analyzer = PipelineHealthAnalyzer(
|
||||||
|
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
|
||||||
|
)
|
||||||
|
# 创建示例输入数据(9个样本)
|
||||||
|
import pandas as pd
|
||||||
|
import time
|
||||||
|
|
||||||
|
base_data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"Material": [7, 11, 7, 7, 7, 7, 7, 7, 7],
|
||||||
|
"Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2],
|
||||||
|
"Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
|
||||||
|
"Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 复制生成10万样本
|
||||||
|
num_samples = 100
|
||||||
|
repetitions = num_samples // len(base_data) + 1
|
||||||
|
sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head(
|
||||||
|
num_samples
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"生成了 {len(sample_data)} 个样本")
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 进行生存预测
|
||||||
|
survival_functions = analyzer.predict_survival(sample_data)
|
||||||
|
|
||||||
|
# 记录结束时间
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
print(f"预测耗时: {elapsed_time:.2f} 秒")
|
||||||
|
|
||||||
|
# 打印预测结果示例
|
||||||
|
print(f"预测了 {len(survival_functions)} 个生存函数")
|
||||||
|
for i, sf in enumerate(survival_functions):
|
||||||
|
print(
|
||||||
|
f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证返回结果类型
|
||||||
|
assert isinstance(survival_functions, list), "返回值应为列表"
|
||||||
|
assert all(
|
||||||
|
hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions
|
||||||
|
), "每个生存函数应包含x和y属性"
|
||||||
|
|
||||||
|
# 可选:测试绘图功能(不显示图表)
|
||||||
|
analyzer.plot_survival(survival_functions, show_plot=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import conftest
|
||||||
|
|
||||||
|
conftest.run_this_test(__file__) # 自定义运行,类似示例
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from psycopg import AsyncConnection
|
from psycopg import AsyncConnection
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -23,7 +23,7 @@ class CompositeQueries:
|
|||||||
device_ids: List[str],
|
device_ids: List[str],
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
) -> List[Optional[Any]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
获取 SCADA 关联的 link/node 模拟值
|
获取 SCADA 关联的 link/node 模拟值
|
||||||
|
|
||||||
@@ -39,12 +39,12 @@ class CompositeQueries:
|
|||||||
field: 要查询的字段名
|
field: 要查询的字段名
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模拟数据值列表,如果没有找到则对应位置返回 None
|
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 当 SCADA 设备未找到或字段无效时
|
ValueError: 当 SCADA 设备未找到或字段无效时
|
||||||
"""
|
"""
|
||||||
results = []
|
result = {}
|
||||||
# 1. 查询所有 SCADA 信息
|
# 1. 查询所有 SCADA 信息
|
||||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||||
|
|
||||||
@@ -75,8 +75,11 @@ class CompositeQueries:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
||||||
results.append(res)
|
# 添加 scada_id 到每个数据项
|
||||||
return results
|
for item in res:
|
||||||
|
item["scada_id"] = device_id
|
||||||
|
result[device_id] = res
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_scada_associated_scheme_simulation_data(
|
async def get_scada_associated_scheme_simulation_data(
|
||||||
@@ -87,7 +90,7 @@ class CompositeQueries:
|
|||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
scheme_type: str,
|
scheme_type: str,
|
||||||
scheme_name: str,
|
scheme_name: str,
|
||||||
) -> List[Optional[Any]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
获取 SCADA 关联的 link/node 模拟值
|
获取 SCADA 关联的 link/node 模拟值
|
||||||
|
|
||||||
@@ -103,12 +106,12 @@ class CompositeQueries:
|
|||||||
field: 要查询的字段名
|
field: 要查询的字段名
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模拟数据值列表,如果没有找到则对应位置返回 None
|
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 当 SCADA 设备未找到或字段无效时
|
ValueError: 当 SCADA 设备未找到或字段无效时
|
||||||
"""
|
"""
|
||||||
results = []
|
result = {}
|
||||||
# 1. 查询所有 SCADA 信息
|
# 1. 查询所有 SCADA 信息
|
||||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||||
|
|
||||||
@@ -151,8 +154,11 @@ class CompositeQueries:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
||||||
results.append(res)
|
# 添加 scada_id 到每个数据项
|
||||||
return results
|
for item in res:
|
||||||
|
item["scada_id"] = device_id
|
||||||
|
result[device_id] = res
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_element_associated_scada_data(
|
async def get_element_associated_scada_data(
|
||||||
@@ -282,31 +288,31 @@ class CompositeQueries:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 处理pressure数据
|
# 处理pressure数据
|
||||||
# if pressure_ids:
|
if pressure_ids:
|
||||||
# pressure_df = df[pressure_ids]
|
pressure_df = df[pressure_ids]
|
||||||
# # 重置索引,将 time 变为普通列
|
# 重置索引,将 time 变为普通列
|
||||||
# pressure_df = pressure_df.reset_index()
|
pressure_df = pressure_df.reset_index()
|
||||||
# # 移除 time 列,准备输入给清洗方法
|
# 移除 time 列,准备输入给清洗方法
|
||||||
# value_df = pressure_df.drop(columns=["time"])
|
value_df = pressure_df.drop(columns=["time"])
|
||||||
# # 调用清洗方法
|
# 调用清洗方法
|
||||||
# cleaned_value_df = clean_pressure_data_df_km(value_df)
|
cleaned_value_df = clean_pressure_data_df_km(value_df)
|
||||||
# # 添加 time 列到首列
|
# 添加 time 列到首列
|
||||||
# cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
|
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
|
||||||
# # 将清洗后的数据写回数据库
|
# 将清洗后的数据写回数据库
|
||||||
# for device_id in pressure_ids:
|
for device_id in pressure_ids:
|
||||||
# if device_id in cleaned_df.columns:
|
if device_id in cleaned_df.columns:
|
||||||
# cleaned_values = cleaned_df[device_id].tolist()
|
cleaned_values = cleaned_df[device_id].tolist()
|
||||||
# time_values = cleaned_df["time"].tolist()
|
time_values = cleaned_df["time"].tolist()
|
||||||
# for i, time_str in enumerate(time_values):
|
for i, time_str in enumerate(time_values):
|
||||||
# time_dt = datetime.fromisoformat(time_str)
|
time_dt = datetime.fromisoformat(time_str)
|
||||||
# value = cleaned_values[i]
|
value = cleaned_values[i]
|
||||||
# await ScadaRepository.update_scada_field(
|
await ScadaRepository.update_scada_field(
|
||||||
# timescale_conn,
|
timescale_conn,
|
||||||
# time_dt,
|
time_dt,
|
||||||
# device_id,
|
device_id,
|
||||||
# "cleaned_value",
|
"cleaned_value",
|
||||||
# value,
|
value,
|
||||||
# )
|
)
|
||||||
|
|
||||||
# 处理flow数据
|
# 处理flow数据
|
||||||
if flow_ids:
|
if flow_ids:
|
||||||
|
|||||||
Reference in New Issue
Block a user