删除管网健康风险分析多余的包引入
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sksurv.ensemble import RandomSurvivalForest
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
@@ -16,7 +14,7 @@ class PipelineHealthAnalyzer:
|
||||
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
|
||||
def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"):
|
||||
"""
|
||||
初始化分析器,加载预训练的随机生存森林模型。
|
||||
|
||||
@@ -35,8 +33,10 @@ class PipelineHealthAnalyzer:
|
||||
try:
|
||||
self.rsf = joblib.load(model_path)
|
||||
self.features = [
|
||||
'Material', 'Diameter', 'Flow Velocity',
|
||||
'Pressure', # 'Temperature', 'Precipitation',
|
||||
"Material",
|
||||
"Diameter",
|
||||
"Flow Velocity",
|
||||
"Pressure", # 'Temperature', 'Precipitation',
|
||||
# 'Location', 'Structural Defects', 'Functional Defects'
|
||||
]
|
||||
except Exception as e:
|
||||
@@ -65,7 +65,9 @@ class PipelineHealthAnalyzer:
|
||||
survival_functions = self.rsf.predict_survival_function(x_test)
|
||||
return list(survival_functions)
|
||||
|
||||
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
|
||||
def plot_survival(
|
||||
self, survival_functions: list, save_path: str = None, show_plot: bool = True
|
||||
):
|
||||
"""
|
||||
可视化生存函数,生成生存概率图表。
|
||||
|
||||
@@ -75,7 +77,7 @@ class PipelineHealthAnalyzer:
|
||||
"""
|
||||
plt.figure(figsize=(10, 6))
|
||||
for i, sf in enumerate(survival_functions):
|
||||
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
|
||||
plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
|
||||
plt.xlabel("时间(年)")
|
||||
plt.ylabel("生存概率")
|
||||
plt.title("管道生存概率预测")
|
||||
@@ -83,7 +85,7 @@ class PipelineHealthAnalyzer:
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"图表已保存到: {save_path}")
|
||||
|
||||
if show_plot:
|
||||
@@ -137,4 +139,4 @@ class PipelineHealthAnalyzer:
|
||||
- 数据格式必须匹配特征列表,特征值为数值型。
|
||||
- 模型文件需从原项目复制或重新训练。
|
||||
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
||||
"""
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user