删除管网健康风险分析多余的包引入

This commit is contained in:
JIANG
2025-12-30 19:09:12 +08:00
parent 79c2bf811e
commit ee642be6b8

View File

@@ -1,8 +1,6 @@
import os
import joblib
import pandas as pd
import numpy as np
from sksurv.ensemble import RandomSurvivalForest
import matplotlib.pyplot as plt
@@ -16,7 +14,7 @@ class PipelineHealthAnalyzer:
使用前需确保安装依赖joblib, pandas, numpy, scikit-survival, matplotlib。
"""
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"):
"""
初始化分析器,加载预训练的随机生存森林模型。
@@ -35,8 +33,10 @@ class PipelineHealthAnalyzer:
try:
self.rsf = joblib.load(model_path)
self.features = [
'Material', 'Diameter', 'Flow Velocity',
'Pressure', # 'Temperature', 'Precipitation',
"Material",
"Diameter",
"Flow Velocity",
"Pressure", # 'Temperature', 'Precipitation',
# 'Location', 'Structural Defects', 'Functional Defects'
]
except Exception as e:
@@ -65,7 +65,9 @@ class PipelineHealthAnalyzer:
survival_functions = self.rsf.predict_survival_function(x_test)
return list(survival_functions)
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
def plot_survival(
self, survival_functions: list, save_path: str = None, show_plot: bool = True
):
"""
可视化生存函数,生成生存概率图表。
@@ -75,7 +77,7 @@ class PipelineHealthAnalyzer:
"""
plt.figure(figsize=(10, 6))
for i, sf in enumerate(survival_functions):
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
plt.xlabel("时间(年)")
plt.ylabel("生存概率")
plt.title("管道生存概率预测")
@@ -83,7 +85,7 @@ class PipelineHealthAnalyzer:
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"图表已保存到: {save_path}")
if show_plot:
@@ -137,4 +139,4 @@ class PipelineHealthAnalyzer:
- 数据格式必须匹配特征列表,特征值为数值型。
- 模型文件需从原项目复制或重新训练。
- 如果需要自定义特征或模型参数可修改类中的features列表或继承此类。
"""
"""