新增tests;新增管道健康预测方法;更新复合查询函数输出格式
This commit is contained in:
3
api_ex/__init__.py
Normal file
3
api_ex/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .Fdataclean import *
|
||||
from .Pdataclean import *
|
||||
from .pipeline_health_analyzer import *
|
||||
BIN
api_ex/model/my_survival_forest_model_quxi.zip
Normal file
BIN
api_ex/model/my_survival_forest_model_quxi.zip
Normal file
Binary file not shown.
140
api_ex/pipeline_health_analyzer.py
Normal file
140
api_ex/pipeline_health_analyzer.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import os
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sksurv.ensemble import RandomSurvivalForest
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class PipelineHealthAnalyzer:
|
||||
"""
|
||||
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
|
||||
|
||||
该类封装了模型加载和预测功能,便于在其他项目中复用。
|
||||
模型基于4个特征进行生存分析预测:材料、直径、流速、压力。
|
||||
|
||||
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
|
||||
"""
|
||||
初始化分析器,加载预训练的随机生存森林模型。
|
||||
|
||||
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
|
||||
:raises FileNotFoundError: 如果模型文件不存在。
|
||||
:raises Exception: 如果模型加载失败。
|
||||
"""
|
||||
# 确保 model 目录存在
|
||||
model_dir = os.path.dirname(model_path)
|
||||
if model_dir and not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
||||
|
||||
try:
|
||||
self.rsf = joblib.load(model_path)
|
||||
self.features = [
|
||||
'Material', 'Diameter', 'Flow Velocity',
|
||||
'Pressure', # 'Temperature', 'Precipitation',
|
||||
# 'Location', 'Structural Defects', 'Functional Defects'
|
||||
]
|
||||
except Exception as e:
|
||||
raise Exception(f"加载模型时出错: {str(e)}")
|
||||
|
||||
def predict_survival(self, data: pd.DataFrame) -> list:
|
||||
"""
|
||||
基于输入数据预测生存函数。
|
||||
|
||||
:param data: pandas DataFrame,包含4个必需特征列。数据应为数值型或可转换为数值型。
|
||||
:return: 生存函数列表,每个元素为一个生存函数对象(包含时间点x和生存概率y)。
|
||||
:raises ValueError: 如果数据缺少必需特征或格式不正确。
|
||||
"""
|
||||
# 检查必需特征是否存在
|
||||
missing_features = [feat for feat in self.features if feat not in data.columns]
|
||||
if missing_features:
|
||||
raise ValueError(f"数据缺少必需特征: {missing_features}")
|
||||
|
||||
# 提取特征数据
|
||||
try:
|
||||
x_test = data[self.features].astype(float) # 确保数值型
|
||||
except ValueError as e:
|
||||
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
|
||||
|
||||
# 进行预测
|
||||
survival_functions = self.rsf.predict_survival_function(x_test)
|
||||
return list(survival_functions)
|
||||
|
||||
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
|
||||
"""
|
||||
可视化生存函数,生成生存概率图表。
|
||||
|
||||
:param survival_functions: predict_survival返回的生存函数列表。
|
||||
:param save_path: 可选,保存图表的路径(.png格式)。如果为None,则不保存。
|
||||
:param show_plot: 是否显示图表(在交互环境中)。
|
||||
"""
|
||||
plt.figure(figsize=(10, 6))
|
||||
for i, sf in enumerate(survival_functions):
|
||||
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
|
||||
plt.xlabel("时间(年)")
|
||||
plt.ylabel("生存概率")
|
||||
plt.title("管道生存概率预测")
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"图表已保存到: {save_path}")
|
||||
|
||||
if show_plot:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
|
||||
# 调用说明示例
|
||||
"""
|
||||
在其他项目中使用PipelineHealthAnalyzer类的步骤:
|
||||
|
||||
1. 安装依赖(在requirements.txt中添加):
|
||||
joblib==1.5.0
|
||||
pandas==2.2.3
|
||||
numpy==2.0.2
|
||||
scikit-survival==0.23.1
|
||||
matplotlib==3.9.4
|
||||
|
||||
2. 导入类:
|
||||
from pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
3. 初始化分析器(替换为实际模型路径):
|
||||
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
|
||||
|
||||
4. 准备数据(pandas DataFrame,包含9个特征列):
|
||||
import pandas as pd
|
||||
data = pd.DataFrame({
|
||||
'Material': [1, 2], # 示例数据
|
||||
'Diameter': [100, 150],
|
||||
'Flow Velocity': [1.5, 2.0],
|
||||
'Pressure': [50, 60],
|
||||
'Temperature': [20, 25],
|
||||
'Precipitation': [0.1, 0.2],
|
||||
'Location': [1, 2],
|
||||
'Structural Defects': [0, 1],
|
||||
'Functional Defects': [0, 0]
|
||||
})
|
||||
|
||||
5. 进行预测:
|
||||
survival_funcs = analyzer.predict_survival(data)
|
||||
|
||||
6. 查看结果(每个样本的生存概率随时间变化):
|
||||
for i, sf in enumerate(survival_funcs):
|
||||
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
|
||||
|
||||
7. 可视化(可选):
|
||||
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
|
||||
|
||||
注意:
|
||||
- 数据格式必须匹配特征列表,特征值为数值型。
|
||||
- 模型文件需从原项目复制或重新训练。
|
||||
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
||||
"""
|
||||
Reference in New Issue
Block a user