删除管网健康风险分析多余的包引入
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import joblib
|
import joblib
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
|
||||||
from sksurv.ensemble import RandomSurvivalForest
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +14,7 @@ class PipelineHealthAnalyzer:
|
|||||||
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
|
def __init__(self, model_path: str = "model/my_survival_forest_model_quxi.joblib"):
|
||||||
"""
|
"""
|
||||||
初始化分析器,加载预训练的随机生存森林模型。
|
初始化分析器,加载预训练的随机生存森林模型。
|
||||||
|
|
||||||
@@ -35,8 +33,10 @@ class PipelineHealthAnalyzer:
|
|||||||
try:
|
try:
|
||||||
self.rsf = joblib.load(model_path)
|
self.rsf = joblib.load(model_path)
|
||||||
self.features = [
|
self.features = [
|
||||||
'Material', 'Diameter', 'Flow Velocity',
|
"Material",
|
||||||
'Pressure', # 'Temperature', 'Precipitation',
|
"Diameter",
|
||||||
|
"Flow Velocity",
|
||||||
|
"Pressure", # 'Temperature', 'Precipitation',
|
||||||
# 'Location', 'Structural Defects', 'Functional Defects'
|
# 'Location', 'Structural Defects', 'Functional Defects'
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -65,7 +65,9 @@ class PipelineHealthAnalyzer:
|
|||||||
survival_functions = self.rsf.predict_survival_function(x_test)
|
survival_functions = self.rsf.predict_survival_function(x_test)
|
||||||
return list(survival_functions)
|
return list(survival_functions)
|
||||||
|
|
||||||
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
|
def plot_survival(
|
||||||
|
self, survival_functions: list, save_path: str = None, show_plot: bool = True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
可视化生存函数,生成生存概率图表。
|
可视化生存函数,生成生存概率图表。
|
||||||
|
|
||||||
@@ -75,7 +77,7 @@ class PipelineHealthAnalyzer:
|
|||||||
"""
|
"""
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
for i, sf in enumerate(survival_functions):
|
for i, sf in enumerate(survival_functions):
|
||||||
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
|
plt.step(sf.x, sf.y, where="post", label=f"样本 {i + 1}")
|
||||||
plt.xlabel("时间(年)")
|
plt.xlabel("时间(年)")
|
||||||
plt.ylabel("生存概率")
|
plt.ylabel("生存概率")
|
||||||
plt.title("管道生存概率预测")
|
plt.title("管道生存概率预测")
|
||||||
@@ -83,7 +85,7 @@ class PipelineHealthAnalyzer:
|
|||||||
plt.grid(True, alpha=0.3)
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
if save_path:
|
if save_path:
|
||||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||||
print(f"图表已保存到: {save_path}")
|
print(f"图表已保存到: {save_path}")
|
||||||
|
|
||||||
if show_plot:
|
if show_plot:
|
||||||
@@ -137,4 +139,4 @@ class PipelineHealthAnalyzer:
|
|||||||
- 数据格式必须匹配特征列表,特征值为数值型。
|
- 数据格式必须匹配特征列表,特征值为数值型。
|
||||||
- 模型文件需从原项目复制或重新训练。
|
- 模型文件需从原项目复制或重新训练。
|
||||||
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user