新增tests;新增管道健康预测方法;更新复合查询函数输出格式
This commit is contained in:
3
api_ex/__init__.py
Normal file
3
api_ex/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .Fdataclean import *
|
||||
from .Pdataclean import *
|
||||
from .pipeline_health_analyzer import *
|
||||
BIN
api_ex/model/my_survival_forest_model_quxi.zip
Normal file
BIN
api_ex/model/my_survival_forest_model_quxi.zip
Normal file
Binary file not shown.
140
api_ex/pipeline_health_analyzer.py
Normal file
140
api_ex/pipeline_health_analyzer.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import os
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sksurv.ensemble import RandomSurvivalForest
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class PipelineHealthAnalyzer:
|
||||
"""
|
||||
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
|
||||
|
||||
该类封装了模型加载和预测功能,便于在其他项目中复用。
|
||||
模型基于4个特征进行生存分析预测:材料、直径、流速、压力。
|
||||
|
||||
使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
|
||||
"""
|
||||
初始化分析器,加载预训练的随机生存森林模型。
|
||||
|
||||
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
|
||||
:raises FileNotFoundError: 如果模型文件不存在。
|
||||
:raises Exception: 如果模型加载失败。
|
||||
"""
|
||||
# 确保 model 目录存在
|
||||
model_dir = os.path.dirname(model_path)
|
||||
if model_dir and not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件未找到: {model_path}")
|
||||
|
||||
try:
|
||||
self.rsf = joblib.load(model_path)
|
||||
self.features = [
|
||||
'Material', 'Diameter', 'Flow Velocity',
|
||||
'Pressure', # 'Temperature', 'Precipitation',
|
||||
# 'Location', 'Structural Defects', 'Functional Defects'
|
||||
]
|
||||
except Exception as e:
|
||||
raise Exception(f"加载模型时出错: {str(e)}")
|
||||
|
||||
def predict_survival(self, data: pd.DataFrame) -> list:
|
||||
"""
|
||||
基于输入数据预测生存函数。
|
||||
|
||||
:param data: pandas DataFrame,包含4个必需特征列。数据应为数值型或可转换为数值型。
|
||||
:return: 生存函数列表,每个元素为一个生存函数对象(包含时间点x和生存概率y)。
|
||||
:raises ValueError: 如果数据缺少必需特征或格式不正确。
|
||||
"""
|
||||
# 检查必需特征是否存在
|
||||
missing_features = [feat for feat in self.features if feat not in data.columns]
|
||||
if missing_features:
|
||||
raise ValueError(f"数据缺少必需特征: {missing_features}")
|
||||
|
||||
# 提取特征数据
|
||||
try:
|
||||
x_test = data[self.features].astype(float) # 确保数值型
|
||||
except ValueError as e:
|
||||
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
|
||||
|
||||
# 进行预测
|
||||
survival_functions = self.rsf.predict_survival_function(x_test)
|
||||
return list(survival_functions)
|
||||
|
||||
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
|
||||
"""
|
||||
可视化生存函数,生成生存概率图表。
|
||||
|
||||
:param survival_functions: predict_survival返回的生存函数列表。
|
||||
:param save_path: 可选,保存图表的路径(.png格式)。如果为None,则不保存。
|
||||
:param show_plot: 是否显示图表(在交互环境中)。
|
||||
"""
|
||||
plt.figure(figsize=(10, 6))
|
||||
for i, sf in enumerate(survival_functions):
|
||||
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
|
||||
plt.xlabel("时间(年)")
|
||||
plt.ylabel("生存概率")
|
||||
plt.title("管道生存概率预测")
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"图表已保存到: {save_path}")
|
||||
|
||||
if show_plot:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close()
|
||||
|
||||
|
||||
# 调用说明示例
|
||||
"""
|
||||
在其他项目中使用PipelineHealthAnalyzer类的步骤:
|
||||
|
||||
1. 安装依赖(在requirements.txt中添加):
|
||||
joblib==1.5.0
|
||||
pandas==2.2.3
|
||||
numpy==2.0.2
|
||||
scikit-survival==0.23.1
|
||||
matplotlib==3.9.4
|
||||
|
||||
2. 导入类:
|
||||
from pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
3. 初始化分析器(替换为实际模型路径):
|
||||
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
|
||||
|
||||
4. 准备数据(pandas DataFrame,包含9个特征列):
|
||||
import pandas as pd
|
||||
data = pd.DataFrame({
|
||||
'Material': [1, 2], # 示例数据
|
||||
'Diameter': [100, 150],
|
||||
'Flow Velocity': [1.5, 2.0],
|
||||
'Pressure': [50, 60],
|
||||
'Temperature': [20, 25],
|
||||
'Precipitation': [0.1, 0.2],
|
||||
'Location': [1, 2],
|
||||
'Structural Defects': [0, 1],
|
||||
'Functional Defects': [0, 0]
|
||||
})
|
||||
|
||||
5. 进行预测:
|
||||
survival_funcs = analyzer.predict_survival(data)
|
||||
|
||||
6. 查看结果(每个样本的生存概率随时间变化):
|
||||
for i, sf in enumerate(survival_funcs):
|
||||
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
|
||||
|
||||
7. 可视化(可选):
|
||||
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
|
||||
|
||||
注意:
|
||||
- 数据格式必须匹配特征列表,特征值为数值型。
|
||||
- 模型文件需从原项目复制或重新训练。
|
||||
- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。
|
||||
"""
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
14
tests/conftest.py
Normal file
14
tests/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 自动添加项目根目录到路径(处理项目结构)
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def run_this_test(test_file):
|
||||
"""自定义函数:运行单个测试文件(类似pytest)"""
|
||||
# 提取测试文件名(无扩展名)
|
||||
test_name = os.path.splitext(os.path.basename(test_file))[0]
|
||||
# 使用pytest运行(自动处理导入)
|
||||
pytest.main([test_file, "-v"])
|
||||
62
tests/test_pipeline_health_analyzer.py
Normal file
62
tests/test_pipeline_health_analyzer.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
|
||||
def test_pipeline_health_analyzer():
|
||||
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
|
||||
analyzer = PipelineHealthAnalyzer(
|
||||
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
|
||||
)
|
||||
# 创建示例输入数据(9个样本)
|
||||
import pandas as pd
|
||||
import time
|
||||
|
||||
base_data = pd.DataFrame(
|
||||
{
|
||||
"Material": [7, 11, 7, 7, 7, 7, 7, 7, 7],
|
||||
"Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2],
|
||||
"Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
|
||||
"Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
|
||||
}
|
||||
)
|
||||
|
||||
# 复制生成10万样本
|
||||
num_samples = 100
|
||||
repetitions = num_samples // len(base_data) + 1
|
||||
sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head(
|
||||
num_samples
|
||||
)
|
||||
|
||||
print(f"生成了 {len(sample_data)} 个样本")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 进行生存预测
|
||||
survival_functions = analyzer.predict_survival(sample_data)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"预测耗时: {elapsed_time:.2f} 秒")
|
||||
|
||||
# 打印预测结果示例
|
||||
print(f"预测了 {len(survival_functions)} 个生存函数")
|
||||
for i, sf in enumerate(survival_functions):
|
||||
print(
|
||||
f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}"
|
||||
)
|
||||
|
||||
# 验证返回结果类型
|
||||
assert isinstance(survival_functions, list), "返回值应为列表"
|
||||
assert all(
|
||||
hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions
|
||||
), "每个生存函数应包含x和y属性"
|
||||
|
||||
# 可选:测试绘图功能(不显示图表)
|
||||
analyzer.plot_survival(survival_functions, show_plot=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import conftest
|
||||
|
||||
conftest.run_this_test(__file__) # 自定义运行,类似示例
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Any
|
||||
from typing import List, Optional, Any, Dict
|
||||
from datetime import datetime
|
||||
from psycopg import AsyncConnection
|
||||
import pandas as pd
|
||||
@@ -23,7 +23,7 @@ class CompositeQueries:
|
||||
device_ids: List[str],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> List[Optional[Any]]:
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取 SCADA 关联的 link/node 模拟值
|
||||
|
||||
@@ -39,12 +39,12 @@ class CompositeQueries:
|
||||
field: 要查询的字段名
|
||||
|
||||
Returns:
|
||||
模拟数据值列表,如果没有找到则对应位置返回 None
|
||||
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
|
||||
|
||||
Raises:
|
||||
ValueError: 当 SCADA 设备未找到或字段无效时
|
||||
"""
|
||||
results = []
|
||||
result = {}
|
||||
# 1. 查询所有 SCADA 信息
|
||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||
|
||||
@@ -75,8 +75,11 @@ class CompositeQueries:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
||||
results.append(res)
|
||||
return results
|
||||
# 添加 scada_id 到每个数据项
|
||||
for item in res:
|
||||
item["scada_id"] = device_id
|
||||
result[device_id] = res
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_scada_associated_scheme_simulation_data(
|
||||
@@ -87,7 +90,7 @@ class CompositeQueries:
|
||||
end_time: datetime,
|
||||
scheme_type: str,
|
||||
scheme_name: str,
|
||||
) -> List[Optional[Any]]:
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
获取 SCADA 关联的 link/node 模拟值
|
||||
|
||||
@@ -103,12 +106,12 @@ class CompositeQueries:
|
||||
field: 要查询的字段名
|
||||
|
||||
Returns:
|
||||
模拟数据值列表,如果没有找到则对应位置返回 None
|
||||
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
|
||||
|
||||
Raises:
|
||||
ValueError: 当 SCADA 设备未找到或字段无效时
|
||||
"""
|
||||
results = []
|
||||
result = {}
|
||||
# 1. 查询所有 SCADA 信息
|
||||
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
|
||||
|
||||
@@ -151,8 +154,11 @@ class CompositeQueries:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown SCADA type: {scada_type}")
|
||||
results.append(res)
|
||||
return results
|
||||
# 添加 scada_id 到每个数据项
|
||||
for item in res:
|
||||
item["scada_id"] = device_id
|
||||
result[device_id] = res
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_element_associated_scada_data(
|
||||
@@ -282,31 +288,31 @@ class CompositeQueries:
|
||||
]
|
||||
|
||||
# 处理pressure数据
|
||||
# if pressure_ids:
|
||||
# pressure_df = df[pressure_ids]
|
||||
# # 重置索引,将 time 变为普通列
|
||||
# pressure_df = pressure_df.reset_index()
|
||||
# # 移除 time 列,准备输入给清洗方法
|
||||
# value_df = pressure_df.drop(columns=["time"])
|
||||
# # 调用清洗方法
|
||||
# cleaned_value_df = clean_pressure_data_df_km(value_df)
|
||||
# # 添加 time 列到首列
|
||||
# cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
|
||||
# # 将清洗后的数据写回数据库
|
||||
# for device_id in pressure_ids:
|
||||
# if device_id in cleaned_df.columns:
|
||||
# cleaned_values = cleaned_df[device_id].tolist()
|
||||
# time_values = cleaned_df["time"].tolist()
|
||||
# for i, time_str in enumerate(time_values):
|
||||
# time_dt = datetime.fromisoformat(time_str)
|
||||
# value = cleaned_values[i]
|
||||
# await ScadaRepository.update_scada_field(
|
||||
# timescale_conn,
|
||||
# time_dt,
|
||||
# device_id,
|
||||
# "cleaned_value",
|
||||
# value,
|
||||
# )
|
||||
if pressure_ids:
|
||||
pressure_df = df[pressure_ids]
|
||||
# 重置索引,将 time 变为普通列
|
||||
pressure_df = pressure_df.reset_index()
|
||||
# 移除 time 列,准备输入给清洗方法
|
||||
value_df = pressure_df.drop(columns=["time"])
|
||||
# 调用清洗方法
|
||||
cleaned_value_df = clean_pressure_data_df_km(value_df)
|
||||
# 添加 time 列到首列
|
||||
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
|
||||
# 将清洗后的数据写回数据库
|
||||
for device_id in pressure_ids:
|
||||
if device_id in cleaned_df.columns:
|
||||
cleaned_values = cleaned_df[device_id].tolist()
|
||||
time_values = cleaned_df["time"].tolist()
|
||||
for i, time_str in enumerate(time_values):
|
||||
time_dt = datetime.fromisoformat(time_str)
|
||||
value = cleaned_values[i]
|
||||
await ScadaRepository.update_scada_field(
|
||||
timescale_conn,
|
||||
time_dt,
|
||||
device_id,
|
||||
"cleaned_value",
|
||||
value,
|
||||
)
|
||||
|
||||
# 处理flow数据
|
||||
if flow_ids:
|
||||
|
||||
Reference in New Issue
Block a user