删除管网健康风险分析多余的包引入

This commit is contained in:
JIANG
2025-12-30 19:09:12 +08:00
parent 79c2bf811e
commit ee642be6b8

View File

@@ -1,8 +1,6 @@
import os import os
import joblib import joblib
import pandas as pd import pandas as pd
import numpy as np
from sksurv.ensemble import RandomSurvivalForest
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -16,7 +14,7 @@ class PipelineHealthAnalyzer:
使用前需确保安装依赖joblib, pandas, numpy, scikit-survival, matplotlib。 使用前需确保安装依赖joblib, pandas, numpy, scikit-survival, matplotlib。
""" """
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'): def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"):
""" """
初始化分析器,加载预训练的随机生存森林模型。 初始化分析器,加载预训练的随机生存森林模型。
@@ -35,8 +33,10 @@ class PipelineHealthAnalyzer:
try: try:
self.rsf = joblib.load(model_path) self.rsf = joblib.load(model_path)
self.features = [ self.features = [
'Material', 'Diameter', 'Flow Velocity', "Material",
'Pressure', # 'Temperature', 'Precipitation', "Diameter",
"Flow Velocity",
"Pressure", # 'Temperature', 'Precipitation',
# 'Location', 'Structural Defects', 'Functional Defects' # 'Location', 'Structural Defects', 'Functional Defects'
] ]
except Exception as e: except Exception as e:
@@ -65,7 +65,9 @@ class PipelineHealthAnalyzer:
survival_functions = self.rsf.predict_survival_function(x_test) survival_functions = self.rsf.predict_survival_function(x_test)
return list(survival_functions) return list(survival_functions)
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True): def plot_survival(
self, survival_functions: list, save_path: str = None, show_plot: bool = True
):
""" """
可视化生存函数,生成生存概率图表。 可视化生存函数,生成生存概率图表。
@@ -75,7 +77,7 @@ class PipelineHealthAnalyzer:
""" """
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
for i, sf in enumerate(survival_functions): for i, sf in enumerate(survival_functions):
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}') plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
plt.xlabel("时间(年)") plt.xlabel("时间(年)")
plt.ylabel("生存概率") plt.ylabel("生存概率")
plt.title("管道生存概率预测") plt.title("管道生存概率预测")
@@ -83,7 +85,7 @@ class PipelineHealthAnalyzer:
plt.grid(True, alpha=0.3) plt.grid(True, alpha=0.3)
if save_path: if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"图表已保存到: {save_path}") print(f"图表已保存到: {save_path}")
if show_plot: if show_plot: