149 lines
5.3 KiB
Python
149 lines
5.3 KiB
Python
import os
|
||
import joblib
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
class PipelineHealthAnalyzer:
|
||
"""
|
||
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
|
||
|
||
该类封装了模型加载和预测功能,便于在其他项目中复用。
|
||
模型基于4个特征进行生存分析预测:材料、直径、流速、压力。
|
||
|
||
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||
"""
|
||
|
||
def __init__(self, model_path: str = None):
|
||
"""
|
||
初始化分析器,加载预训练的随机生存森林模型。
|
||
|
||
:param model_path: 模型文件的路径(默认为相对路径 './model/my_survival_forest_model_quxi.joblib')。
|
||
:raises FileNotFoundError: 如果模型文件不存在。
|
||
:raises Exception: 如果模型加载失败。
|
||
"""
|
||
if model_path is None:
|
||
model_path = os.path.join(
|
||
os.path.dirname(__file__),
|
||
"model",
|
||
"my_survival_forest_model_quxi.joblib",
|
||
)
|
||
# 确保 model 目录存在
|
||
model_dir = os.path.dirname(model_path)
|
||
if model_dir and not os.path.exists(model_dir):
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
if not os.path.exists(model_path):
|
||
raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
||
|
||
try:
|
||
self.rsf = joblib.load(model_path)
|
||
self.features = [
|
||
"Material",
|
||
"Diameter",
|
||
"Flow Velocity",
|
||
"Pressure", # 'Temperature', 'Precipitation',
|
||
# 'Location', 'Structural Defects', 'Functional Defects'
|
||
]
|
||
except Exception as e:
|
||
raise Exception(f"加载模型时出错: {str(e)}")
|
||
|
||
def predict_survival(self, data: pd.DataFrame) -> list:
|
||
"""
|
||
基于输入数据预测生存函数。
|
||
|
||
:param data: pandas DataFrame,包含4个必需特征列。数据应为数值型或可转换为数值型。
|
||
:return: 生存函数列表,每个元素为一个生存函数对象(包含时间点x和生存概率y)。
|
||
:raises ValueError: 如果数据缺少必需特征或格式不正确。
|
||
"""
|
||
# 检查必需特征是否存在
|
||
missing_features = [feat for feat in self.features if feat not in data.columns]
|
||
if missing_features:
|
||
raise ValueError(f"数据缺少必需特征: {missing_features}")
|
||
|
||
# 提取特征数据
|
||
try:
|
||
x_test = data[self.features].astype(float) # 确保数值型
|
||
except ValueError as e:
|
||
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
|
||
|
||
# 进行预测
|
||
survival_functions = self.rsf.predict_survival_function(x_test)
|
||
return list(survival_functions)
|
||
|
||
def plot_survival(
|
||
self, survival_functions: list, save_path: str = None, show_plot: bool = True
|
||
):
|
||
"""
|
||
可视化生存函数,生成生存概率图表。
|
||
|
||
:param survival_functions: predict_survival返回的生存函数列表。
|
||
:param save_path: 可选,保存图表的路径(.png格式)。如果为None,则不保存。
|
||
:param show_plot: 是否显示图表(在交互环境中)。
|
||
"""
|
||
plt.figure(figsize=(10, 6))
|
||
for i, sf in enumerate(survival_functions):
|
||
plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
|
||
plt.xlabel("时间(年)")
|
||
plt.ylabel("生存概率")
|
||
plt.title("管道生存概率预测")
|
||
plt.legend()
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||
print(f"图表已保存到: {save_path}")
|
||
|
||
if show_plot:
|
||
plt.show()
|
||
else:
|
||
plt.close()
|
||
|
||
|
||
# 调用说明示例
|
||
"""
|
||
在其他项目中使用PipelineHealthAnalyzer类的步骤:
|
||
|
||
1. 安装依赖(在requirements.txt中添加):
|
||
joblib==1.5.0
|
||
pandas==2.2.3
|
||
numpy==2.0.2
|
||
scikit-survival==0.23.1
|
||
matplotlib==3.9.4
|
||
|
||
2. 导入类:
|
||
from pipeline_health_analyzer import PipelineHealthAnalyzer
|
||
|
||
3. 初始化分析器(替换为实际模型路径):
|
||
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
|
||
|
||
4. 准备数据(pandas DataFrame,包含9个特征列):
|
||
import pandas as pd
|
||
data = pd.DataFrame({
|
||
'Material': [1, 2], # 示例数据
|
||
'Diameter': [100, 150],
|
||
'Flow Velocity': [1.5, 2.0],
|
||
'Pressure': [50, 60],
|
||
'Temperature': [20, 25],
|
||
'Precipitation': [0.1, 0.2],
|
||
'Location': [1, 2],
|
||
'Structural Defects': [0, 1],
|
||
'Functional Defects': [0, 0]
|
||
})
|
||
|
||
5. 进行预测:
|
||
survival_funcs = analyzer.predict_survival(data)
|
||
|
||
6. 查看结果(每个样本的生存概率随时间变化):
|
||
for i, sf in enumerate(survival_funcs):
|
||
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
|
||
|
||
7. 可视化(可选):
|
||
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
|
||
|
||
注意:
|
||
- 数据格式必须匹配特征列表,特征值为数值型。
|
||
- 模型文件需从原项目复制或重新训练。
|
||
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
||
"""
|