Files
TJWaterServerBinary/app/algorithms/api_ex/pipeline_health_analyzer.py

149 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import joblib
import pandas as pd
import matplotlib.pyplot as plt
class PipelineHealthAnalyzer:
"""
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
该类封装了模型加载和预测功能,便于在其他项目中复用。
模型基于4个特征进行生存分析预测材料、直径、流速、压力。
使用前需确保安装依赖joblib, pandas, numpy, scikit-survival, matplotlib。
"""
def __init__(self, model_path: str = None):
"""
初始化分析器,加载预训练的随机生存森林模型。
:param model_path: 模型文件的路径(默认为相对路径 './model/my_survival_forest_model_quxi.joblib')。
:raises FileNotFoundError: 如果模型文件不存在。
:raises Exception: 如果模型加载失败。
"""
if model_path is None:
model_path = os.path.join(
os.path.dirname(__file__),
"model",
"my_survival_forest_model_quxi.joblib",
)
# 确保 model 目录存在
model_dir = os.path.dirname(model_path)
if model_dir and not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件未找到: {model_path}")
try:
self.rsf = joblib.load(model_path)
self.features = [
"Material",
"Diameter",
"Flow Velocity",
"Pressure", # 'Temperature', 'Precipitation',
# 'Location', 'Structural Defects', 'Functional Defects'
]
except Exception as e:
raise Exception(f"加载模型时出错: {str(e)}")
def predict_survival(self, data: pd.DataFrame) -> list:
"""
基于输入数据预测生存函数。
:param data: pandas DataFrame包含4个必需特征列。数据应为数值型或可转换为数值型。
:return: 生存函数列表每个元素为一个生存函数对象包含时间点x和生存概率y
:raises ValueError: 如果数据缺少必需特征或格式不正确。
"""
# 检查必需特征是否存在
missing_features = [feat for feat in self.features if feat not in data.columns]
if missing_features:
raise ValueError(f"数据缺少必需特征: {missing_features}")
# 提取特征数据
try:
x_test = data[self.features].astype(float) # 确保数值型
except ValueError as e:
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
# 进行预测
survival_functions = self.rsf.predict_survival_function(x_test)
return list(survival_functions)
def plot_survival(
self, survival_functions: list, save_path: str = None, show_plot: bool = True
):
"""
可视化生存函数,生成生存概率图表。
:param survival_functions: predict_survival返回的生存函数列表。
:param save_path: 可选,保存图表的路径(.png格式。如果为None则不保存。
:param show_plot: 是否显示图表(在交互环境中)。
"""
plt.figure(figsize=(10, 6))
for i, sf in enumerate(survival_functions):
plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
plt.xlabel("时间(年)")
plt.ylabel("生存概率")
plt.title("管道生存概率预测")
plt.legend()
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"图表已保存到: {save_path}")
if show_plot:
plt.show()
else:
plt.close()
# 调用说明示例
"""
在其他项目中使用PipelineHealthAnalyzer类的步骤
1. 安装依赖在requirements.txt中添加
joblib==1.5.0
pandas==2.2.3
numpy==2.0.2
scikit-survival==0.23.1
matplotlib==3.9.4
2. 导入类:
from pipeline_health_analyzer import PipelineHealthAnalyzer
3. 初始化分析器(替换为实际模型路径):
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
4. 准备数据pandas DataFrame包含9个特征列
import pandas as pd
data = pd.DataFrame({
'Material': [1, 2], # 示例数据
'Diameter': [100, 150],
'Flow Velocity': [1.5, 2.0],
'Pressure': [50, 60],
'Temperature': [20, 25],
'Precipitation': [0.1, 0.2],
'Location': [1, 2],
'Structural Defects': [0, 1],
'Functional Defects': [0, 0]
})
5. 进行预测:
survival_funcs = analyzer.predict_survival(data)
6. 查看结果(每个样本的生存概率随时间变化):
for i, sf in enumerate(survival_funcs):
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
7. 可视化(可选):
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
注意:
- 数据格式必须匹配特征列表,特征值为数值型。
- 模型文件需从原项目复制或重新训练。
- 如果需要自定义特征或模型参数可修改类中的features列表或继承此类。
"""