From ee642be6b808753a4d957fcb6793b07be932b203 Mon Sep 17 00:00:00 2001 From: JIANG Date: Tue, 30 Dec 2025 19:09:12 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E7=AE=A1=E7=BD=91=E5=81=A5?= =?UTF-8?q?=E5=BA=B7=E9=A3=8E=E9=99=A9=E5=88=86=E6=9E=90=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E7=9A=84=E5=8C=85=E5=BC=95=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_ex/pipeline_health_analyzer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/api_ex/pipeline_health_analyzer.py b/api_ex/pipeline_health_analyzer.py index 7cdb699..0edb27a 100644 --- a/api_ex/pipeline_health_analyzer.py +++ b/api_ex/pipeline_health_analyzer.py @@ -1,8 +1,6 @@ import os import joblib import pandas as pd -import numpy as np -from sksurv.ensemble import RandomSurvivalForest import matplotlib.pyplot as plt @@ -16,7 +14,7 @@ class PipelineHealthAnalyzer: 使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。 """ - def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'): + def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"): """ 初始化分析器,加载预训练的随机生存森林模型。 @@ -35,8 +33,10 @@ class PipelineHealthAnalyzer: try: self.rsf = joblib.load(model_path) self.features = [ - 'Material', 'Diameter', 'Flow Velocity', - 'Pressure', # 'Temperature', 'Precipitation', + "Material", + "Diameter", + "Flow Velocity", + "Pressure", # 'Temperature', 'Precipitation', # 'Location', 'Structural Defects', 'Functional Defects' ] except Exception as e: @@ -65,7 +65,9 @@ class PipelineHealthAnalyzer: survival_functions = self.rsf.predict_survival_function(x_test) return list(survival_functions) - def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True): + def plot_survival( + self, survival_functions: list, save_path: str = None, show_plot: bool = True + ): """ 可视化生存函数,生成生存概率图表。 @@ -75,7 +77,7 @@ class PipelineHealthAnalyzer: """ plt.figure(figsize=(10, 6)) for i, sf in enumerate(survival_functions): - plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}') + plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}") plt.xlabel("时间(年)") plt.ylabel("生存概率") plt.title("管道生存概率预测") @@ -83,7 +85,7 @@ class PipelineHealthAnalyzer: plt.grid(True, alpha=0.3) if save_path: - plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.savefig(save_path, dpi=300, bbox_inches="tight") print(f"图表已保存到: {save_path}") if show_plot: @@ -137,4 +139,4 @@ class PipelineHealthAnalyzer: - 数据格式必须匹配特征列表,特征值为数值型。 - 模型文件需从原项目复制或重新训练。 - 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。 -""" \ No newline at end of file +"""