import os import joblib import pandas as pd import matplotlib.pyplot as plt class PipelineHealthAnalyzer: """ 管道健康分析器类,使用随机生存森林模型预测管道的生存概率。 该类封装了模型加载和预测功能,便于在其他项目中复用。 模型基于4个特征进行生存分析预测:材料、直径、流速、压力。 使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。 """ def __init__(self, model_path: str = None): """ 初始化分析器,加载预训练的随机生存森林模型。 :param model_path: 模型文件的路径(默认为相对路径 './model/my_survival_forest_model_quxi.joblib')。 :raises FileNotFoundError: 如果模型文件不存在。 :raises Exception: 如果模型加载失败。 """ if model_path is None: model_path = os.path.join( os.path.dirname(__file__), "model", "my_survival_forest_model_quxi.joblib", ) # 确保 model 目录存在 model_dir = os.path.dirname(model_path) if model_dir and not os.path.exists(model_dir): os.makedirs(model_dir, exist_ok=True) if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件未找到: {model_path}") try: self.rsf = joblib.load(model_path) self.features = [ "Material", "Diameter", "Flow Velocity", "Pressure", # 'Temperature', 'Precipitation', # 'Location', 'Structural Defects', 'Functional Defects' ] except Exception as e: raise Exception(f"加载模型时出错: {str(e)}") def predict_survival(self, data: pd.DataFrame) -> list: """ 基于输入数据预测生存函数。 :param data: pandas DataFrame,包含4个必需特征列。数据应为数值型或可转换为数值型。 :return: 生存函数列表,每个元素为一个生存函数对象(包含时间点x和生存概率y)。 :raises ValueError: 如果数据缺少必需特征或格式不正确。 """ # 检查必需特征是否存在 missing_features = [feat for feat in self.features if feat not in data.columns] if missing_features: raise ValueError(f"数据缺少必需特征: {missing_features}") # 提取特征数据 try: x_test = data[self.features].astype(float) # 确保数值型 except ValueError as e: raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}") # 进行预测 survival_functions = self.rsf.predict_survival_function(x_test) return list(survival_functions) def plot_survival( self, survival_functions: list, save_path: str = None, show_plot: bool = True ): """ 可视化生存函数,生成生存概率图表。 :param survival_functions: predict_survival返回的生存函数列表。 :param save_path: 可选,保存图表的路径(.png格式)。如果为None,则不保存。 :param show_plot: 是否显示图表(在交互环境中)。 """ plt.figure(figsize=(10, 6)) for i, sf in enumerate(survival_functions): plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}") plt.xlabel("时间(年)") plt.ylabel("生存概率") plt.title("管道生存概率预测") plt.legend() plt.grid(True, alpha=0.3) if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"图表已保存到: {save_path}") if show_plot: plt.show() else: plt.close() # 调用说明示例 """ 在其他项目中使用PipelineHealthAnalyzer类的步骤: 1. 安装依赖(在requirements.txt中添加): joblib==1.5.0 pandas==2.2.3 numpy==2.0.2 scikit-survival==0.23.1 matplotlib==3.9.4 2. 导入类: from pipeline_health_analyzer import PipelineHealthAnalyzer 3. 初始化分析器(替换为实际模型路径): analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib') 4. 准备数据(pandas DataFrame,包含9个特征列): import pandas as pd data = pd.DataFrame({ 'Material': [1, 2], # 示例数据 'Diameter': [100, 150], 'Flow Velocity': [1.5, 2.0], 'Pressure': [50, 60], 'Temperature': [20, 25], 'Precipitation': [0.1, 0.2], 'Location': [1, 2], 'Structural Defects': [0, 1], 'Functional Defects': [0, 0] }) 5. 进行预测: survival_funcs = analyzer.predict_survival(data) 6. 查看结果(每个样本的生存概率随时间变化): for i, sf in enumerate(survival_funcs): print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...") 7. 可视化(可选): analyzer.plot_survival(survival_funcs, save_path='survival_plot.png') 注意: - 数据格式必须匹配特征列表,特征值为数值型。 - 模型文件需从原项目复制或重新训练。 - 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。 """